Skip to content

Commit

Permalink
Fold linalg.matmul(a, linalg.transpose(b)) into `linalg.matmul_tran…
Browse files Browse the repository at this point in the history
…spose_b(a, b)` (#14645)
  • Loading branch information
MaheshRavishankar authored Aug 14, 2023
1 parent b56ac23 commit 4a34723
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -27,13 +28,64 @@ namespace Flow {

namespace {

// Method to match a transpose operation.
static bool matchTranspose(linalg::LinalgOp genericOp) {
if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) {
return false;
}
// Check only for 2D ops.
if (genericOp.getNumLoops() != 2 ||
genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
return false;
}
// Check for transpose map.
AffineExpr d0, d1;
MLIRContext *context = genericOp.getContext();
bindDims(context, d0, d1);
SmallVector<AffineMap> expectedMaps = {
AffineMap::get(2, 0, {d0, d1}, context),
AffineMap::get(2, 0, {d1, d0}, context)};
if (genericOp.getIndexingMapsArray() != expectedMaps) {
return false;
}

// Check the body.
Block *body = genericOp.getBlock();
if (!llvm::hasSingleElement(*body)) {
return false;
}
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
auto blockArg = yieldOp.getOperand(0).dyn_cast<BlockArgument>();
if (!blockArg || blockArg.getOwner() != body ||
blockArg.getArgNumber() != 0) {
return false;
}
return true;
}

// Method to match a linalg.matmul(a, linalg.transpose(b)). Returns `b` on
// success.
std::optional<Value> matchATransposeBMatmul(linalg::LinalgOp matmulOp) {
if (!isa<linalg::MatmulOp>(matmulOp.getOperation())) {
return std::nullopt;
}
// Get the RHS
auto rhs = matmulOp.getDpsInputOperand(1);
auto genericOp = rhs->get().getDefiningOp<linalg::GenericOp>();
if (genericOp && matchTranspose(genericOp)) {
return genericOp.getDpsInputOperand(0)->get();
}
return std::nullopt;
}

struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::LinalgExt::IREELinalgExtDialect>();
}

void runOnOperation() override {
SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
SmallVector<std::pair<linalg::MatmulOp, Value>> transposeMatmulRoots;
getOperation()->walk([&](linalg::LinalgOp op) {
{
transform_ext::MatcherContext matcherContext;
Expand All @@ -44,16 +96,31 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
Value src = maxReduction->getCaptured()->getOperand(0);
softmaxRoots.push_back(std::make_pair(op, src));
}
if (std::optional<Value> newRhs = matchATransposeBMatmul(op)) {
transposeMatmulRoots.push_back(std::make_pair(
cast<linalg::MatmulOp>(op.getOperation()), newRhs.value()));
}
}
});
IRRewriter rewriter(&getContext());
for (std::pair<linalg::LinalgOp, Value> softmax : softmaxRoots) {
linalg::LinalgOp op = softmax.first;
Value src = softmax.second;
IRRewriter rewriter(op.getContext());
rewriter.setInsertionPoint(softmax.first);
rewriter.replaceOpWithNewOp<IREE::LinalgExt::SoftmaxOp>(
op, src, op.getDpsInitOperand(0)->get(), op.getNumLoops() - 1);
}
for (std::pair<linalg::MatmulOp, Value> aTransposeBMatmul :
transposeMatmulRoots) {
auto matmulOp = aTransposeBMatmul.first;
Value lhs = matmulOp.getDpsInputOperand(0)->get();
auto newRhs = aTransposeBMatmul.second;
Value init = matmulOp.getDpsInitOperand(0)->get();
rewriter.setInsertionPoint(matmulOp);
SmallVector<NamedAttribute> attrs = getPrunedAttributeList(matmulOp);
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,27 @@ func.func @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x128x128
} -> tensor<12x128x128xf32>
return %109 : tensor<12x128x128xf32>
}

func.func @aTransposeBMatmul(%arg0 : tensor<10x20xf32>,
%arg1 : tensor<40x20xf32>) -> tensor<10x40xf32> {
%0 = tensor.empty() : tensor<20x40xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg1 : tensor<40x20xf32>) outs(%0 : tensor<20x40xf32>) {
^bb0(%b0 : f32, %b1 : f32):
linalg.yield %b0 : f32
} -> tensor<20x40xf32>
%2 = tensor.empty() : tensor<10x40xf32>
%3 = arith.constant 0.0 : f32
%4 = linalg.fill ins(%3 : f32) outs(%2 : tensor<10x40xf32>) -> tensor<10x40xf32>
%5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xf32>, tensor<20x40xf32>)
outs(%4 : tensor<10x40xf32>) -> tensor<10x40xf32>
return %5 : tensor<10x40xf32>
}
// CHECK-LABEL: func @aTransposeBMatmul
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<40x20xf32>
// CHECK: %[[RESULT:.+]] = linalg.matmul_transpose_b
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[RESULT]]

0 comments on commit 4a34723

Please sign in to comment.