From cf2c68e2dc563d1a97a7b695cefd89f54d070396 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 10 Aug 2023 20:06:58 +0000 Subject: [PATCH 1/7] Add LLVMCPUDecomposeBatchMmt4D pass --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Codegen/Common/DecomposeBatchMmt4DOps.cpp | 159 ++++++++++++++++++ .../src/iree/compiler/Codegen/Common/Passes.h | 3 + .../iree/compiler/Codegen/Common/Passes.td | 7 + 5 files changed, 171 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 75b0e01c2975..d87c096a066e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -153,6 +153,7 @@ iree_compiler_cc_library( "ConvertBf16ToUInt16Buffers.cpp", "ConvertToDestinationPassingStylePass.cpp", "DecomposeAffineOpsPass.cpp", + "DecomposeBatchMmt4DOps.cpp", "DecomposeConvolutionToLowerDimOps.cpp", "DecomposeLinalgGeneric.cpp", "DecomposePackUnPackOps.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 6df643e9288b..0dd1a068d850 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -128,6 +128,7 @@ iree_cc_library( "ConvertBf16ToUInt16Buffers.cpp" "ConvertToDestinationPassingStylePass.cpp" "DecomposeAffineOpsPass.cpp" + "DecomposeBatchMmt4DOps.cpp" "DecomposeConvolutionToLowerDimOps.cpp" "DecomposeLinalgGeneric.cpp" "DecomposePackUnPackOps.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp new file mode 100644 index 000000000000..071ca202c269 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -0,0 +1,159 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/PassDetail.h" +#include "iree/compiler/Codegen/Common/Passes.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-decompose-batch-mmt4d-ops" + +namespace mlir { +namespace iree_compiler { + +namespace { + +/// Pattern to convert linalg.batch_mmt4d with batch dim = 1 into mmt4d. +struct ConvertBatchMmt4DtoMmt4DPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::BatchMmt4DOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = op.getDpsInputOperand(0)->get(); + auto rhs = op.getDpsInputOperand(1)->get(); + auto out = op.getDpsInitOperand(0)->get(); + + auto outType = out.getType().cast(); + // Skip if the batch dim isn't tiled to 1. + if (outType.getShape()[0] != 1) { + return failure(); + } + auto reducedOutType = RankedTensorType::Builder(outType).dropDim(0); + Value reducedOut; + Value initTensor; + // If the init operand is a linalg.fill op, create a new linalg.fill op with + // the batch dim dropped, so it is easier to identify fill + mmt4d cases. + if (auto oldFillOp = out.getDefiningOp()) { + initTensor = oldFillOp.output(); + auto newInit = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, initTensor, reducedOutType); + reducedOut = + rewriter + .create(loc, ValueRange{oldFillOp.value()}, + ValueRange{newInit}) + .result(); + } else { + reducedOut = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, out, reducedOutType); + initTensor = out; + } + + auto lhsType = lhs.getType().cast(); + auto reducedLhsType = RankedTensorType::Builder(lhsType).dropDim(0); + auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, lhs, reducedLhsType); + + auto rhsType = rhs.getType().cast(); + auto reducedRhsType = RankedTensorType::Builder(rhsType).dropDim(0); + auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, rhs, reducedRhsType); + + auto mmt4DOp = rewriter.create( + loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs}, + ValueRange{reducedOut}); + + auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, mmt4DOp.getResult(0), initTensor); + rewriter.replaceOp(op, insertSliceOp); + return success(); + } +}; + +struct DecomposeBatchMmt4DOpsPass + : public DecomposeBatchMmt4DOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void DecomposeBatchMmt4DOpsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + auto funcOp = getOperation(); + + // First tile the batch dim of linalg.batch_mmt4d into 1. + { + auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions( + scf::SCFTilingOptions().setTileSizes({1})); + IRRewriter rewriter(ctx); + funcOp->walk([&](linalg::BatchMmt4DOp op) { + FailureOr tileAndFuseResult = + scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( + rewriter, cast(op.getOperation()), + tileAndFuseOptions); + if (failed(tileAndFuseResult)) { + return signalPassFailure(); + } + + SmallVector replacements; + replacements.resize(op->getNumResults()); + for (const auto &[index, result] : llvm::enumerate(op->getResults())) { + replacements[index] = tileAndFuseResult->replacements[result]; + } + op->replaceAllUsesWith(replacements); + }); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After tiling batch dim to 1 ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + // Canonicalize tiled ops. + { + RewritePatternSet patterns(ctx); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); + ctx->getOrLoadDialect()->getCanonicalizationPatterns( + patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Convert linalg.batch_mmt4d with batch dim = 1 into linalg.mmt4d. + { + RewritePatternSet patterns(ctx); + patterns.add(ctx); + // Canonicalize extract and insert slice ops created during the conversion. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } +} + +std::unique_ptr> +createDecomposeBatchMmt4DOpsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index bd23f3380589..0f48f4ce70dd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -79,6 +79,9 @@ createConvertToDestinationPassingStylePass( // hoisted in different loops. std::unique_ptr createDecomposeAffineOpsPass(); +// Decomposes batch mmt4d op into mmt4d by tiling the batch dim to 1. +std::unique_ptr> createDecomposeBatchMmt4DOpsPass(); + // Decomposes high-D convolution ops into low-D ones. std::unique_ptr createDecomposeConvolutionToLowerDimOpsPass(); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index dbed920ce661..27024895034a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -119,6 +119,13 @@ def DecomposeAffineOps: Pass<"decompose-affine-ops"> { ]; } +def DecomposeBatchMmt4DOps + : Pass<"iree-codegen-decompose-batch-mmt4d-ops", "func::FuncOp"> { + let summary = "Decompose batch_mmt4d ops into mmt4d ops"; + let constructor = + "mlir::iree_compiler::createDecomposeBatchMmt4DOpsPass()"; +} + def DecomposeConvolutionToLowerDimOps : Pass<"iree-codegen-decompose-convolution-to-lower-dim-ops", ""> { let summary = "Decomposes linalg convolution ops to lower dim ops"; From f460d04ea6ea4f511db5c0ab3d40d53056956d78 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 10 Aug 2023 22:59:30 +0000 Subject: [PATCH 2/7] Add tests --- .../Codegen/Common/DecomposeBatchMmt4DOps.cpp | 6 ++ .../compiler/Codegen/Common/test/BUILD.bazel | 1 + .../Codegen/Common/test/CMakeLists.txt | 1 + .../test/decompose_batch_mmt4d_ops.mlir | 69 +++++++++++++++++++ 4 files changed, 77 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp index 071ca202c269..d3f366b12006 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -80,6 +80,12 @@ struct ConvertBatchMmt4DtoMmt4DPattern struct DecomposeBatchMmt4DOpsPass : public DecomposeBatchMmt4DOpsBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + void runOnOperation() override; }; diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index ecd46603c00f..99b76c2f18f6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -30,6 +30,7 @@ iree_lit_test_suite( "convolutions.mlir", "erase_dead_alloc_and_stores.mlir", "decompose_affine_ops.mlir", + "decompose_batch_mmt4d_ops.mlir", "decompose_linalg_generic.mlir", "decompose_pack_unpack_ops.mlir", "eliminate_empty_tensors.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index e1159169b363..64aab43ee4aa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -25,6 +25,7 @@ iree_lit_test_suite( "convert_to_destination_passing_style.mlir" "convolutions.mlir" "decompose_affine_ops.mlir" + "decompose_batch_mmt4d_ops.mlir" "decompose_linalg_generic.mlir" "decompose_pack_unpack_ops.mlir" "eliminate_empty_tensors.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir new file mode 100644 index 000000000000..ec6295f99d5a --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir @@ -0,0 +1,69 @@ +// RUN: iree-opt --iree-codegen-decompose-batch-mmt4d-ops --split-input-file %s | FileCheck %s + +func.func @batch_mmt4d_with_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> + %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%0 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> + return %1 : tensor<128x10x80x8x4xf32> +} + +// CHECK: func.func @batch_mmt4d_with_fill +// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>, +// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>, +// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32> +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]]) +// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32> +// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32> +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32> +// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32> +// CHECK: } +// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32> + +// ----- + +func.func @batch_mmt4d_with_no_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> { + %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> + return %1 : tensor<128x10x80x8x4xf32> +} + +// CHECK: func.func @batch_mmt4d_with_no_fill +// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>, +// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>, +// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32> +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]]) +// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32> +// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32> +// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32> +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32> +// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32> +// CHECK: } +// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32> + +// ----- + +func.func @batch_mmt4d_with_unit_batch(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> { + %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> + return %1 : tensor<1x10x80x8x4xf32> +} + +// CHECK: func.func @batch_mmt4d_with_unit_batch +// CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>, +// CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>, +// CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32> +// CHECK-DAG: %[[EXT_OUT:.+]] = tensor.extract_slice %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<1x10x80x8x4xf32> to tensor<10x80x8x4xf32> +// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<1x10x32x8x1xf32> to tensor<10x32x8x1xf32> +// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<1x80x32x4x1xf32> to tensor<80x32x4x1xf32> +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32> +// CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32> From 48c9570c8c6e116643894ab5cad7e59e83776ada Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 11 Aug 2023 16:53:33 +0000 Subject: [PATCH 3/7] Fix stack out of scope --- .../iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp index d3f366b12006..3466ad89541d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -37,7 +37,8 @@ struct ConvertBatchMmt4DtoMmt4DPattern if (outType.getShape()[0] != 1) { return failure(); } - auto reducedOutType = RankedTensorType::Builder(outType).dropDim(0); + RankedTensorType reducedOutType = + RankedTensorType::Builder(outType).dropDim(0); Value reducedOut; Value initTensor; // If the init operand is a linalg.fill op, create a new linalg.fill op with From eda8687167b6bec8b2ad95e132a77a1e3bc38d36 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 11 Aug 2023 17:13:00 +0000 Subject: [PATCH 4/7] Add dyanmic test --- .../Codegen/Common/DecomposeBatchMmt4DOps.cpp | 4 +-- .../test/decompose_batch_mmt4d_ops.mlir | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp index 3466ad89541d..1a23d02f32b7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -129,9 +130,8 @@ void DecomposeBatchMmt4DOpsPass::runOnOperation() { { RewritePatternSet patterns(ctx); linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - ctx->getOrLoadDialect()->getCanonicalizationPatterns( - patterns); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir index ec6295f99d5a..a2513e8790ba 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir @@ -67,3 +67,31 @@ func.func @batch_mmt4d_with_unit_batch(%arg0: tensor<1x10x32x8x1xf32>, %arg1: te // CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> // CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32> // CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32> + +// ----- + +func.func @batch_mmt4d_with_dynamic_batch(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor + %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) -> tensor + return %1 : tensor +} + +// CHECK: func.func @batch_mmt4d_with_dynamic_batch +// CHECK-SAME: %[[LHS:.+]]: tensor, +// CHECK-SAME: %[[RHS:.+]]: tensor, +// CHECK-SAME: %[[OUT:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[LHS]], %[[C0]] : tensor +// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[DIM]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]]) +// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor to tensor<10x80x8x4xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor to tensor<10x32x8x1xf32> +// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor to tensor<80x32x4x1xf32> +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor +// CHECK: scf.yield %[[INS]] : tensor +// CHECK: } +// CHECK: return %[[RES]] : tensor From 95102060045bb745becd509c833dc4cbecd28e0a Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 11 Aug 2023 17:58:11 +0000 Subject: [PATCH 5/7] Fix ASAN --- .../iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp index 1a23d02f32b7..0863ede83dc1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -60,12 +60,14 @@ struct ConvertBatchMmt4DtoMmt4DPattern } auto lhsType = lhs.getType().cast(); - auto reducedLhsType = RankedTensorType::Builder(lhsType).dropDim(0); + RankedTensorType reducedLhsType = + RankedTensorType::Builder(lhsType).dropDim(0); auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, lhs, reducedLhsType); auto rhsType = rhs.getType().cast(); - auto reducedRhsType = RankedTensorType::Builder(rhsType).dropDim(0); + RankedTensorType reducedRhsType = + RankedTensorType::Builder(rhsType).dropDim(0); auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, rhs, reducedRhsType); From cab05b726d657aa972ed425e174956ab85a5eb23 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 11 Aug 2023 18:53:26 +0000 Subject: [PATCH 6/7] Fix GCC build --- .../iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp index 0863ede83dc1..0f282f661005 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -101,8 +101,9 @@ void DecomposeBatchMmt4DOpsPass::runOnOperation() { // First tile the batch dim of linalg.batch_mmt4d into 1. { + SmallVector tileSizes({1}); auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions( - scf::SCFTilingOptions().setTileSizes({1})); + scf::SCFTilingOptions().setTileSizes(tileSizes)); IRRewriter rewriter(ctx); funcOp->walk([&](linalg::BatchMmt4DOp op) { FailureOr tileAndFuseResult = From affc000aa3ea99b903579379c7d394330dbc08de Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 11 Aug 2023 20:44:59 +0000 Subject: [PATCH 7/7] Remove tiling in pass --- .../Codegen/Common/DecomposeBatchMmt4DOps.cpp | 81 +++-------------- .../test/decompose_batch_mmt4d_ops.mlir | 90 ++++--------------- 2 files changed, 29 insertions(+), 142 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp index 0f282f661005..7a0b99725363 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -6,16 +6,11 @@ #include "iree/compiler/Codegen/Common/PassDetail.h" #include "iree/compiler/Codegen/Common/Passes.h" -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "iree-codegen-decompose-batch-mmt4d-ops" - namespace mlir { namespace iree_compiler { @@ -34,9 +29,9 @@ struct ConvertBatchMmt4DtoMmt4DPattern auto out = op.getDpsInitOperand(0)->get(); auto outType = out.getType().cast(); - // Skip if the batch dim isn't tiled to 1. + // Batch dim needs to be tiled to 1 first. if (outType.getShape()[0] != 1) { - return failure(); + return rewriter.notifyMatchFailure(op, "batch dim needs to be 1"); } RankedTensorType reducedOutType = RankedTensorType::Builder(outType).dropDim(0); @@ -85,9 +80,8 @@ struct ConvertBatchMmt4DtoMmt4DPattern struct DecomposeBatchMmt4DOpsPass : public DecomposeBatchMmt4DOpsBase { void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() override; @@ -99,64 +93,15 @@ void DecomposeBatchMmt4DOpsPass::runOnOperation() { MLIRContext *ctx = &getContext(); auto funcOp = getOperation(); - // First tile the batch dim of linalg.batch_mmt4d into 1. - { - SmallVector tileSizes({1}); - auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions( - scf::SCFTilingOptions().setTileSizes(tileSizes)); - IRRewriter rewriter(ctx); - funcOp->walk([&](linalg::BatchMmt4DOp op) { - FailureOr tileAndFuseResult = - scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( - rewriter, cast(op.getOperation()), - tileAndFuseOptions); - if (failed(tileAndFuseResult)) { - return signalPassFailure(); - } - - SmallVector replacements; - replacements.resize(op->getNumResults()); - for (const auto &[index, result] : llvm::enumerate(op->getResults())) { - replacements[index] = tileAndFuseResult->replacements[result]; - } - op->replaceAllUsesWith(replacements); - }); - - LLVM_DEBUG({ - llvm::dbgs() << "--- After tiling batch dim to 1 ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); - } - - // Canonicalize tiled ops. - { - RewritePatternSet patterns(ctx); - linalg::populateLinalgTilingCanonicalizationPatterns(patterns); - scf::populateSCFForLoopCanonicalizationPatterns(patterns); - memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - } - // Convert linalg.batch_mmt4d with batch dim = 1 into linalg.mmt4d. - { - RewritePatternSet patterns(ctx); - patterns.add(ctx); - // Canonicalize extract and insert slice ops created during the conversion. - tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); - tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - - LLVM_DEBUG({ - llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n"; - funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n\n"; - }); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + // Canonicalize extract and insert slice ops created during the conversion. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir index a2513e8790ba..30795d6573da 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_batch_mmt4d_ops.mlir @@ -1,63 +1,33 @@ // RUN: iree-opt --iree-codegen-decompose-batch-mmt4d-ops --split-input-file %s | FileCheck %s -func.func @batch_mmt4d_with_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> { +func.func @batch_mmt4d_with_fill(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> { %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> - %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%0 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> - return %1 : tensor<128x10x80x8x4xf32> + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> + %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%0 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> + return %1 : tensor<1x10x80x8x4xf32> } // CHECK: func.func @batch_mmt4d_with_fill -// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>, -// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>, -// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32> -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>, +// CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>, +// CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32> // CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]]) -// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> -// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32> -// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32> -// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> -// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32> -// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32> -// CHECK: } -// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32> - -// ----- - -func.func @batch_mmt4d_with_no_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> { - %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> - return %1 : tensor<128x10x80x8x4xf32> -} - -// CHECK: func.func @batch_mmt4d_with_no_fill -// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>, -// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>, -// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32> -// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]]) -// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32> -// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32> -// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32> -// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> -// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32> -// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32> -// CHECK: } -// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32> +// CHECK-DAG: %[[EXT_OUT:.+]] = tensor.extract_slice %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<1x10x80x8x4xf32> to tensor<10x80x8x4xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<1x10x32x8x1xf32> to tensor<10x32x8x1xf32> +// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<1x80x32x4x1xf32> to tensor<80x32x4x1xf32> +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> +// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32> +// CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32> // ----- -func.func @batch_mmt4d_with_unit_batch(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> { +func.func @batch_mmt4d_with_no_fill(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> { %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> return %1 : tensor<1x10x80x8x4xf32> } -// CHECK: func.func @batch_mmt4d_with_unit_batch +// CHECK: func.func @batch_mmt4d_with_no_fill // CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>, // CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>, // CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32> @@ -67,31 +37,3 @@ func.func @batch_mmt4d_with_unit_batch(%arg0: tensor<1x10x32x8x1xf32>, %arg1: te // CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> // CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32> // CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32> - -// ----- - -func.func @batch_mmt4d_with_dynamic_batch(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor - %1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) -> tensor - return %1 : tensor -} - -// CHECK: func.func @batch_mmt4d_with_dynamic_batch -// CHECK-SAME: %[[LHS:.+]]: tensor, -// CHECK-SAME: %[[RHS:.+]]: tensor, -// CHECK-SAME: %[[OUT:.+]]: tensor -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[LHS]], %[[C0]] : tensor -// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[DIM]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]]) -// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor to tensor<10x80x8x4xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> -// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor to tensor<10x32x8x1xf32> -// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor to tensor<80x32x4x1xf32> -// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> -// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor -// CHECK: scf.yield %[[INS]] : tensor -// CHECK: } -// CHECK: return %[[RES]] : tensor