From 3e2b92e026388133ad2057c5b2484ec0b64b32ff Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 10 Aug 2023 20:06:58 +0000 Subject: [PATCH] Add LLVMCPUDecomposeBatchMmt4D pass --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Codegen/Common/DecomposeBatchMmt4DOps.cpp | 147 ++++++++++++++++++ .../src/iree/compiler/Codegen/Common/Passes.h | 4 + .../iree/compiler/Codegen/Common/Passes.td | 7 + 5 files changed, 160 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 75b0e01c2975d..b78d6667c57ce 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -154,6 +154,7 @@ iree_compiler_cc_library( "ConvertToDestinationPassingStylePass.cpp", "DecomposeAffineOpsPass.cpp", "DecomposeConvolutionToLowerDimOps.cpp", + "DecomposeBatchMmt4DOps.cpp", "DecomposeLinalgGeneric.cpp", "DecomposePackUnPackOps.cpp", "EmulateNarrowType.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 6df643e9288b4..0dd1a068d8504 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -128,6 +128,7 @@ iree_cc_library( "ConvertBf16ToUInt16Buffers.cpp" "ConvertToDestinationPassingStylePass.cpp" "DecomposeAffineOpsPass.cpp" + "DecomposeBatchMmt4DOps.cpp" "DecomposeConvolutionToLowerDimOps.cpp" "DecomposeLinalgGeneric.cpp" "DecomposePackUnPackOps.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp new file mode 100644 index 0000000000000..61462f0db968e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp @@ -0,0 +1,147 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-decompose-batch-mmt4d-ops" + +namespace mlir { +namespace iree_compiler { + +namespace { +struct DecomposeBatchMmt4DOpsPass + : public DecomposeBatchMmt4DOpsBase { + void runOnOperation() override; +}; +} // namespace + +void DecomposeBatchMmt4DOpsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + auto funcOp = getOperation(); + + { + auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions( + scf::SCFTilingOptions().setTileSizes({1})); + IRRewriter rewriter(ctx); + funcOp->walk([&](linalg::BatchMmt4DOp op) { + FailureOr tileAndFuseResult = + scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( + rewriter, cast(op.getOperation()), + tileAndFuseOptions); + if (failed(tileAndFuseResult)) { + return signalPassFailure(); + } + + SmallVector replacements; + replacements.resize(op->getNumResults()); + for (const auto &[index, result] : llvm::enumerate(op->getResults())) { + replacements[index] = tileAndFuseResult->replacements[result]; + } + op->replaceAllUsesWith(replacements); + }); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After applying tiling that makes batch dim be 1" + " ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + // Canonicalize tiled ops. + { + RewritePatternSet patterns(ctx); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); + ctx->getOrLoadDialect()->getCanonicalizationPatterns( + patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } + + { + funcOp->walk([&](linalg::BatchMmt4DOp op) { + IRRewriter rewriter(ctx); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + auto loc = op.getLoc(); + + auto lhs = op.getDpsInputOperand(0)->get(); + auto rhs = op.getDpsInputOperand(1)->get(); + auto out = op.getDpsInitOperand(0)->get(); + + auto lhsType = lhs.getType().cast(); + auto reducedLhsType = RankedTensorType::Builder(lhsType).dropDim(0); + auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, lhs, reducedLhsType); + + auto rhsType = rhs.getType().cast(); + auto reducedRhsType = RankedTensorType::Builder(rhsType).dropDim(0); + auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, rhs, reducedRhsType); + + auto outType = out.getType().cast(); + auto reducedOutType = RankedTensorType::Builder(outType).dropDim(0); + + Value originalOut = out; + Value reducedOut; + if (auto oldFillOp = out.getDefiningOp()) { + originalOut = oldFillOp.output(); + auto newInit = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, originalOut, reducedOutType); + reducedOut = + rewriter + .create(loc, ValueRange{oldFillOp.value()}, + ValueRange{newInit}) + .result(); + } else { + reducedOut = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, out, reducedOutType); + originalOut = out; + } + + auto mmt4DOp = rewriter.create( + loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs}, + ValueRange{reducedOut}); + + auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, mmt4DOp.getResult(0), originalOut); + rewriter.replaceOp(op, insertSliceOp); + }); + + LLVM_DEBUG({ + llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + + // Canonicalize extract and insert slice ops. + { + RewritePatternSet patterns(ctx); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } +} + +std::unique_ptr> +createDecomposeBatchMmt4DOpsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index bd23f33805891..079f36b71d247 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -79,6 +79,10 @@ createConvertToDestinationPassingStylePass( // hoisted in different loops. std::unique_ptr createDecomposeAffineOpsPass(); +// Decomposes batch mmt4d op into mmt4d by tiling the batch dim to 1. +std::unique_ptr> +createDecomposeBatchMmt4DOpsPass(); + // Decomposes high-D convolution ops into low-D ones. std::unique_ptr createDecomposeConvolutionToLowerDimOpsPass(); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index dbed920ce661d..e313420b5d4f2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -119,6 +119,13 @@ def DecomposeAffineOps: Pass<"decompose-affine-ops"> { ]; } +def DecomposeBatchMmt4DOps + : Pass<"iree-codegen-decompose-batch-mmt4d-ops", "func::FuncOp"> { + let summary = "TODO"; + let constructor = + "mlir::iree_compiler::createDecomposeBatchMmt4DOpsPass()"; +} + def DecomposeConvolutionToLowerDimOps : Pass<"iree-codegen-decompose-convolution-to-lower-dim-ops", ""> { let summary = "Decomposes linalg convolution ops to lower dim ops";