Skip to content

Commit

Permalink
[Codegen][GPU] Rework scf.forall fusion to support different thread c…
Browse files Browse the repository at this point in the history
…ounts (iree-org#18280)

The current fusion pattern is restricted to cases where the thread count
of each loop being fused is statically the same. This changes the
pattern to instead generate an scf.for loop within the consumer loop and
map the producer loop to the iteration space of the consumer loop. This
will allow supporting dynamic and unaligned code generation.
  • Loading branch information
qedawkins authored Aug 20, 2024
1 parent 87084d5 commit ab0d4c6
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter,
}
}
}
tilingOptions.setMapping(mapping);
tilingOptions.setMapping(llvm::to_vector(llvm::reverse(mapping)));
}

scf::SCFTileAndFuseOptions tileAndFuseOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ module {
// THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (2, 16)
// THREAD: linalg.generic {{.*}} ins(%{{.*}}: tensor<2x16xf32>, tensor<2x16xf32>)
// THREAD: scf.forall.in_parallel
// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// SUBGROUP-LABEL: func.func @add_tensor
// SUBGROUP: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (2, 16)
// SUBGROUP: linalg.generic {{.*}} ins(%{{.*}}: tensor<2x16xf32>, tensor<2x16xf32>)
// SUBGROUP: scf.forall.in_parallel
// SUBGROUP: mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>]
// SUBGROUP: mapping = [#gpu.warp<linear_dim_1>, #gpu.warp<linear_dim_0>]

// -----

Expand Down Expand Up @@ -138,13 +138,13 @@ func.func @matmul_transpose_b() attributes {translation_info = #iree_codegen.tra
// THREAD-LABEL: func.func @matmul_transpose_b
// THREAD: scf.forall ({{.*}}) in (64, 4)
// THREAD: linalg.copy
// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// THREAD: scf.forall ({{.*}}) in (64, 4)
// THREAD: linalg.copy
// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]
// THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 64) step (4, 4)
// THREAD: linalg.matmul
// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

Expand Down Expand Up @@ -310,7 +310,7 @@ module {
// THREAD: scf.forall ({{.*}}) = (0, 0) to (64, 256) step (8, 4)
// THREAD: linalg.generic {{.*}} ins(%{{.*}}: tensor<8x4xf32>, tensor<8x4xf32>)
// THREAD: scf.forall.in_parallel
// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>]
// THREAD: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

Expand Down Expand Up @@ -344,7 +344,7 @@ module {
// THREAD: scf.forall ({{.*}}) = (0, 0, 0) to (2, 128, 8) step (1, 1, 4)
// THREAD: iree_linalg_ext.im2col {{.*}} ins(%{{.*}}: tensor<1x34x34x128xf16>) outs({{.*}}: tensor<1x1x4xf16>)
// THREAD: scf.forall.in_parallel
// THREAD: mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_2>]
// THREAD: mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def FuseForallOp : Op<Transform_Dialect, "iree.fuse_forall",
`extract_slice` of the consumer. If specified, uses |address_space| for
the intermediate allocation.

The mapping attributes of both the producer and consumer `scf.forall` ops
must be in a relative descending order, for example:
[#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>]
or
[#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

NOTE: This pattern implicitly REQUIRES that the resulting scf.forall
is capable of synchronizing all threads at the point of fusion (i.e.
inserting a barrier). This invalidates certain kinds of lowerings of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ module attributes { transform.with_named_sequence } {
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2)>
// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0) -> (d0 * 2)>

// CHECK-LABEL: func @fuse_forall
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
Expand All @@ -49,14 +49,18 @@ module attributes { transform.with_named_sequence } {
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK-DAG: %[[OUTID0:.+]] = affine.apply #[[$MAP]](%[[IDX]])
// CHECK-DAG: %[[OUTID1:.+]] = affine.apply #[[$MAP]](%[[IDY]])
// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c64, %c1) : index, index
// CHECK: %[[INID0:.+]] = affine.apply #[[$MAP2]](%[[IDS]]#0)
// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[EMPTY]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ALLOC]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]

// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %c0 to %c64{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) -> (tensor<128x128xf32>)
// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP2]](%[[I]], %[[IDX]], %[[IDY]])
// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINEARID]] into (%c1, %c64) : index, index
// CHECK: %[[INID0:.+]] = affine.apply #[[$MAP3]](%[[IDS]]#1)
// CHECK: %[[INSLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
// CHECK: %[[INSLICE1:.+]] = tensor.extract_slice %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1]
// CHECK: scf.yield %[[INSERT]]

// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
// CHECK: iree_gpu.yield %[[SLICE]]
Expand Down Expand Up @@ -108,18 +112,19 @@ module attributes { transform.with_named_sequence } {
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>

// CHECK-LABEL: func @fuse_forall
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>

// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[ALLOC]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}

Expand Down Expand Up @@ -163,18 +168,19 @@ module attributes { transform.with_named_sequence } {
}
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 16)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 2)>

// CHECK-LABEL: func @fuse_forall_with_reshape
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>

// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<128x128xf32>
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[ALLOC]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[INTERMEDIATE]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]][0, %{{.*}}, %{{.*}}] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32>
Expand Down Expand Up @@ -227,9 +233,9 @@ module attributes { transform.with_named_sequence } {
}
}

// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2 * 4)>
// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 32 + d3 * 16)>
// CHECK: #[[$MAP4:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 8 + d2 * 4)>
// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 + d1 + d2 * 4 + d3 * 32 + d4 * 16)>
// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0) -> (d0 * 2)>

// CHECK-LABEL: func @fuse_thread_forall_with_warp_and_lane
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
Expand All @@ -238,12 +244,124 @@ module attributes { transform.with_named_sequence } {
// CHECK-DAG: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[W_IDX:.+]], %[[W_IDY:.+]]) in (2, 2) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: scf.forall (%[[L_IDX:.+]], %[[L_IDY:.+]]) in (4, 4) {{.*}} -> (tensor<64x64xf32>)
// CHECK-DAG: %[[FLAT_ID:.+]] = affine.apply #[[$MAP3]](%[[L_IDY]], %[[L_IDX]], %[[W_IDX]], %[[W_IDY]])
// CHECK-DAG: %[[IDS:.+]]:2 = affine.delinearize_index %[[FLAT_ID]] into (%c64, %c1) : index, index
// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[$MAP4]](%[[IDS]]#0)
// CHECK: %[[COPY:.+]] = linalg.copy
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ALLOC]][%[[IDX]], %[[IDS]]#1] [2, 128]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[INSERT]]

// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %c0 to %c64{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]]) -> (tensor<128x128xf32>)
// CHECK: %[[FLAT_ID:.+]] = affine.apply #[[$MAP4]](%[[I]], %[[L_IDY]], %[[L_IDX]], %[[W_IDX]], %[[W_IDY]])
// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[FLAT_ID]] into (%c1, %c64) : index, index
// CHECK: %[[IDX:.+]] = affine.apply #[[$MAP5]](%[[IDS]]#1)
// CHECK: %[[COPY:.+]] = linalg.copy
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#0] [2, 128]
// CHECK: scf.yield %[[INSERT]]

// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]}
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}

// -----

#map = affine_map<(d0) -> (d0 * 4)>
#map1 = affine_map<(d0) -> (d0 * 16)>
module {
func.func @fuse_forall_different_thread_count(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
%0 = tensor.empty() : tensor<128x128xf32>
%2 = scf.forall (%arg5) in (32) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
%4 = affine.apply #map(%arg5)
%extracted_slice = tensor.extract_slice %arg0[%4, 0] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32>
%extracted_slice_0 = tensor.extract_slice %arg7[%4, 0] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32>
%5 = linalg.copy ins(%extracted_slice : tensor<4x128xf32>) outs(%extracted_slice_0 : tensor<4x128xf32>) -> tensor<4x128xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg7[%4, 0] [4, 128] [1, 1] : tensor<4x128xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<x>]}
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
%6 = affine.apply #map1(%arg5)
%7 = affine.apply #map1(%arg6)
%extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
%extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
%8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
return %3 : tensor<128x128xf32>
}
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
%producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>

// CHECK-LABEL: func @fuse_forall_different_thread_count
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>

// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) {{.*}} -> (tensor<128x128xf32>) {
// CHECK: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index
// CHECK: scf.yield
// CHECK: iree_gpu.barrier_region %[[LOOP]]
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}

// -----

#map = affine_map<(d0) -> (d0 * 4)>
#map1 = affine_map<(d0) -> (d0 * 16)>
module {
func.func @fuse_forall_dynamic_thread_count(%arg0: tensor<128x128xf32>, %x: index, %y: index, %z: index) -> tensor<128x128xf32> {
%0 = tensor.empty() : tensor<128x128xf32>
%2 = scf.forall (%arg5, %arg6, %arg7) in (%x, %y, %z) shared_outs(%arg8 = %0) -> (tensor<128x128xf32>) {
%slice = tensor.extract_slice %arg0[%arg5, %arg6] [4, 128] [1, 1] : tensor<128x128xf32> to tensor<4x128xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %slice into %arg8[%arg7, 0] [4, 128] [1, 1] : tensor<4x128xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>]}
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
%6 = affine.apply #map1(%arg5)
%7 = affine.apply #map1(%arg6)
%extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
%extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
%8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
return %3 : tensor<128x128xf32>
}
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
%producer, %consumer = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.iree.fuse_forall %producer into %consumer : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
// CHECK-DAG: #[[$MAP3:.+]] = affine_map<()[s0, s1, s2] -> (s2 * (s0 * s1))>

// CHECK-LABEL: func @fuse_forall_dynamic_thread_count
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[X:[A-Za-z0-9]+]]: index
// CHECK-SAME: %[[Y:[A-Za-z0-9]+]]: index
// CHECK-SAME: %[[Z:[A-Za-z0-9]+]]: index

// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<128x128xf32>
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) {{.*}} -> (tensor<128x128xf32>) {
// CHECK-DAG: %[[LINEARID:.+]] = affine.apply #[[$MAP1]](%[[IDX]], %[[IDY]])
// CHECK-DAG: %[[PRODCOUNT:.+]] = affine.apply #[[$MAP3]]()[%[[X]], %[[Y]], %[[Z]]]
// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index
// CHECK: scf.yield
// CHECK: iree_gpu.barrier_region %[[LOOP]]
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
Loading

0 comments on commit ab0d4c6

Please sign in to comment.