Skip to content

Commit

Permalink
Add LLVMCPUDecomposeBatchMmt4D pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 10, 2023
1 parent af3d2a1 commit 3e2b92e
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ iree_compiler_cc_library(
"ConvertToDestinationPassingStylePass.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeBatchMmt4DOps.cpp",
"DecomposeLinalgGeneric.cpp",
"DecomposePackUnPackOps.cpp",
"EmulateNarrowType.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ iree_cc_library(
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeBatchMmt4DOps.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
"DecomposePackUnPackOps.cpp"
Expand Down
147 changes: 147 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-decompose-batch-mmt4d-ops"

namespace mlir {
namespace iree_compiler {

namespace {
struct DecomposeBatchMmt4DOpsPass
: public DecomposeBatchMmt4DOpsBase<DecomposeBatchMmt4DOpsPass> {
void runOnOperation() override;
};
} // namespace

void DecomposeBatchMmt4DOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

{
auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizes({1}));
IRRewriter rewriter(ctx);
funcOp->walk([&](linalg::BatchMmt4DOp op) {
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, cast<TilingInterface>(op.getOperation()),
tileAndFuseOptions);
if (failed(tileAndFuseResult)) {
return signalPassFailure();
}

SmallVector<Value> replacements;
replacements.resize(op->getNumResults());
for (const auto &[index, result] : llvm::enumerate(op->getResults())) {
replacements[index] = tileAndFuseResult->replacements[result];
}
op->replaceAllUsesWith(replacements);
});

LLVM_DEBUG({
llvm::dbgs() << "--- After applying tiling that makes batch dim be 1"
" ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}

// Canonicalize tiled ops.
{
RewritePatternSet patterns(ctx);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
ctx->getOrLoadDialect<tensor::TensorDialect>()->getCanonicalizationPatterns(
patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

{
funcOp->walk([&](linalg::BatchMmt4DOp op) {
IRRewriter rewriter(ctx);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
auto loc = op.getLoc();

auto lhs = op.getDpsInputOperand(0)->get();
auto rhs = op.getDpsInputOperand(1)->get();
auto out = op.getDpsInitOperand(0)->get();

auto lhsType = lhs.getType().cast<RankedTensorType>();
auto reducedLhsType = RankedTensorType::Builder(lhsType).dropDim(0);
auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, lhs, reducedLhsType);

auto rhsType = rhs.getType().cast<RankedTensorType>();
auto reducedRhsType = RankedTensorType::Builder(rhsType).dropDim(0);
auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, rhs, reducedRhsType);

auto outType = out.getType().cast<RankedTensorType>();
auto reducedOutType = RankedTensorType::Builder(outType).dropDim(0);

Value originalOut = out;
Value reducedOut;
if (auto oldFillOp = out.getDefiningOp<linalg::FillOp>()) {
originalOut = oldFillOp.output();
auto newInit = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, originalOut, reducedOutType);
reducedOut =
rewriter
.create<linalg::FillOp>(loc, ValueRange{oldFillOp.value()},
ValueRange{newInit})
.result();
} else {
reducedOut = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, out, reducedOutType);
originalOut = out;
}

auto mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs},
ValueRange{reducedOut});

auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, mmt4DOp.getResult(0), originalOut);
rewriter.replaceOp(op, insertSliceOp);
});

LLVM_DEBUG({
llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}

// Canonicalize extract and insert slice ops.
{
RewritePatternSet patterns(ctx);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}

std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeBatchMmt4DOpsPass() {
return std::make_unique<DecomposeBatchMmt4DOpsPass>();
}

} // namespace iree_compiler
} // namespace mlir
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ createConvertToDestinationPassingStylePass(
// hoisted in different loops.
std::unique_ptr<Pass> createDecomposeAffineOpsPass();

// Decomposes batch mmt4d op into mmt4d by tiling the batch dim to 1.
std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeBatchMmt4DOpsPass();

// Decomposes high-D convolution ops into low-D ones.
std::unique_ptr<Pass> createDecomposeConvolutionToLowerDimOpsPass();

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def DecomposeAffineOps: Pass<"decompose-affine-ops"> {
];
}

def DecomposeBatchMmt4DOps
: Pass<"iree-codegen-decompose-batch-mmt4d-ops", "func::FuncOp"> {
let summary = "TODO";
let constructor =
"mlir::iree_compiler::createDecomposeBatchMmt4DOpsPass()";
}

def DecomposeConvolutionToLowerDimOps :
Pass<"iree-codegen-decompose-convolution-to-lower-dim-ops", ""> {
let summary = "Decomposes linalg convolution ops to lower dim ops";
Expand Down

0 comments on commit 3e2b92e

Please sign in to comment.