diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp index 1c3fac336ff3..9b016fb884bf 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp @@ -124,7 +124,7 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter, } } } - tilingOptions.setMapping(mapping); + tilingOptions.setMapping(llvm::to_vector(llvm::reverse(mapping))); } scf::SCFTileAndFuseOptions tileAndFuseOptions; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir index 170161be77b1..1b9cc625b073 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir @@ -41,13 +41,13 @@ module { // THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (2, 16) // THREAD: linalg.generic {{.*}} ins(%{{.*}}: tensor<2x16xf32>, tensor<2x16xf32>) // THREAD: scf.forall.in_parallel -// THREAD: mapping = [#gpu.thread, #gpu.thread] +// THREAD: mapping = [#gpu.thread, #gpu.thread] // SUBGROUP-LABEL: func.func @add_tensor // SUBGROUP: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (2, 16) // SUBGROUP: linalg.generic {{.*}} ins(%{{.*}}: tensor<2x16xf32>, tensor<2x16xf32>) // SUBGROUP: scf.forall.in_parallel -// SUBGROUP: mapping = [#gpu.warp, #gpu.warp] +// SUBGROUP: mapping = [#gpu.warp, #gpu.warp] // ----- @@ -138,13 +138,13 @@ func.func @matmul_transpose_b() attributes {translation_info = #iree_codegen.tra // THREAD-LABEL: func.func @matmul_transpose_b // THREAD: scf.forall ({{.*}}) in (64, 4) // THREAD: linalg.copy -// THREAD: mapping = [#gpu.thread, #gpu.thread] +// THREAD: mapping = [#gpu.thread, #gpu.thread] // THREAD: scf.forall ({{.*}}) in (64, 4) // THREAD: linalg.copy -// THREAD: mapping = [#gpu.thread, #gpu.thread] +// THREAD: mapping = [#gpu.thread, #gpu.thread] // THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 64) step (4, 4) // THREAD: linalg.matmul -// THREAD: mapping = [#gpu.thread, #gpu.thread] +// THREAD: mapping = [#gpu.thread, #gpu.thread] // ----- @@ -310,7 +310,7 @@ module { // THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (8, 4) // THREAD: linalg.generic {{.*}} ins(%{{.*}}: tensor<8x4xf32>, tensor<8x4xf32>) // THREAD: scf.forall.in_parallel -// THREAD: mapping = [#gpu.thread, #gpu.thread] +// THREAD: mapping = [#gpu.thread, #gpu.thread] // ----- @@ -344,7 +344,7 @@ module { // THREAD: scf.forall ({{.*}}) = (0, 0, 0) to (2, 128, 8) step (1, 1, 4) // THREAD: iree_linalg_ext.im2col {{.*}} ins(%{{.*}}: tensor<1x34x34x128xf16>) outs({{.*}}: tensor<1x1x4xf16>) // THREAD: scf.forall.in_parallel -// THREAD: mapping = [#gpu.thread, #gpu.thread, #gpu.thread] +// THREAD: mapping = [#gpu.thread, #gpu.thread, #gpu.thread] // ----- diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td index e1d9c4c0c50e..ceba3f086b99 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td @@ -198,6 +198,12 @@ def FuseForallOp : Op, #gpu.thread, #gpu.thread] + or + [#gpu.thread, #gpu.thread] + NOTE: This pattern implicitly REQUIRES that the resulting scf.forall is capable of synchronizing all threads at the point of fusion (i.e. inserting a barrier). This invalidates certain kinds of lowerings of diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir index 4bf4fe1d4573..b70768df5e33 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir @@ -37,9 +37,9 @@ module attributes { transform.with_named_sequence } { } } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK-LABEL: func @fuse_forall // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32> @@ -49,14 +49,18 @@ module attributes { transform.with_named_sequence } { // CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) { // CHECK-DAG: %[[OUTID0:.+]] = affine.apply #[[$MAP]](%[[IDX]]) // CHECK-DAG: %[[OUTID1:.+]] = affine.apply #[[$MAP]](%[[IDY]]) -// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]]) -// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c64, %c1) : index, index -// CHECK: %[[INID0:.+]] = affine.apply #[[$MAP2]](%[[IDS]]#0) -// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32> -// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[EMPTY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32> -// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32> -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ALLOC]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] -// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]] + +// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %c0 to %c64{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) -> (tensor<128x128xf32>) +// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP2]](%[[I]], %[[IDX]], %[[IDY]]) +// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c1, %c64) : index, index +// CHECK: %[[INID0:.+]] = affine.apply #[[$MAP3]](%[[IDS]]#1) +// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32> +// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32> +// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32> +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1] +// CHECK: scf.yield %[[INSERT]] + +// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]] // CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>): // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32> // CHECK: iree_gpu.yield %[[SLICE]] @@ -108,9 +112,9 @@ module attributes { transform.with_named_sequence } { } } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK-LABEL: func @fuse_forall // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32> @@ -118,8 +122,9 @@ module attributes { transform.with_named_sequence } { // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32> // CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<128x128xf32> // CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) { -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[ALLOC]] -// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]] +// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]]) +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]] +// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]] // CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32> // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} @@ -163,9 +168,9 @@ module attributes { transform.with_named_sequence } { } } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK-LABEL: func @fuse_forall_with_reshape // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32> @@ -173,8 +178,9 @@ module attributes { transform.with_named_sequence } { // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32> // CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<128x128xf32> // CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) { -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[ALLOC]] -// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]] +// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]]) +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]] +// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]] // CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>): // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[INTERMEDIATE]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]][0, %{{.*}}, %{{.*}}] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32> @@ -227,9 +233,9 @@ module attributes { transform.with_named_sequence } { } } -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2 * 4)> -// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 32 + d3 * 16)> -// CHECK: #[[$MAP4:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2 * 4)> +// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 + d1 + d2 * 4 + d3 * 32 + d4 * 16)> +// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK-LABEL: func @fuse_thread_forall_with_warp_and_lane // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32> @@ -238,12 +244,124 @@ module attributes { transform.with_named_sequence } { // CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<128x128xf32> // CHECK: scf.forall (%[[W_IDX:.+]], %[[W_IDY:.+]]) in (2, 2) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) { // CHECK: scf.forall (%[[L_IDX:.+]], %[[L_IDY:.+]]) in (4, 4) {{.*}} -> (tensor<64x64xf32>) -// CHECK-DAG: %[[FLAT_ID:.+]] = affine.apply #[[$MAP3]](%[[L_IDY]], %[[L_IDX]], %[[W_IDX]], %[[W_IDY]]) -// CHECK-DAG: %[[IDS:.+]]:2 = affine.delinearize_index %[[FLAT_ID]] into (%c64, %c1) : index, index -// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$MAP4]](%[[IDS]]#0) -// CHECK: %[[COPY:.+]] = linalg.copy -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ALLOC]][%[[IDX]], %[[IDS]]#1] [2, 128] -// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]] + +// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %c0 to %c64{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) -> (tensor<128x128xf32>) +// CHECK: %[[FLAT_ID:.+]] = affine.apply #[[$MAP4]](%[[I]], %[[L_IDY]], %[[L_IDX]], %[[W_IDX]], %[[W_IDY]]) +// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[FLAT_ID]] into (%c1, %c64) : index, index +// CHECK: %[[IDX:.+]] = affine.apply #[[$MAP5]](%[[IDS]]#1) +// CHECK: %[[COPY:.+]] = linalg.copy +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#0] [2, 128] +// CHECK: scf.yield %[[INSERT]] + +// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]] // CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32> // CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]} // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} + +// ----- + +#map = affine_map<(d0) -> (d0 * 4)> +#map1 = affine_map<(d0) -> (d0 * 16)> +module { + func.func @fuse_forall_different_thread_count(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> { + %0 = tensor.empty() : tensor<128x128xf32> + %2 = scf.forall (%arg5) in (32) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) { + %4 = affine.apply #map(%arg5) + %extracted_slice = tensor.extract_slice %arg0[%4, 0] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32> + %extracted_slice_0 = tensor.extract_slice %arg7[%4, 0] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32> + %5 = linalg.copy ins(%extracted_slice : tensor<4x128xf32>) outs(%extracted_slice_0 : tensor<4x128xf32>) -> tensor<4x128xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg7[%4, 0] [4, 128] [1, 1] : tensor<4x128xf32> into tensor<128x128xf32> + } + } {mapping = [#gpu.thread]} + %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) { + %6 = affine.apply #map1(%arg5) + %7 = affine.apply #map1(%arg6) + %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32> + %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32> + %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + return %3 : tensor<128x128xf32> + } +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> + +// CHECK-LABEL: func @fuse_forall_different_thread_count +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32> + +// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<128x128xf32> +// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) {{.*}} -> (tensor<128x128xf32>) { +// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]]) +// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) +// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index +// CHECK: scf.yield +// CHECK: iree_gpu.barrier_region %[[LOOP]] +// CHECK: } {mapping = [#gpu.thread, #gpu.thread]} + +// ----- + +#map = affine_map<(d0) -> (d0 * 4)> +#map1 = affine_map<(d0) -> (d0 * 16)> +module { + func.func @fuse_forall_dynamic_thread_count(%arg0: tensor<128x128xf32>, %x: index, %y: index, %z: index) -> tensor<128x128xf32> { + %0 = tensor.empty() : tensor<128x128xf32> + %2 = scf.forall (%arg5, %arg6, %arg7) in (%x, %y, %z) shared_outs(%arg8 = %0) -> (tensor<128x128xf32>) { + %slice = tensor.extract_slice %arg0[%arg5, %arg6] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %arg8[%arg7, 0] [4, 128] [1, 1] : tensor<4x128xf32> into tensor<128x128xf32> + } + } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} + %3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) { + %6 = affine.apply #map1(%arg5) + %7 = affine.apply #map1(%arg6) + %extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32> + %extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32> + %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + return %3 : tensor<128x128xf32> + } +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<()[s0, s1, s2] -> (s2 * (s0 * s1))> + +// CHECK-LABEL: func @fuse_forall_dynamic_thread_count +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32> +// CHECK-SAME: %[[X:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[Y:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[Z:[A-Za-z0-9]+]]: index + +// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<128x128xf32> +// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) {{.*}} -> (tensor<128x128xf32>) { +// CHECK-DAG: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]]) +// CHECK-DAG: %[[PRODCOUNT:.+]] = affine.apply #[[$MAP3]]()[%[[X]], %[[Y]], %[[Z]]] +// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) +// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index +// CHECK: scf.yield +// CHECK: iree_gpu.barrier_region %[[LOOP]] +// CHECK: } {mapping = [#gpu.thread, #gpu.thread]} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 708647119ef0..9e6ee98a80b3 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -22,6 +23,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -40,29 +42,14 @@ namespace mlir::iree_compiler::IREE::GPU { // Forall Fusion //===---------------------------------------------------------------------===// -static FailureOr getTripCount(scf::ForallOp loop) { - ArrayRef lbs = loop.getStaticLowerBound(); - ArrayRef ubs = loop.getStaticUpperBound(); - ArrayRef steps = loop.getStaticStep(); - - if (ShapedType::isDynamicShape(lbs) || ShapedType::isDynamicShape(ubs) || - ShapedType::isDynamicShape(steps)) { - return failure(); - } - - int64_t tripCount = 1; - for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) { - tripCount *= llvm::divideCeil((ub - lb), step); - } - return tripCount; -} - static FailureOr> getEquivalentMappingConsumerLoopNest(scf::ForallOp producer, scf::ForallOp consumer) { - auto checkMappingTypes = [&](ArrayRef array) { - return llvm::all_of(array, llvm::IsaPred) || - llvm::all_of(array, llvm::IsaPred); + auto compareMappingTypes = [&](ArrayRef l, ArrayRef r) { + return (llvm::all_of(l, llvm::IsaPred) && + llvm::all_of(r, llvm::IsaPred)) || + (llvm::all_of(l, llvm::IsaPred) && + llvm::all_of(r, llvm::IsaPred)); }; ArrayRef producerMapping = producer.getMappingAttr().getValue(); @@ -72,12 +59,34 @@ getEquivalentMappingConsumerLoopNest(scf::ForallOp producer, return failure(); } - if (producerMapping.front() == consumerMapping.front() && - checkMappingTypes(producerMapping) && - checkMappingTypes(consumerMapping)) { + auto isDescendingRelativeIndices = [&](ArrayRef array) { + int64_t prev = + llvm::cast(array[0]).getRelativeIndex(); + for (Attribute attr : array.drop_front()) { + int64_t relativeIndex = + llvm::cast(attr).getRelativeIndex(); + if (relativeIndex != prev - 1) { + return false; + } + prev = relativeIndex; + } + return true; + }; + + // Require descending relative indices so that the linearization and + // delinearization done in subsequent steps are valid. + if (!isDescendingRelativeIndices(producerMapping) || + !isDescendingRelativeIndices(consumerMapping)) { + return failure(); + } + + // If both loops share the same kind of mapping, return the sole consumer. + if (compareMappingTypes(producerMapping, consumerMapping)) { return SmallVector({consumer}); } + // The only other supported case is fusing a thread mapped loop into a nest + // of a warp and lane forall. if (!llvm::all_of(producerMapping, llvm::IsaPred) || !llvm::all_of(consumerMapping, llvm::IsaPred)) { @@ -91,64 +100,39 @@ getEquivalentMappingConsumerLoopNest(scf::ForallOp producer, return SmallVector({outerWarpLoop, consumer}); } -static LogicalResult compareWorkerCounts(scf::ForallOp producer, - ArrayRef consumers) { - FailureOr producerTripCount = getTripCount(producer); - if (failed(producerTripCount)) { - return failure(); - } - int64_t consumerTotal = 1; - for (auto consumer : consumers) { - FailureOr consumerTripCount = getTripCount(consumer); - if (failed(consumerTripCount)) { - return failure(); - } - consumerTotal *= *consumerTripCount; - } - if (*producerTripCount != consumerTotal) { +static FailureOr createSharedAllocDestination(RewriterBase &rewriter, + scf::ForallOp forallOp) { + if (forallOp->getNumResults() != 1) { return failure(); } - return success(); -} -static LogicalResult -replaceConsumerChain(RewriterBase &rewriter, Location loc, Value source, - tensor::ParallelInsertSliceOp parallelInsert, - SmallVector consumerChain) { - auto extractSlice = cast(consumerChain.back()); - OpBuilder::InsertionGuard g(rewriter); - Value shuffleDest = parallelInsert.getDest(); - auto empty = shuffleDest.getDefiningOp(); + auto empty = forallOp.getDpsInits()[0].getDefiningOp(); // Fail if the destination is not a `tensor.empty` op and cannot be trivially // converted to a `bufferization.alloc_tensor`. if (!empty) { return failure(); } - // Replace the destination with a `bufferization.alloc_tensor` op with - // memory space `#gpu.address_space`. - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(empty); - Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get( - rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); - auto allocTensor = rewriter.create( - empty->getLoc(), empty->getResultTypes()[0], empty.getDynamicSizes()); - allocTensor.setMemorySpaceAttr(sharedMemoryAddrSpace); - shuffleDest = allocTensor.getResult(); - } + // Create a `bufferization.alloc_tensor` op with memory space + // `#gpu.address_space`. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(empty); + Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get( + rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + auto allocTensor = rewriter.create( + empty->getLoc(), empty->getResultTypes()[0], empty.getDynamicSizes()); + allocTensor.setMemorySpaceAttr(sharedMemoryAddrSpace); + return allocTensor.getResult(); +} - // Create an insert_slice for the result of the first forall op into the - // shared memory alloc_tensor. - SmallVector sourceOffsets = parallelInsert.getMixedOffsets(); - SmallVector sourceSizes = parallelInsert.getMixedSizes(); - SmallVector sourceStrides = parallelInsert.getMixedStrides(); - Value insertedSlice = rewriter.create( - loc, parallelInsert.getSource(), shuffleDest, sourceOffsets, sourceSizes, - sourceStrides); +static void replaceConsumerChain(RewriterBase &rewriter, Location loc, + Value source, Value replacement, + SmallVector consumerChain) { + auto extractSlice = cast(consumerChain.back()); + OpBuilder::InsertionGuard g(rewriter); auto barrierRegionOp = rewriter.create( - loc, extractSlice.getType(), insertedSlice); + loc, extractSlice.getType(), replacement); rewriter.setInsertionPointToStart(barrierRegionOp.getBody()); auto terminator = rewriter.create(loc, extractSlice.getResult()); @@ -159,7 +143,6 @@ replaceConsumerChain(RewriterBase &rewriter, Location loc, Value source, ->replaceUsesOfWith(source, barrierRegionOp.getBody()->getArgument(0)); rewriter.replaceAllUsesExcept(extractSlice.getResult(), barrierRegionOp, terminator); - return success(); } LogicalResult fuseForallIntoSlice(RewriterBase &rewriter, @@ -191,6 +174,7 @@ LogicalResult fuseForallIntoSlice(RewriterBase &rewriter, }); }; + // Verify that both loops are normalized. if (!isAll(producer.getMixedStep(), 1) || !isAll(producer.getMixedLowerBound(), 0)) { return failure(); @@ -205,54 +189,132 @@ LogicalResult fuseForallIntoSlice(RewriterBase &rewriter, rewriter.setInsertionPoint(slice); - // Step 1. Compute the producer IDs in terms of the consumer IDs. + // Step 1. Get the destination of the producer loop as a shared memory + // allocation. + FailureOr sharedDest = + createSharedAllocDestination(rewriter, producer); + if (failed(sharedDest)) { + return failure(); + } + + // Step 2. Compute the producer IDs in terms of the consumer IDs. + // The producer IDs are computed as follows: + // + // producer = [p0, ..., pn] ∈ [0, ..., 0] to [P0, ..., Pn] + // consumer = [c0, ..., cn] ∈ [0, ..., 0] to [C0, ..., Cn] + // + // Not a real op + // | + // %ub = P0 * ... * Pn | + // %step = C0 * ... * Cn v + // %flatc = affine.linearize_index %c0, ..., %cn + // scf.for %id = %flatc to %ub step %step { + // %p:n = affine.delinearize_index %id into [%P0, ..., %Pn] + // ... + // } + // + // Note: We use 0 as the loop lower bound instead of the linearized consumer + // loop ID if possible to make later loop promotion patterns easier. MLIRContext *context = rewriter.getContext(); Location loc = producer.getLoc(); + // Compute the linearize consumer loop ID and total consumer loop worker + // count (C0 * ... * Cn). AffineExpr d0, d1, d2; bindDims(context, d0, d1, d2); AffineExpr mulAdd = d0 * d1 + d2; OpFoldResult linearId = rewriter.getIndexAttr(0); + OpFoldResult consumerWorkerCount = rewriter.getIndexAttr(1); for (auto loop : *consumerLoopNest) { for (auto [inductionVar, workerCount] : llvm::zip_equal(getAsOpFoldResult(loop.getInductionVars()), loop.getMixedUpperBound())) { linearId = affine::makeComposedFoldedAffineApply( rewriter, loc, mulAdd, {linearId, workerCount, inductionVar}); + consumerWorkerCount = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0 * d1, {consumerWorkerCount, workerCount}); } } - Value linearThreadIdVal = + // Compute the total producer loop worker count (P0 * ... * Pn). + Value linearConsumerIdVal = getValueOrCreateConstantIndexOp(rewriter, loc, linearId); - SmallVector ranges; - for (auto workerCount : producer.getStaticUpperBound()) { - ranges.push_back(rewriter.create(loc, workerCount)); + SmallVector producerRanges; + OpFoldResult producerWorkerCount = rewriter.getIndexAttr(1); + for (auto workerCount : producer.getMixedUpperBound()) { + producerRanges.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, workerCount)); + producerWorkerCount = affine::makeComposedFoldedAffineApply( + rewriter, loc, d0 * d1, {producerWorkerCount, workerCount}); } - ValueRange newIds = rewriter - .create( - loc, linearThreadIdVal, ranges) - .getResults(); - - // Step 2. Inline the region of the producer. - SmallVector bbArgReplacements(newIds); - bbArgReplacements.append(producer.getOutputs().begin(), - producer.getOutputs().end()); + std::optional staticProducerCount = + getConstantIntValue(producerWorkerCount); + std::optional staticConsumerCount = + getConstantIntValue(consumerWorkerCount); + bool perfectlyDivides = + staticConsumerCount && staticProducerCount && + staticProducerCount.value() % staticConsumerCount.value() == 0; + + // Step 3. Create the `scf.for` loop for the producer. + // If the consumer worker count perfectly divides the producer worker count, + // then we can use a lower bound of 0 and keep the loop bounds static. + Value lb = perfectlyDivides ? rewriter.create(loc, 0) + : linearConsumerIdVal; + Value ub = + getValueOrCreateConstantIndexOp(rewriter, loc, producerWorkerCount); + Value step = + getValueOrCreateConstantIndexOp(rewriter, loc, consumerWorkerCount); + auto newProducer = + rewriter.create(loc, lb, ub, step, *sharedDest); + Block *loopBody = newProducer.getBody(); + + // Get the replacement IDs for the producer loop. + rewriter.setInsertionPointToStart(loopBody); + Value newFlatProducerId = + perfectlyDivides + ? affine::makeComposedAffineApply( + rewriter, loc, d0 + d1, + {newProducer.getInductionVar(), linearConsumerIdVal}) + : newProducer.getInductionVar(); + + // We require a descending relative mapping, so delinearize in reverse order. + auto delinearize = rewriter.create( + loc, newFlatProducerId, llvm::to_vector(llvm::reverse(producerRanges))); + + SmallVector newBlockArgs = + llvm::map_to_vector(llvm::reverse(delinearize.getResults()), + [](OpResult r) -> Value { return r; }); + newBlockArgs.append(newProducer.getRegionIterArgs().begin(), + newProducer.getRegionIterArgs().end()); + + // Step 4. Inline the region of the producer and replace the terminator. scf::InParallelOp terminator = producer.getTerminator(); - rewriter.inlineBlockBefore(producer.getBody(), slice, bbArgReplacements); + rewriter.mergeBlocks(producer.getBody(), loopBody, newBlockArgs); rewriter.setInsertionPointAfter(terminator); auto parallelInsert = cast(*terminator.getYieldingOps().begin()); - if (failed(replaceConsumerChain(rewriter, loc, producer.getResult(0), - parallelInsert, consumerChain))) { - return failure(); - } - + // Create an insert_slice to yield from the loop body. + SmallVector sourceOffsets = parallelInsert.getMixedOffsets(); + SmallVector sourceSizes = parallelInsert.getMixedSizes(); + SmallVector sourceStrides = parallelInsert.getMixedStrides(); + Value insertedSlice = rewriter.create( + loc, parallelInsert.getSource(), parallelInsert.getDest(), + parallelInsert.getMixedOffsets(), parallelInsert.getMixedSizes(), + parallelInsert.getMixedStrides()); + rewriter.create(loc, insertedSlice); rewriter.eraseOp(parallelInsert); rewriter.eraseOp(terminator); + + // Step 5. Replace the extract slice with a `barrier_region` op to indicate + // synchronization of the shared tensor. + rewriter.setInsertionPointAfter(newProducer); + replaceConsumerChain(rewriter, loc, producer.getResult(0), + newProducer.getResult(0), consumerChain); + rewriter.eraseOp(producer); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index a24612486d7d..8bdcff776733 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -38,6 +38,12 @@ namespace mlir::iree_compiler::IREE::GPU { /// the single consumer loop at the given |slice| within the consumer of the /// producer. This is managed by inserting an `iree_gpu.barrier_region` at the /// boundary to synchronize the workers at the fusion point. +/// +/// The mapping attributes of both the producer and consumer `scf.forall` ops +/// must be in a relative descending order, for example: +/// [#gpu.thread, #gpu.thread, #gpu.thread] +/// or +/// [#gpu.thread, #gpu.thread] LogicalResult fuseForallIntoSlice(RewriterBase &rewriter, scf::ForallOp producer, scf::ForallOp consumer, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir index a50c0e32eec3..7bb8f8fc58ac 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir @@ -23,7 +23,7 @@ func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf1 scf.forall.in_parallel { tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} %10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) { %12 = affine.apply #map(%arg2) %13 = affine.apply #map1(%arg3) @@ -35,7 +35,7 @@ func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf1 scf.forall.in_parallel { tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} %11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) { %12 = affine.apply #map4(%arg2) %13 = affine.apply #map4(%arg3) @@ -46,7 +46,7 @@ func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf1 scf.forall.in_parallel { tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} scf.yield %11 : tensor<128x128xf32> } return %8 : tensor<128x128xf32> @@ -85,7 +85,7 @@ func.func @forall_fuse_then_hoist_mixed_mappings(%3: tensor<128x128xf16>, %5: te scf.forall.in_parallel { tensor.parallel_insert_slice %16 into %arg5[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16> } - } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} %11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) { %12 = affine.apply #map3(%arg2) %13 = affine.apply #map3(%arg3) @@ -96,7 +96,7 @@ func.func @forall_fuse_then_hoist_mixed_mappings(%3: tensor<128x128xf16>, %5: te scf.forall.in_parallel { tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} scf.yield %11 : tensor<128x128xf32> } return %8 : tensor<128x128xf32> @@ -139,7 +139,7 @@ func.func @forall_fuse_then_hoist_with_fill(%3: tensor<128x128xf16>, %4: tensor< scf.forall.in_parallel { tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<128x4xf16> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} %10 = scf.forall (%arg2, %arg3) in (2, 32) shared_outs(%arg4 = %7) -> (tensor<4x128xf16>) { %12 = affine.apply #map(%arg2) %13 = affine.apply #map1(%arg3) @@ -151,7 +151,7 @@ func.func @forall_fuse_then_hoist_with_fill(%3: tensor<128x128xf16>, %4: tensor< scf.forall.in_parallel { tensor.parallel_insert_slice %16 into %arg4[%12, %13] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<4x128xf16> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} %11 = scf.forall (%arg2, %arg3) in (8, 8) shared_outs(%arg4 = %arg1) -> (tensor<128x128xf32>) { %12 = affine.apply #map4(%arg2) %13 = affine.apply #map4(%arg3) @@ -162,7 +162,7 @@ func.func @forall_fuse_then_hoist_with_fill(%3: tensor<128x128xf16>, %4: tensor< scf.forall.in_parallel { tensor.parallel_insert_slice %14 into %arg4[%12, %13] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} scf.yield %11 : tensor<128x128xf32> } return %8 : tensor<128x128xf32> @@ -194,11 +194,11 @@ func.func @multi_hoist_and_fuse_trailing_stuff(%2: tensor<128x128xf16>) -> tenso scf.forall.in_parallel { tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<64x64xf16> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} scf.forall.in_parallel { tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<64x64xf16> into tensor<128x128xf16> } - } {mapping = [#gpu.warp, #gpu.warp]} + } {mapping = [#gpu.warp, #gpu.warp]} scf.yield %9 : tensor<128x128xf16> } %transpose = linalg.transpose ins(%8: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0] @@ -234,11 +234,11 @@ func.func @multi_hoist_and_fuse_trailing_with_producer_fusion(%2: tensor<128x128 scf.forall.in_parallel { tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 4] [1, 1] : tensor<2x4xf16> into tensor<64x64xf16> } - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.thread, #gpu.thread]} scf.forall.in_parallel { tensor.parallel_insert_slice %10 into %arg4[%arg2, %arg3] [64, 64] [1, 1] : tensor<64x64xf16> into tensor<128x128xf16> } - } {mapping = [#gpu.warp, #gpu.warp]} + } {mapping = [#gpu.warp, #gpu.warp]} scf.yield %9 : tensor<128x128xf16> } %transpose_input = linalg.transpose ins(%3: tensor<128x128xf16>) outs(%empty: tensor<128x128xf16>) permutation = [1, 0]