Skip to content

Commit

Permalink
Add CDNA3 MFMA BF16 intrinsics. (#18892)
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Oct 25, 2024
1 parent 3b751a4 commit 1fc6e5b
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 4 deletions.
4 changes: 2 additions & 2 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
Expand All @@ -26,7 +26,7 @@
// GFX941-SAME: features = "+sramecc,-xnack"

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1130,3 +1130,63 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ() {
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [bf16, bf16, f32], user_indexing_maps = [#map, #map1, #map2]>
#pipeline_layout_4 = #hal.pipeline.layout<constants = 4, bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() {
%c0 = arith.constant 0 : index
%B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index
%M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index
%N = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(2) : index
%K = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(3) : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(0) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_lhs>>{%B, %M, %K}
%1 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(1) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_rhs>>{%B, %K, %N}
%2 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(2) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_lhs>>{%B, %M, %K}
-> tensor<?x?x?xbf16, #encoding_lhs>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?x?xbf16, #encoding_rhs>>{%B, %K, %N}
-> tensor<?x?x?xbf16, #encoding_rhs>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
: !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
-> tensor<?x?x?xf32, #encoding_result>
%6 = linalg.batch_matmul
ins(%3, %4 : tensor<?x?x?xbf16, #encoding_lhs>,
tensor<?x?x?xbf16, #encoding_rhs>)
outs(%5 : tensor<?x?x?xf32, #encoding_result>)
-> tensor<?x?x?xf32, #encoding_result>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
: tensor<?x?x?xf32, #encoding_result>
-> !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x4xbf16>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x4xbf16>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_BF16, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
60 changes: 58 additions & 2 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context);
Type f16 = Float16Type::get(context);
Type bf16 = BFloat16Type::get(context);
Type f32 = Float32Type::get(context);

Type i8 = IntegerType::get(context, 8);
Expand All @@ -229,6 +230,12 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
return OpaqueMmaLayout{16, 16, 16, bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
Expand Down Expand Up @@ -336,6 +343,45 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>
// #layout_c = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout = inner;
auto cNLayout = outer;
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]>
// #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]>
// #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX],
// [4, 2, 4]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner1>
// #layout_b = #iree_vector_ext.layout<#inner1, #outer>
// #layout_c = #iree_vector_ext.layout<#inner2, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {32});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 4});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout =
PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4});
auto cNLayout = outer;
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
Expand Down Expand Up @@ -462,14 +508,16 @@ MMAAttr::getABCVectorTypes() const {
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({16}, getCType());
Expand Down Expand Up @@ -519,8 +567,10 @@ int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
Expand All @@ -540,8 +590,10 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
switch (intrinsic) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
Expand Down Expand Up @@ -584,6 +636,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
Expand All @@ -597,6 +650,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
}
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
switch (fragment) {
case MMAFragment::Lhs:
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32},
Expand Down Expand Up @@ -704,8 +758,10 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>;
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>;
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>;
def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>;
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>;
Expand All @@ -143,6 +145,8 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
MFMA_F32_16x16x4_F32,
MFMA_F32_16x16x16_F16,
MFMA_F32_32x32x8_F16,
MFMA_F32_16x16x16_BF16,
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E5M2FNUZ,
MFMA_I32_16x16x32_I8,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ const WgpDetails *getCDNA3WgpDetails() {
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
MMAIntrinsic::MFMA_F32_16x16x16_BF16,
MMAIntrinsic::MFMA_F32_32x32x8_BF16,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ,
MMAIntrinsic::MFMA_I32_16x16x32_I8,
Expand Down
29 changes: 29 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,35 @@ iree_generated_e2e_runner_test(
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_bf16_cdna3_mfma_data_tiled
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=bf16"
"--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
"--iree-opt-data-tiling"
"--iree-global-opt-experimental-rocm-data-tiling"
"--iree-global-opt-enable-early-materialization=true"
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_i8_cdna3_mfma_data_tiled
Expand Down

0 comments on commit 1fc6e5b

Please sign in to comment.