Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass to decompose batch_mmt4d ops into mmt4d #14628

Merged
merged 7 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ iree_compiler_cc_library(
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeBatchMmt4DOps.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeLinalgGeneric.cpp",
"DecomposePackUnPackOps.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ iree_cc_library(
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeBatchMmt4DOps.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
"DecomposePackUnPackOps.cpp"
Expand Down
114 changes: 114 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/DecomposeBatchMmt4DOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace iree_compiler {

namespace {

/// Pattern to convert linalg.batch_mmt4d with batch dim = 1 into mmt4d.
struct ConvertBatchMmt4DtoMmt4DPattern
: public OpRewritePattern<linalg::BatchMmt4DOp> {
using OpRewritePattern<linalg::BatchMmt4DOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::BatchMmt4DOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto lhs = op.getDpsInputOperand(0)->get();
auto rhs = op.getDpsInputOperand(1)->get();
auto out = op.getDpsInitOperand(0)->get();

auto outType = out.getType().cast<RankedTensorType>();
// Batch dim needs to be tiled to 1 first.
if (outType.getShape()[0] != 1) {
return rewriter.notifyMatchFailure(op, "batch dim needs to be 1");
}
RankedTensorType reducedOutType =
RankedTensorType::Builder(outType).dropDim(0);
Value reducedOut;
Value initTensor;
// If the init operand is a linalg.fill op, create a new linalg.fill op with
// the batch dim dropped, so it is easier to identify fill + mmt4d cases.
if (auto oldFillOp = out.getDefiningOp<linalg::FillOp>()) {
initTensor = oldFillOp.output();
auto newInit = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, initTensor, reducedOutType);
reducedOut =
rewriter
.create<linalg::FillOp>(loc, ValueRange{oldFillOp.value()},
ValueRange{newInit})
.result();
} else {
reducedOut = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, out, reducedOutType);
initTensor = out;
}

auto lhsType = lhs.getType().cast<RankedTensorType>();
RankedTensorType reducedLhsType =
RankedTensorType::Builder(lhsType).dropDim(0);
auto reducedLhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, lhs, reducedLhsType);

auto rhsType = rhs.getType().cast<RankedTensorType>();
RankedTensorType reducedRhsType =
RankedTensorType::Builder(rhsType).dropDim(0);
auto reducedRhs = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, rhs, reducedRhsType);

auto mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs},
ValueRange{reducedOut});

auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, mmt4DOp.getResult(0), initTensor);
rewriter.replaceOp(op, insertSliceOp);
return success();
}
};

struct DecomposeBatchMmt4DOpsPass
: public DecomposeBatchMmt4DOpsBase<DecomposeBatchMmt4DOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, func::FuncDialect,
arith::ArithDialect, tensor::TensorDialect>();
}

void runOnOperation() override;
};

} // namespace

void DecomposeBatchMmt4DOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

// Convert linalg.batch_mmt4d with batch dim = 1 into linalg.mmt4d.
RewritePatternSet patterns(ctx);
patterns.add<ConvertBatchMmt4DtoMmt4DPattern>(ctx);
// Canonicalize extract and insert slice ops created during the conversion.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeBatchMmt4DOpsPass() {
return std::make_unique<DecomposeBatchMmt4DOpsPass>();
}

} // namespace iree_compiler
} // namespace mlir
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ createConvertToDestinationPassingStylePass(
// hoisted in different loops.
std::unique_ptr<Pass> createDecomposeAffineOpsPass();

// Decomposes batch mmt4d op into mmt4d by tiling the batch dim to 1.
std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeBatchMmt4DOpsPass();

// Decomposes high-D convolution ops into low-D ones.
std::unique_ptr<Pass> createDecomposeConvolutionToLowerDimOpsPass();

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def DecomposeAffineOps: Pass<"decompose-affine-ops"> {
];
}

def DecomposeBatchMmt4DOps
: Pass<"iree-codegen-decompose-batch-mmt4d-ops", "func::FuncOp"> {
let summary = "Decompose batch_mmt4d ops into mmt4d ops";
let constructor =
"mlir::iree_compiler::createDecomposeBatchMmt4DOpsPass()";
}

def DecomposeConvolutionToLowerDimOps :
Pass<"iree-codegen-decompose-convolution-to-lower-dim-ops", ""> {
let summary = "Decomposes linalg convolution ops to lower dim ops";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ iree_lit_test_suite(
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
"decompose_affine_ops.mlir",
"decompose_batch_mmt4d_ops.mlir",
"decompose_linalg_generic.mlir",
"decompose_pack_unpack_ops.mlir",
"eliminate_empty_tensors.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"convert_to_destination_passing_style.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
"decompose_batch_mmt4d_ops.mlir"
"decompose_linalg_generic.mlir"
"decompose_pack_unpack_ops.mlir"
"eliminate_empty_tensors.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: iree-opt --iree-codegen-decompose-batch-mmt4d-ops --split-input-file %s | FileCheck %s

func.func @batch_mmt4d_with_fill(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32>
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%0 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32>
return %1 : tensor<1x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_fill
// CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[EXT_OUT:.+]] = tensor.extract_slice %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<1x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<1x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<1x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32>
// CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32>

// -----

func.func @batch_mmt4d_with_no_fill(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> {
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32>
return %1 : tensor<1x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_no_fill
// CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32>
// CHECK-DAG: %[[EXT_OUT:.+]] = tensor.extract_slice %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<1x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<1x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<1x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32>
// CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32>
Loading