Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Enable device-specific compilation feature MACRO using Option DEVICE …
Browse files Browse the repository at this point in the history
…with target_compile_definitions (GPU_ARCH, MMA_ENGINE)

Signed-off-by: Qun Gao <qun.gao@intel.com>
  • Loading branch information
qgao007 committed Jun 25, 2024
1 parent 81f132d commit 6327396
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
2 changes: 2 additions & 0 deletions tests/integration/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
include_directories(${CMAKE_SOURCE_DIR}/tests/integration/gemm)
if (DEVICE STREQUAL "mtl")
add_subdirectory(int4_dequantization)
add_subdirectory(int4_dequantization_bias)
else()
add_subdirectory(bf16)
add_subdirectory(stream_k)
Expand Down
7 changes: 6 additions & 1 deletion tests/integration/gemm/int4_dequantization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@ string(REPLACE " " "_" ProjectId ${ProjectId})
string(PREPEND ProjectId "gemm_")

FILE(GLOB src main.cpp)
add_integration_test(${ProjectId} ${src})
add_integration_test(${ProjectId} ${src})
if (DEVICE STREQUAL "mtl")
target_compile_definitions(${ProjectId} PRIVATE MMA_ENGINE=fpu GPU_ARCH=XeLpg)
else()
target_compile_definitions(${ProjectId} PRIVATE MMA_ENGINE=xmx GPU_ARCH=XeHpg)
endif()
8 changes: 4 additions & 4 deletions tests/integration/gemm/int4_dequantization/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,16 @@ void dequantize_gemm_run(uint32_t iter) {
data_type_zero_pt,
gpu::xetla::group::quant_mode::S4_ASYM,
dequant_s,
mma_engine::xmx,
gpu_arch::XeHpg>;
mma_engine::MMA_ENGINE,
gpu_arch::GPU_ARCH>;
using gemm_t = xetla::group::
gemm_t<compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t>;

using epilogue_t = xetla::group::epilogue_t<
xetla::group::epilogue_policy_default<gpu_arch::XeHpg>,
xetla::group::epilogue_policy_default<gpu_arch::GPU_ARCH>,
tile_shape,
mem_desc_c_t>;
using group_swizzle = xetla::kernel::group_swizzle_default<gpu_arch::XeHpg>;
using group_swizzle = xetla::kernel::group_swizzle_default<gpu_arch::GPU_ARCH>;
using gemm_op_t = xetla::kernel::gemm_universal_t<
gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing<
group_swizzle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ set(ProjectIdXe ${ProjectId})
string(PREPEND ProjectIdClient "gemm_client_")
string(PREPEND ProjectIdXe "gemm_xe_")

FILE(GLOB src_client main_client.cpp)
if (DEVICE STREQUAL "mtl")
FILE(GLOB src_client main_client.cpp)
add_integration_test(${ProjectIdClient} ${src_client})
else()
FILE(GLOB src_xe main_xe.cpp)
add_integration_test(${ProjectIdXe} ${src_xe})
endif()
FILE(GLOB src_xe main_xe.cpp)
add_integration_test(${ProjectIdXe} ${src_xe})

0 comments on commit 6327396

Please sign in to comment.