Skip to content

Commit

Permalink
Add LLVMCPUDecomposeBatchMmt4D pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 10, 2023
1 parent 91099ae commit 08a90e6
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ iree_compiler_cc_library(
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeBatchMmt4DOps.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeLinalgGeneric.cpp",
"DecomposePackUnPackOps.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ iree_cc_library(
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeBatchMmt4DOps.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
"DecomposePackUnPackOps.cpp"
Expand Down
159 changes: 159 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-decompose-batch-mmt4d-ops"

namespace mlir {
namespace iree_compiler {

namespace {

/// Pattern to convert linalg.batch_mmt4d with batch dim = 1 into mmt4d.
struct ConvertBatchMmt4DtoMmt4DPattern
: public OpRewritePattern<linalg::BatchMmt4DOp> {
using OpRewritePattern<linalg::BatchMmt4DOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::BatchMmt4DOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto lhs = op.getDpsInputOperand(0)->get();
auto rhs = op.getDpsInputOperand(1)->get();
auto out = op.getDpsInitOperand(0)->get();

auto outType = out.getType().cast<RankedTensorType>();
// Skip if the batch dim isn't tiled to 1.
if (outType.getShape()[0] != 1) {
return failure();
}
auto reducedOutType = RankedTensorType::Builder(outType).dropDim(0);
Value reducedOut;
Value initTensor;
// If the init operand is a linalg.fill op, create a new linalg.fill op with
// the batch dim dropped, so it is easier to identify fill + mmt4d cases.
if (auto oldFillOp = out.getDefiningOp<linalg::FillOp>()) {
initTensor = oldFillOp.output();
auto newInit = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, initTensor, reducedOutType);
reducedOut =
rewriter
.create<linalg::FillOp>(loc, ValueRange{oldFillOp.value()},
ValueRange{newInit})
.result();
} else {
reducedOut = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, out, reducedOutType);
initTensor = out;
}

auto lhsType = lhs.getType().cast<RankedTensorType>();
auto reducedLhsType = RankedTensorType::Builder(lhsType).dropDim(0);
auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, lhs, reducedLhsType);

auto rhsType = rhs.getType().cast<RankedTensorType>();
auto reducedRhsType = RankedTensorType::Builder(rhsType).dropDim(0);
auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, rhs, reducedRhsType);

auto mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs},
ValueRange{reducedOut});

auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, mmt4DOp.getResult(0), initTensor);
rewriter.replaceOp(op, insertSliceOp);
return success();
}
};

struct DecomposeBatchMmt4DOpsPass
: public DecomposeBatchMmt4DOpsBase<DecomposeBatchMmt4DOpsPass> {
void runOnOperation() override;
};

} // namespace

void DecomposeBatchMmt4DOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

// First tile the batch dim of linalg.batch_mmt4d into 1.
{
auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizes({1}));
IRRewriter rewriter(ctx);
funcOp->walk([&](linalg::BatchMmt4DOp op) {
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, cast<TilingInterface>(op.getOperation()),
tileAndFuseOptions);
if (failed(tileAndFuseResult)) {
return signalPassFailure();
}

SmallVector<Value> replacements;
replacements.resize(op->getNumResults());
for (const auto &[index, result] : llvm::enumerate(op->getResults())) {
replacements[index] = tileAndFuseResult->replacements[result];
}
op->replaceAllUsesWith(replacements);
});

LLVM_DEBUG({
llvm::dbgs() << "--- After tiling batch dim to 1 ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}

// Canonicalize tiled ops.
{
RewritePatternSet patterns(ctx);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
ctx->getOrLoadDialect<tensor::TensorDialect>()->getCanonicalizationPatterns(
patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

// Convert linalg.batch_mmt4d with batch dim = 1 into linalg.mmt4d.
{
RewritePatternSet patterns(ctx);
patterns.add<ConvertBatchMmt4DtoMmt4DPattern>(ctx);
// Canonicalize extract and insert slice ops created during the conversion.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}

LLVM_DEBUG({
llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
}

std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeBatchMmt4DOpsPass() {
return std::make_unique<DecomposeBatchMmt4DOpsPass>();
}

} // namespace iree_compiler
} // namespace mlir
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ createConvertToDestinationPassingStylePass(
// hoisted in different loops.
std::unique_ptr<Pass> createDecomposeAffineOpsPass();

// Decomposes batch mmt4d op into mmt4d by tiling the batch dim to 1.
std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeBatchMmt4DOpsPass();

// Decomposes high-D convolution ops into low-D ones.
std::unique_ptr<Pass> createDecomposeConvolutionToLowerDimOpsPass();

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def DecomposeAffineOps: Pass<"decompose-affine-ops"> {
];
}

def DecomposeBatchMmt4DOps
: Pass<"iree-codegen-decompose-batch-mmt4d-ops", "func::FuncOp"> {
let summary = "Decompose batch_mmt4d ops into mmt4d ops";
let constructor =
"mlir::iree_compiler::createDecomposeBatchMmt4DOpsPass()";
}

def DecomposeConvolutionToLowerDimOps :
Pass<"iree-codegen-decompose-convolution-to-lower-dim-ops", ""> {
let summary = "Decomposes linalg convolution ops to lower dim ops";
Expand Down

0 comments on commit 08a90e6

Please sign in to comment.