Skip to content

Commit

Permalink
Decompose batch_mmt4d to mmt4d
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 9, 2023
1 parent a0aebfc commit d6b0730
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_cc_library(
"LLVMCPUAssignConstantOrdinals.cpp"
"LLVMCPUAssignImportOrdinals.cpp"
"LLVMCPUCheckIRBeforeLLVMConversion.cpp"
"LLVMCPUDecomposeBatchMmt4d.cpp"
"LLVMCPUEmitVectorizationRemarks.cpp"
"LLVMCPULinkExecutables.cpp"
"LLVMCPULowerExecutableTarget.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-llvmcpu-decompose-batch-mmt4d"

namespace mlir {
namespace iree_compiler {

namespace {
struct LLVMCPUDecomposeBatchMmt4dPass
: public LLVMCPUDecomposeBatchMmt4dBase<LLVMCPUDecomposeBatchMmt4dPass> {
void runOnOperation() override;
};
} // namespace

void LLVMCPUDecomposeBatchMmt4dPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

{
auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizes({1}));
IRRewriter rewriter(ctx);
funcOp->walk([&](linalg::BatchMmt4DOp op) {
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, cast<TilingInterface>(op.getOperation()),
tileAndFuseOptions);
if (failed(tileAndFuseResult)) {
return signalPassFailure();
}

SmallVector<Value> replacements;
replacements.resize(op->getNumResults());
for (const auto &[index, result] : llvm::enumerate(op->getResults())) {
replacements[index] = tileAndFuseResult->replacements[result];
}
op->replaceAllUsesWith(replacements);
});

// LLVM_DEBUG({
llvm::dbgs() << "--- After applying tiling that makes batch dim be 1"
" ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
// });
}

// Canonicalize tiled ops.
{
RewritePatternSet patterns(ctx);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
ctx->getOrLoadDialect<tensor::TensorDialect>()->getCanonicalizationPatterns(
patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

{
funcOp->walk([&](linalg::BatchMmt4DOp op) {
IRRewriter rewriter(ctx);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
auto loc = op.getLoc();

auto lhs = op.getDpsInputOperand(0)->get();
auto rhs = op.getDpsInputOperand(1)->get();
auto out = op.getDpsInitOperand(0)->get();

auto lhsType = lhs.getType().cast<RankedTensorType>();
auto reducedLhsType = RankedTensorType::Builder(lhsType).dropDim(0);
auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, lhs, reducedLhsType);

auto rhsType = rhs.getType().cast<RankedTensorType>();
auto reducedRhsType = RankedTensorType::Builder(rhsType).dropDim(0);
auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, rhs, reducedRhsType);

auto outType = out.getType().cast<RankedTensorType>();
auto reducedOutType = RankedTensorType::Builder(outType).dropDim(0);

Value originalOut = out;
Value reducedOut;
if (auto oldFillOp = out.getDefiningOp<linalg::FillOp>()) {
originalOut = oldFillOp.output();
auto newInit = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, originalOut, reducedOutType);
reducedOut =
rewriter
.create<linalg::FillOp>(loc, ValueRange{oldFillOp.value()},
ValueRange{newInit})
.result();
} else {
reducedOut = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, out, reducedOutType);
originalOut = out;
}

auto mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs},
ValueRange{reducedOut});

auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, mmt4DOp.getResult(0), originalOut);
rewriter.replaceOp(op, insertSliceOp);
});

// LLVM_DEBUG({
llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
// });
}

// Canonicalize extract and insert slice ops.
{
RewritePatternSet patterns(ctx);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}

std::unique_ptr<OperationPass<func::FuncOp>>
createLLVMCPUDecomposeBatchMmt4dPass() {
return std::make_unique<LLVMCPUDecomposeBatchMmt4dPass>();
}

} // namespace iree_compiler
} // namespace mlir
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();

if (enableMicrokernels) {
nestedModulePM.addNestedPass<func::FuncOp>(
createLLVMCPUDecomposeBatchMmt4dPass());
nestedModulePM.addPass(
createLLVMCPULowerToUKernelsPass(clSkipIntermediateRoundings));
} else {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ createLLVMCPULowerToUKernelsPass(bool skipIntermediateRoundings = true);
std::unique_ptr<OperationPass<func::FuncOp>>
createLLVMCPUMmt4dVectorLoweringPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createLLVMCPUDecomposeBatchMmt4dPass();

/// Pass to perform peeling on non-distributed loops.
std::unique_ptr<OperationPass<func::FuncOp>> createLLVMCPUPeelPass();

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def LLVMCPUMmt4dVectorLowering
"mlir::iree_compiler::createLLVMCPUMmt4dVectorLoweringPass()";
}

def LLVMCPUDecomposeBatchMmt4d
: Pass<"iree-llvmcpu-decompose-batch-mmt4d", "func::FuncOp"> {
let summary = "TODO";
let constructor =
"mlir::iree_compiler::createLLVMCPUDecomposeBatchMmt4dPass()";
}

def LLVMCPUPeel :
Pass<"iree-llvmcpu-peel", "func::FuncOp"> {
let summary = "Pass to perform peeling on non-distributed loops.";
Expand Down

0 comments on commit d6b0730

Please sign in to comment.