Skip to content

Commit

Permalink
refine unrolling factors and tile dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacob committed Sep 19, 2024
1 parent a21a6dc commit c8dc5d1
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,30 @@ chooseDataTiledMMAAttr(TypeRange elementTypes, IREE::GPU::TargetAttr target) {
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];
auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollN,
auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollMToThreads,
int unrollN, int unrollNToThreads,
int unrollK) -> std::optional<DataTiledMMAAttr> {
if (!hasIntrinsic(target, intrinsic)) {
return std::nullopt;
}
auto candidate = DataTiledMMAAttr::get(
ctx, MMAIntrinsicAttr::get(ctx, intrinsic), unrollM, unrollN, unrollK);
ctx, MMAIntrinsicAttr::get(ctx, intrinsic), /*unroll_m=*/unrollM,
/*unroll_m_to_threads=*/unrollMToThreads, /*unroll_n=*/unrollN,
/*unroll_n_to_threads=*/unrollNToThreads, /*unroll_k=*/unrollK);
auto [candidateLhs, candidateRhs, candidateOut] =
candidate.getABCElementTypes();
if (candidateLhs != lhs || candidateRhs != rhs || candidateOut != out) {
return std::nullopt;
}
return candidate;
};
if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 8, 4)) {
if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 1, 2, 4, 4)) {
return m;
}
if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 8, 2)) {
if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 1, 2, 4, 2)) {
return m;
}
if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 8, 2)) {
if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 1, 2, 4, 2)) {
return m;
}
// Fallback - no architecture-optimized tile size for this case.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
unroll(swizzle, 0, mma.getUnrollM(), /*cross_thread=*/false,
/*cross_intrinsic=*/true);
}
if (mma.getUnrollMToThreads() > 1) {
unroll(swizzle, 0, mma.getUnrollMToThreads(), /*cross_thread=*/true,
/*cross_intrinsic=*/true);
}
break;
case IREE::GPU::MMAFragment::Rhs:
// B-matrix (RHS). Since the pack ops already took care of transposing B,
Expand All @@ -212,6 +216,10 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
unroll(swizzle, 0, mma.getUnrollN(), /*cross_thread=*/false,
/*cross_intrinsic=*/true);
}
if (mma.getUnrollNToThreads() > 1) {
unroll(swizzle, 0, mma.getUnrollNToThreads(), /*cross_thread=*/true,
/*cross_intrinsic=*/true);
}
break;
case IREE::GPU::MMAFragment::Acc:
// C-matrix (accumulator). Source dimensions are M (index 0) and N (index
Expand All @@ -220,10 +228,18 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
unroll(swizzle, 1, mma.getUnrollN(), /*cross_thread=*/false,
/*cross_intrinsic=*/true);
}
if (mma.getUnrollNToThreads() > 1) {
unroll(swizzle, 1, mma.getUnrollNToThreads(), /*cross_thread=*/true,
/*cross_intrinsic=*/true);
}
if (mma.getUnrollM() > 1) {
unroll(swizzle, 0, mma.getUnrollM(), /*cross_thread=*/false,
/*cross_intrinsic=*/true);
}
if (mma.getUnrollMToThreads() > 1) {
unroll(swizzle, 0, mma.getUnrollMToThreads(), /*cross_thread=*/true,
/*cross_intrinsic=*/true);
}
break;
}
return swizzle;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ func.func @set_encoding_RHS_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK-SAME: inner_tiles = [128, 16]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<5x16x128x16xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x8x16x4x4xf32>
// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x4x2x16x4x4xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x8x16x4x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<5x16x8x4x16x4xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4]
// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x4x2x16x4x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<5x16x4x2x4x16x4xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]

// -----
Expand Down Expand Up @@ -96,11 +96,11 @@ func.func @set_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK-SAME: inner_tiles = [128, 128]
// CHECK-SAME: : tensor<255x513xf32> -> tensor<2x5x128x128xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x8x16xf32>
// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x4x2x16xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x8x4x16x4xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 6, 4]
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]

// -----
Expand All @@ -124,11 +124,11 @@ func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {

// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x8x4x16x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x8x16xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5]
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xf32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
// CHECK-SAME: : tensor<2x5x8x4x4x8x16xf32> into tensor<2x5x128x128xf32>
// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xf32> into tensor<2x5x128x128xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
Expand Down Expand Up @@ -187,12 +187,12 @@ func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() {
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x8x4x16x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x4xf32>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xf32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 8, unroll_k = 4>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 8, unroll_n = 2, unroll_n_to_threads = 4, unroll_k = 4>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]


Expand Down Expand Up @@ -257,11 +257,11 @@ func.func @set_encoding_RHS_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK-SAME: inner_tiles = [128, 64]
// CHECK-SAME: : tensor<255x513xi8> -> tensor<5x4x128x64xi8>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x8x16x2x4x8xi8>
// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x4x2x16x2x4x8xi8>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x8x16x2x4x8xi8>)
// CHECK-SAME: outs({{.*}} : tensor<5x4x8x4x16x2x8xi8>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4, 6]
// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x4x2x16x2x4x8xi8>)
// CHECK-SAME: outs({{.*}} : tensor<5x4x4x2x4x16x2x8xi8>)
// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5, 7]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]

// -----
Expand Down Expand Up @@ -290,11 +290,11 @@ func.func @set_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK-SAME: inner_tiles = [128, 128]
// CHECK-SAME: : tensor<255x513xi32> -> tensor<2x5x128x128xi32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x8x16xi32>
// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x4x2x16xi32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xi32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x8x4x16x4xi32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 6, 4]
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xi32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xi32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4]
// CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]]

// -----
Expand All @@ -318,11 +318,11 @@ func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {

// CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x8x4x16x4xi32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x8x16xi32>)
// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5]
// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xi32>)
// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xi32>)
// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
// CHECK-SAME: : tensor<2x5x8x4x4x8x16xi32> into tensor<2x5x128x128xi32>
// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xi32> into tensor<2x5x128x128xi32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]]
// CHECK-SAME: outer_dims_perm = [0, 1]
// CHECK-SAME: inner_dims_pos = [0, 1]
Expand Down Expand Up @@ -382,10 +382,10 @@ func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8() {
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x8x4x16x2x8xi8>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x8x4x16x4xi32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x2x4x16x2x8xi8>
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x8x4x2x4x16x4xi32>
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 8, unroll_k = 2>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_threads = 4, unroll_k = 2>
// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,8 @@ std::tuple<Type, Type, Type> DataTiledMMAAttr::getABCElementTypes() const {
std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
return {opaqueLayout.mSize * getUnrollM(), opaqueLayout.nSize * getUnrollN(),
return {opaqueLayout.mSize * getUnrollM() * getUnrollMToThreads(),
opaqueLayout.nSize * getUnrollN() * getUnrollNToThreads(),
opaqueLayout.kSize * getUnrollK()};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,11 @@ def IREEGPU_DataTiledMMAAttr :

let parameters = (ins
"::mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr":$intrinsic,
"int64_t":$unroll_m,
"int64_t":$unroll_n,
"int64_t":$unroll_k
DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, on the same thread.">:$unroll_m,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, distributed across this many more threads.">:$unroll_m_to_threads,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, on the same thread.">:$unroll_n,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, distributed across this many more threads.">:$unroll_n_to_threads,
DefaultValuedParameter<"int64_t", "1", "Unrolling along the K dimension, on the same thread, with interleaved layout.">:$unroll_k
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ module {

module {
func.func @test_data_tiled_mfma_f32_16x16x4_f32() attributes {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>} {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_threads = 2, unroll_k = 1>} {
return
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x4_f32
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 4, unroll_m_to_threads = 2>

module {
func.func @test_data_tiled_mfma_f32_16x16x16_f16() attributes {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n = 1, unroll_k = 1>} {
mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n_to_threads = 2, unroll_k = 2>} {
return
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x16_f16
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 1, unroll_n = 1, unroll_k = 1>
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_n_to_threads = 2, unroll_k = 2>

module {
func.func @test_data_tiled_mfma_i32_16x16x32_i8() attributes {
Expand All @@ -52,7 +52,7 @@ module {
}
}
// CHECK-LABEL: func @test_data_tiled_mfma_i32_16x16x32_i8
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 1, unroll_n = 1, unroll_k = 1>
// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8>

module {
func.func @test_any_lowering_config() attributes {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<?x?x4x16x1x1xf32>, %rh
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
indexing_maps = #contraction_accesses,
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
} : tensor<?x?x4x16x1x1xf32>, tensor<?x?x4x16x1x1xf32> into tensor<?x?x4x16x4x1xf32>
return %0 : tensor<?x?x4x16x4x1xf32>
}
Expand All @@ -226,7 +226,7 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<?x?x4x16x1x1xf32>, %rh
// CHECK: iree_gpu.multi_mma %arg0, %arg1, %arg2
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_m = 1, unroll_n = 1, unroll_k = 1>
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32>
// CHECK-SAME: : tensor<?x?x4x16x1x1xf32>, tensor<?x?x4x16x1x1xf32> into tensor<?x?x4x16x4x1xf32>

// -----
Expand Down

0 comments on commit c8dc5d1

Please sign in to comment.