From 1aa58257eb65db2a01227fb60b409b219f4790ff Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 25 Oct 2024 16:43:50 +0100 Subject: [PATCH] [LLVMGPU] Combine parallel and reduction padding in LLVMGPUPadAndVectorDistribute (#18771) Since https://github.com/iree-org/iree/pull/18748 tensor.pad can be fused in with tiling. This patch combines the parallel and reduction padding passes into a single pass that pads at once, and the pads are later fused during tiling. --- .../LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp | 131 +++--------------- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 9 +- .../iree/compiler/Codegen/LLVMGPU/Passes.h | 4 - .../iree/compiler/Codegen/LLVMGPU/Passes.td | 13 -- .../pipeline_vector_distribute_gfx940.mlir | 8 +- .../test/promote_matmul_to_fit_mma.mlir | 129 +++-------------- 6 files changed, 41 insertions(+), 253 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp index dbcc5b1e54b6..24214940f30e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp @@ -27,25 +27,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final public: using impl::LLVMGPUPromoteMatmulToFitMMAPassBase< LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase; - explicit LLVMGPUPromoteMatmulToFitMMAPass( - const LLVMGPUMatmulPadOption &option) { - this->targetDimensions.setValue(option); - } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op, - ArrayRef paddingDims, - ArrayRef padToMultipleOf, bool noFold) const { - assert(paddingDims.size() == padToMultipleOf.size() && - "invalid pad multiples for padding dimensions"); - + ArrayRef padToMultipleOf) const { LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); - SmallVector nofoldFlags(op.getNumDpsInputs(), noFold); + SmallVector paddingDims = + llvm::to_vector(llvm::seq(padToMultipleOf.size())); SmallVector paddingValueAttributes; for (auto &operand : op->getOpOperands()) { @@ -58,7 +51,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final .setPaddingDimensions(paddingDims) .setPaddingValues(paddingValueAttributes) .setPadToMultipleOf(padToMultipleOf) - .setNofoldFlags(nofoldFlags) .setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None); FailureOr result = @@ -72,26 +64,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final MLIRContext *ctx = &getContext(); auto funcOp = getOperation(); - // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so - // we can kick canonicalization patterns to fold outer tensor.pad ops away. - bool noFold = false; - utils::IteratorType targetIterType = utils::IteratorType::parallel; - switch (targetDimensions) { - case LLVMGPUMatmulPadOption::ParallelDims: - LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n"); - targetIterType = utils::IteratorType::parallel; - noFold = false; - break; - case LLVMGPUMatmulPadOption::ReductionDims: - LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n"); - targetIterType = utils::IteratorType::reduction; - noFold = true; - break; - default: // Unreachable. - assert(false); - break; - }; - SmallVector candidates; funcOp->walk([&](linalg::LinalgOp op) { if (linalg::isaContractionOpInterface(op)) { @@ -101,46 +73,27 @@ class LLVMGPUPromoteMatmulToFitMMAPass final IRRewriter rewriter(ctx); for (linalg::LinalgOp op : candidates) { - SmallVector padMultiples(op.getNumLoops(), 1); auto config = dyn_cast_or_null( getLoweringConfig(op)); - if (config) { - switch (targetDimensions) { - case LLVMGPUMatmulPadOption::ParallelDims: - padMultiples = config.getStaticTilingLevelSizes( - static_cast(IREE::GPU::TilingLevel::Workgroup), op); - break; - case LLVMGPUMatmulPadOption::ReductionDims: - padMultiples = config.getStaticTilingLevelSizes( - static_cast(IREE::GPU::TilingLevel::Reduction), op); - break; - default: - assert(false && "Unexpected target dimensions"); - break; - } + if (!config) { + continue; } - // Populate padding dimensions. - SmallVector paddingDimensions; - for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) { - if (iter == targetIterType) { - paddingDimensions.push_back(idx); - } - } + SmallVector wgTiles = config.getStaticTilingLevelSizes( + static_cast(IREE::GPU::TilingLevel::Workgroup), op); + SmallVector redTiles = config.getStaticTilingLevelSizes( + static_cast(IREE::GPU::TilingLevel::Reduction), op); - // Populate tile sizes. We pad to multiples of workgroup/reduction - // tile sizes based on the selected target tiling dimensions. - // This pass is ran after the select target tiling is done to pad - // all dimensions to the select tile sizes. - SmallVector padToMultipleOf; - for (int64_t dim : paddingDimensions) { - if (padMultiples[dim] != 0) { - padToMultipleOf.push_back(padMultiples[dim]); - } + // Populate padding dimensions to maximum of possible tile sizes. + SmallVector padToMultipleOf(op.getNumLoops(), 1); + for (auto [wgTile, redTile, padMultiple] : + llvm::zip_equal(wgTiles, redTiles, padToMultipleOf)) { + padMultiple = std::max({wgTile, redTile, padMultiple}); } + SmallVector paddingDimensions = + llvm::to_vector(llvm::seq(op.getNumLoops())); - padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf, - noFold); + padWithZeroValue(rewriter, op, padToMultipleOf); } { @@ -156,58 +109,8 @@ class LLVMGPUPromoteMatmulToFitMMAPass final return signalPassFailure(); } } - - // XXX(hanchung): This is needed for pad op fusion, which will remove - // outer pad ops. I.e., it mainly wants to remove first pad op in the - // pad->extract_slice->pad chain, while the canonicalization pattern can - // only recognize slice->pad->slice->pad. - { - SmallVector padOps; - funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); }); - for (auto op : padOps) { - auto srcExtractSliceOp = - op.getSource().getDefiningOp(); - if (!srcExtractSliceOp) { - continue; - } - auto producerPadOp = - srcExtractSliceOp.getSource().getDefiningOp(); - if (!producerPadOp) { - continue; - } - auto src = producerPadOp.getSource() - .getDefiningOp(); - if (!src) { - continue; - } - - rewriter.setInsertionPointAfter(src); - SmallVector sizes = - tensor::getMixedSizes(rewriter, op.getLoc(), src); - SmallVector offsets(sizes.size(), - rewriter.getIndexAttr(0)); - SmallVector strides(sizes.size(), - rewriter.getIndexAttr(1)); - auto extractSliceOp = rewriter.create( - op.getLoc(), src.getResult(), offsets, sizes, strides); - rewriter.startOpModification(op); - producerPadOp.getSourceMutable().assign(extractSliceOp.getResult()); - rewriter.finalizeOpModification(op); - } - - RewritePatternSet patterns(ctx); - tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - } } }; } // namespace -std::unique_ptr> -createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) { - return std::make_unique(option); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 76b1af3204be..51fcc6b3996c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -858,25 +858,20 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); if (usePadToModelSharedMemcpy) { - LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims; - funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option)); + funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass()); } // Tile to reduction loops. { GPUApplyTilingLevelPassOptions options; options.tilingLevel = IREE::GPU::TilingLevel::Reduction; + options.allowZeroSlices = true; funcPassManager.addPass(createGPUApplyTilingLevelPass(options)); funcPassManager.addPass(affine::createLoopCoalescingPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } - if (usePadToModelSharedMemcpy) { - LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims; - funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option)); - } - funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h index c1181776e8f7..d9325647a50d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h @@ -103,10 +103,6 @@ verifyGPUMatmulPipeline(Operation *op, // Wrappers that not use tablegen options. //------------------------------------------------------------------------------ -enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims }; -std::unique_ptr> -createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option); - enum class GPUTensorCoreType { WMMA = 0, MMA_SYNC = 1, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td index ef51a6a9a883..815a82f28d8d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td @@ -105,19 +105,6 @@ def LLVMGPUPrefetchSharedMemoryPass : def LLVMGPUPromoteMatmulToFitMMAPass : InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> { let summary = "Pass to promote contraction ops to fit mma shapes"; - let options = [ - Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption", - /*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims", - "Select the strategy to control how multi_reduction is lowered.", - [{::llvm::cl::values( - clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims, - "parallel", - "Pad all the parallel dims for contraction ops."), - clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims, - "reduction", - "Pad all the reduction dims for contraction ops.") - )}]> - ]; } def LLVMGPUSelectLoweringStrategyPass : diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 610e114c81e3..d21faf8867b1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -511,7 +511,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]} // CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]] // CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]] -// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>) +// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>) // CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]] // CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]] // CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]] @@ -581,9 +581,11 @@ hal.executable public @pad_batch_matmul { // CHECK-SAME: memref<196x16x24xf32 // CHECK-SAME: vector<1x1x1xf32> // RHS +// The dynamic dimension should be removed after: +// https://github.com/llvm/llvm-project/pull/112236 // CHECK: vector.transfer_read -// CHECK-SAME: in_bounds = [true, true, false] -// CHECK-SAME: memref<1x8x24xf32 +// CHECK-SAME: in_bounds = [true, false, false] +// CHECK-SAME: memref<1x?x24xf32 // CHECK-SAME: vector<1x1x2xf32> // CHECK: scf.yield // OUTPUT diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir index bda4836eaec3..21bc2fc3cac3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir @@ -1,5 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=parallel}))" %s | FileCheck %s --check-prefixes=ALL,PARALLEL -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=reduction}))" %s | FileCheck %s --check-prefixes=ALL,REDUCTION +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma))" %s | FileCheck %s #pipeline_layout = #hal.pipeline.layout, @@ -34,114 +33,20 @@ func.func @batch_matmul_f16() { flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor> return } -// ALL-LABEL: func.func @batch_matmul_f16 -// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> -// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] -// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] -// PARALLEL: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] -// PARALLEL: } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16> -// PARALLEL: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] -// PARALLEL: } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16> -// PARALLEL: %[[FILL:.+]] = linalg.fill -// PARALLEL: %[[GEMM:.+]] = linalg.batch_matmul -// PARALLEL-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// PARALLEL-SAME: outs(%[[FILL]] +// CHECK-LABEL: func.func @batch_matmul_f16 +// CHECK: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> +// CHECK: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> +// CHECK: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] +// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] +// CHECK: } : tensor<1x?x1281xf16> to tensor<1x64x1296xf16> +// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] +// CHECK: } : tensor<1x1281x?xf16> to tensor<1x1296x128xf16> +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK: %[[GEMM:.+]] = linalg.batch_matmul +// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] +// CHECK-SAME: outs(%[[FILL]] -// The reduction dim is not tiled in the test case, so it pads it to the -// matmul intrinsic k. -// REDUCTION-DAG: %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]] -// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]] -// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] -// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16> -// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] -// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16> -// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul -// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// REDUCTION-SAME: outs(%[[FILL]] - -// ALL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]] -// ALL: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] - -// ----- - -#pipeline_layout = #hal.pipeline.layout, - #hal.pipeline.binding, - #hal.pipeline.binding -]> -#map = affine_map<()[s0] -> (s0 * 64)> -#map1 = affine_map<()[s0] -> (s0 * 128)> -#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)> -#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)> -#map4 = affine_map<()[s0] -> (-s0 + 64)> -#map5 = affine_map<()[s0] -> (-s0 + 128)> -#map6 = affine_map<(d0) -> (-d0 + 1281, 64)> -func.func @batch_matmul_pad_reduction_after_tiling() { - %c64 = arith.constant 64 : index - %c1281 = arith.constant 1281 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f16 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %workgroup_id_z = hal.interface.workgroup.id[2] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %3 = affine.apply #map()[%workgroup_id_y] - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %4 = affine.apply #map1()[%workgroup_id_x] - %5 = affine.min #map2()[%workgroup_id_y] - %6 = affine.min #map3()[%workgroup_id_x] - %7 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x?x1281xf16> - %dim = tensor.dim %7, %c1 : tensor<1x?x1281xf16> - %8 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1281x?xf16> - %dim_0 = tensor.dim %8, %c2 : tensor<1x1281x?xf16> - %9 = affine.apply #map4()[%5] - %padded = tensor.pad %7 low[0, 0, 0] high[0, %9, 0] { - ^bb0(%arg0: index, %arg1: index, %arg2: index): - tensor.yield %cst : f16 - } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16> - %10 = affine.apply #map5()[%6] - %padded_2 = tensor.pad %8 low[0, 0, 0] high[0, 0, %10] { - ^bb0(%arg0: index, %arg1: index, %arg2: index): - tensor.yield %cst : f16 - } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16> - %11 = tensor.empty() : tensor<1x64x128xf16> - %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16> - %13 = scf.for %arg0 = %c0 to %c1281 step %c64 iter_args(%arg1 = %12) -> (tensor<1x64x128xf16>) { - %14 = affine.min #map6(%arg0) - %extracted_slice_4 = tensor.extract_slice %padded[0, 0, %arg0] [1, 64, %14] [1, 1, 1] : tensor<1x64x1281xf16> to tensor<1x64x?xf16> - %extracted_slice_5 = tensor.extract_slice %padded_2[0, %arg0, 0] [1, %14, 128] [1, 1, 1] : tensor<1x1281x128xf16> to tensor<1x?x128xf16> - %15 = linalg.batch_matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<1x64x?xf16>, tensor<1x?x128xf16>) outs(%arg1 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16> - scf.yield %15 : tensor<1x64x128xf16> - } - %extracted_slice_3 = tensor.extract_slice %13[0, 0, 0] [1, %5, %6] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x?xf16> - flow.dispatch.tensor.store %extracted_slice_3, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor> - return -} -// The padding on parallel dims is a nop because they are already padded. Skip -// the check for the testcase. -// ALL-LABEL: func.func @batch_matmul_pad_reduction_after_tiling -// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> -// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] -// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] -// REDUCTION: %[[INIT:.+]] = tensor.empty() : tensor<1x64x128xf16> -// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[INIT]] -// REDUCTION: %[[RES:.+]] = scf.for {{.+}} iter_args(%[[ITER:.+]] = %[[FILL]]) -// REDUCTION: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]] -// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS_SLICE]] -// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x64xf16> -// REDUCTION: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]] -// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS_SLICE]] -// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x128xf16> -// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul -// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// REDUCTION-SAME: outs(%[[ITER]] -// REDUCTION: scf.yield %[[GEMM]] -// REDUCTION: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[RES]] -// REDUCTION: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] +// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]] +// CHECK: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]