Skip to content

Commit

Permalink
Remove tiling in pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu committed Aug 11, 2023
1 parent cab05b7 commit 8eb9a82
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-decompose-batch-mmt4d-ops"
Expand All @@ -34,9 +32,9 @@ struct ConvertBatchMmt4DtoMmt4DPattern
auto out = op.getDpsInitOperand(0)->get();

auto outType = out.getType().cast<RankedTensorType>();
// Skip if the batch dim isn't tiled to 1.
// Batch dim needs to be tiled to 1 first.
if (outType.getShape()[0] != 1) {
return failure();
return rewriter.notifyMatchFailure(op, "batch dim needs to be 1");
}
RankedTensorType reducedOutType =
RankedTensorType::Builder(outType).dropDim(0);
Expand Down Expand Up @@ -85,9 +83,8 @@ struct ConvertBatchMmt4DtoMmt4DPattern
struct DecomposeBatchMmt4DOpsPass
: public DecomposeBatchMmt4DOpsBase<DecomposeBatchMmt4DOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, func::FuncDialect, arith::ArithDialect,
scf::SCFDialect, tensor::TensorDialect>();
registry.insert<linalg::LinalgDialect, func::FuncDialect,
arith::ArithDialect, tensor::TensorDialect>();
}

void runOnOperation() override;
Expand All @@ -99,64 +96,15 @@ void DecomposeBatchMmt4DOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

// First tile the batch dim of linalg.batch_mmt4d into 1.
{
SmallVector<int64_t> tileSizes({1});
auto tileAndFuseOptions = scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizes(tileSizes));
IRRewriter rewriter(ctx);
funcOp->walk([&](linalg::BatchMmt4DOp op) {
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, cast<TilingInterface>(op.getOperation()),
tileAndFuseOptions);
if (failed(tileAndFuseResult)) {
return signalPassFailure();
}

SmallVector<Value> replacements;
replacements.resize(op->getNumResults());
for (const auto &[index, result] : llvm::enumerate(op->getResults())) {
replacements[index] = tileAndFuseResult->replacements[result];
}
op->replaceAllUsesWith(replacements);
});

LLVM_DEBUG({
llvm::dbgs() << "--- After tiling batch dim to 1 ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}

// Canonicalize tiled ops.
{
RewritePatternSet patterns(ctx);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

// Convert linalg.batch_mmt4d with batch dim = 1 into linalg.mmt4d.
{
RewritePatternSet patterns(ctx);
patterns.add<ConvertBatchMmt4DtoMmt4DPattern>(ctx);
// Canonicalize extract and insert slice ops created during the conversion.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}

LLVM_DEBUG({
llvm::dbgs() << "--- After converting batch_mmt4d into mmt4d ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
RewritePatternSet patterns(ctx);
patterns.add<ConvertBatchMmt4DtoMmt4DPattern>(ctx);
// Canonicalize extract and insert slice ops created during the conversion.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,33 @@
// RUN: iree-opt --iree-codegen-decompose-batch-mmt4d-ops --split-input-file %s | FileCheck %s

func.func @batch_mmt4d_with_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
func.func @batch_mmt4d_with_fill(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%0 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
return %1 : tensor<128x10x80x8x4xf32>
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32>
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%0 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32>
return %1 : tensor<1x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_fill
// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32>
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]])
// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32>
// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32>
// CHECK: }
// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32>

// -----

func.func @batch_mmt4d_with_no_fill(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<128x10x32x8x1xf32>, tensor<128x80x32x4x1xf32>) outs(%arg2 : tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32>
return %1 : tensor<128x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_no_fill
// CHECK-SAME: %[[LHS:.+]]: tensor<128x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<128x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<128x10x80x8x4xf32>
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]])
// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<128x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<128x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<128x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<128x10x80x8x4xf32>
// CHECK: scf.yield %[[INS]] : tensor<128x10x80x8x4xf32>
// CHECK: }
// CHECK: return %[[RES]] : tensor<128x10x80x8x4xf32>
// CHECK-DAG: %[[EXT_OUT:.+]] = tensor.extract_slice %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<1x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<1x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<1x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32>
// CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32>

// -----

func.func @batch_mmt4d_with_unit_batch(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> {
func.func @batch_mmt4d_with_no_fill(%arg0: tensor<1x10x32x8x1xf32>, %arg1: tensor<1x80x32x4x1xf32>, %arg2: tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32> {
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x10x32x8x1xf32>, tensor<1x80x32x4x1xf32>) outs(%arg2 : tensor<1x10x80x8x4xf32>) -> tensor<1x10x80x8x4xf32>
return %1 : tensor<1x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_unit_batch
// CHECK: func.func @batch_mmt4d_with_no_fill
// CHECK-SAME: %[[LHS:.+]]: tensor<1x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<1x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<1x10x80x8x4xf32>
Expand All @@ -67,31 +37,3 @@ func.func @batch_mmt4d_with_unit_batch(%arg0: tensor<1x10x32x8x1xf32>, %arg1: te
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[OUT]][0, 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<1x10x80x8x4xf32>
// CHECK: return %[[INS]] : tensor<1x10x80x8x4xf32>

// -----

func.func @batch_mmt4d_with_dynamic_batch(%arg0: tensor<?x10x32x8x1xf32>, %arg1: tensor<?x80x32x4x1xf32>, %arg2: tensor<?x10x80x8x4xf32>) -> tensor<?x10x80x8x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?x10x80x8x4xf32>) -> tensor<?x10x80x8x4xf32>
%1 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<?x10x32x8x1xf32>, tensor<?x80x32x4x1xf32>) outs(%0 : tensor<?x10x80x8x4xf32>) -> tensor<?x10x80x8x4xf32>
return %1 : tensor<?x10x80x8x4xf32>
}

// CHECK: func.func @batch_mmt4d_with_dynamic_batch
// CHECK-SAME: %[[LHS:.+]]: tensor<?x10x32x8x1xf32>,
// CHECK-SAME: %[[RHS:.+]]: tensor<?x80x32x4x1xf32>,
// CHECK-SAME: %[[OUT:.+]]: tensor<?x10x80x8x4xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x10x32x8x1xf32>
// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[DIM]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[OUT]])
// CHECK: %[[EXT_OUT:.+]] = tensor.extract_slice %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<?x10x80x8x4xf32> to tensor<10x80x8x4xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXT_OUT]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK-DAG: %[[EXT_LHS:.+]] = tensor.extract_slice %[[LHS]][%[[I]], 0, 0, 0, 0] [1, 10, 32, 8, 1] [1, 1, 1, 1, 1] : tensor<?x10x32x8x1xf32> to tensor<10x32x8x1xf32>
// CHECK-DAG: %[[EXT_RHS:.+]] = tensor.extract_slice %[[RHS]][%[[I]], 0, 0, 0, 0] [1, 80, 32, 4, 1] [1, 1, 1, 1, 1] : tensor<?x80x32x4x1xf32> to tensor<80x32x4x1xf32>
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXT_LHS]], %[[EXT_RHS]] : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%[[FILL]] : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[MMT4D]] into %[[ITER_ARG]][%[[I]], 0, 0, 0, 0] [1, 10, 80, 8, 4] [1, 1, 1, 1, 1] : tensor<10x80x8x4xf32> into tensor<?x10x80x8x4xf32>
// CHECK: scf.yield %[[INS]] : tensor<?x10x80x8x4xf32>
// CHECK: }
// CHECK: return %[[RES]] : tensor<?x10x80x8x4xf32>

0 comments on commit 8eb9a82

Please sign in to comment.