From 08a90e6b04ccc565e780acc743d0c0a4f693168a Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 10 Aug 2023 20:06:58 +0000 Subject: [PATCH] Add LLVMCPUDecomposeBatchMmt4D pass --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Codegen/Common/DecomposeBatchMmt4DOps.cpp | 159 ++++++++++++++++++ .../src/iree/compiler/Codegen/Common/Passes.h | 3 + .../iree/compiler/Codegen/Common/Passes.td | 7 + 5 files changed, 171 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 75b0e01c2975d..d87c096a066e8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -153,6 +153,7 @@ iree_compiler_cc_library( "ConvertBf16ToUInt16Buffers.cpp", "ConvertToDestinationPassingStylePass.cpp", "DecomposeAffineOpsPass.cpp", + "DecomposeBatchMmt4DOps.cpp", "DecomposeConvolutionToLowerDimOps.cpp", "DecomposeLinalgGeneric.cpp", "DecomposePackUnPackOps.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 6df643e9288b4..0dd1a068d8504 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -128,6 +128,7 @@ iree_cc_library( "ConvertBf16ToUInt16Buffers.cpp" "ConvertToDestinationPassingStylePass.cpp" "DecomposeAffineOpsPass.cpp" + "DecomposeBatchMmt4DOps.cpp" "DecomposeConvolutionToLowerDimOps.cpp" "DecomposeLinalgGeneric.cpp" "DecomposePackUnPackOps.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp new file mode 100644 index 0000000000000..071ca202c2698 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -0,0 +1,159 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/PassDetail.h" +#include "iree/compiler/Codegen/Common/Passes.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-decompose-batch-mmt4d-ops" + +namespace mlir { +namespace iree_compiler { + +namespace { + +/// Pattern to convert linalg.batch_mmt4d with batch dim = 1 into mmt4d. +struct ConvertBatchMmt4DtoMmt4DPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::BatchMmt4DOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = op.getDpsInputOperand(0)->get(); + auto rhs = op.getDpsInputOperand(1)->get(); + auto out = op.getDpsInitOperand(0)->get(); + + auto outType = out.getType().cast(); + // Skip if the batch dim isn't tiled to 1. + if (outType.getShape()[0] != 1) { + return failure(); + } + auto reducedOutType = RankedTensorType::Builder(outType).dropDim(0); + Value reducedOut; + Value initTensor; + // If the init operand is a linalg.fill op, create a new linalg.fill op with + // the batch dim dropped, so it is easier to identify fill + mmt4d cases. + if (auto oldFillOp = out.getDefiningOp()) { + initTensor = oldFillOp.output(); + auto newInit = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, initTensor, reducedOutType); + reducedOut = + rewriter + .create(loc, ValueRange{oldFillOp.value()}, + ValueRange{newInit}) + .result(); + } else { + reducedOut = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, out, reducedOutType); + initTensor = out; + } + + auto lhsType = lhs.getType().cast(); + auto reducedLhsType = RankedTensorType::Builder(lhsType).dropDim(0); + auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, lhs, reducedLhsType); + + auto rhsType = rhs.getType().cast(); + auto reducedRhsType = RankedTensorType::Builder(rhsType).dropDim(0); + auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, rhs, reducedRhsType); + + auto mmt4DOp = rewriter.create( + loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs}, + ValueRange{reducedOut}); + + auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, mmt4DOp.getResult(0), initTensor); + rewriter.replaceOp(op, insertSliceOp); + return success(); + } +}; + +struct DecomposeBatchMmt4DOpsPass + : public DecomposeBatchMmt4DOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void DecomposeBatchMmt4DOpsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + auto funcOp = getOperation(); + + // First tile the batch dim of linalg.batch_mmt4d into 1. + { + auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions( + scf::SCFTilingOptions().setTileSizes({1})); + IRRewriter rewriter(ctx); + funcOp->walk([&](linalg::BatchMmt4DOp op) { + FailureOr tileAndFuseResult = + scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( + rewriter, cast(op.getOperation()), + tileAndFuseOptions); + if (failed(tileAndFuseResult)) { + return signalPassFailure(); + } + + SmallVector replacements; + replacements.resize(op->getNumResults()); + for (const auto &[index, result] : llvm::enumerate(op->getResults())) { + replacements[index] = tileAndFuseResult->replacements[result]; + } + op->replaceAllUsesWith(replacements); + }); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After tiling batch dim to 1 ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + // Canonicalize tiled ops. + { + RewritePatternSet patterns(ctx); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); + ctx->getOrLoadDialect()->getCanonicalizationPatterns( + patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } + + // Convert linalg.batch_mmt4d with batch dim = 1 into linalg.mmt4d. + { + RewritePatternSet patterns(ctx); + patterns.add(ctx); + // Canonicalize extract and insert slice ops created during the conversion. + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } +} + +std::unique_ptr> +createDecomposeBatchMmt4DOpsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index bd23f33805891..0f48f4ce70dd3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -79,6 +79,9 @@ createConvertToDestinationPassingStylePass( // hoisted in different loops. std::unique_ptr createDecomposeAffineOpsPass(); +// Decomposes batch mmt4d op into mmt4d by tiling the batch dim to 1. +std::unique_ptr> createDecomposeBatchMmt4DOpsPass(); + // Decomposes high-D convolution ops into low-D ones. std::unique_ptr createDecomposeConvolutionToLowerDimOpsPass(); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index dbed920ce661d..27024895034ab 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -119,6 +119,13 @@ def DecomposeAffineOps: Pass<"decompose-affine-ops"> { ]; } +def DecomposeBatchMmt4DOps + : Pass<"iree-codegen-decompose-batch-mmt4d-ops", "func::FuncOp"> { + let summary = "Decompose batch_mmt4d ops into mmt4d ops"; + let constructor = + "mlir::iree_compiler::createDecomposeBatchMmt4DOpsPass()"; +} + def DecomposeConvolutionToLowerDimOps : Pass<"iree-codegen-decompose-convolution-to-lower-dim-ops", ""> { let summary = "Decomposes linalg convolution ops to lower dim ops";