Skip to content

Commit

Permalink
[Flow] Raise batch_matmul(a, transpose(b)) to batch_matmul_transpose_b (
Browse files Browse the repository at this point in the history
#14847)

Adds a similar raising pattern as that for matmul(a, transpose(b)).
  • Loading branch information
qedawkins authored Aug 27, 2023
1 parent 87b920e commit 16e2931
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,30 @@ namespace Flow {

namespace {

// Method to match a transpose operation.
static bool match2DTranspose(linalg::LinalgOp genericOp) {
// Method to match a transpose operation on the two most minor dimensions of the
// specified rank.
static bool matchInner2DTranspose(linalg::LinalgOp genericOp, unsigned rank) {
// Only makes sense for minimum rank 2.
if (rank < 2) {
return false;
}
if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) {
return false;
}
// Check only for 2D ops.
if (genericOp.getNumLoops() != 2 ||
// Check only for ops of the specified rank.
if (genericOp.getNumLoops() != rank ||
genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
return false;
}
// Check for transpose map.
AffineExpr d0, d1;
SmallVector<AffineExpr> exprList(rank);
MLIRContext *context = genericOp.getContext();
bindDims(context, d0, d1);
bindDimsList(context, MutableArrayRef{exprList});
SmallVector<AffineExpr> transposeExprList(exprList);
std::swap(transposeExprList[rank - 1], transposeExprList[rank - 2]);
SmallVector<AffineMap> expectedMaps = {
AffineMap::get(2, 0, {d0, d1}, context),
AffineMap::get(2, 0, {d1, d0}, context)};
AffineMap::get(rank, 0, exprList, context),
AffineMap::get(rank, 0, transposeExprList, context)};
if (genericOp.getIndexingMapsArray() != expectedMaps) {
return false;
}
Expand All @@ -70,7 +77,21 @@ std::optional<Value> matchATransposeBMatmul(linalg::LinalgOp matmulOp) {
}
auto rhs = matmulOp.getDpsInputOperand(1);
auto genericOp = rhs->get().getDefiningOp<linalg::GenericOp>();
if (genericOp && match2DTranspose(genericOp)) {
if (genericOp && matchInner2DTranspose(genericOp, 2)) {
return genericOp.getDpsInputOperand(0)->get();
}
return std::nullopt;
}

// Method to match a linalg.batch_matmul(a, linalg.transpose(b)). Returns `b` on
// success.
std::optional<Value> matchATransposeBBatchMatmul(linalg::LinalgOp bmmOp) {
if (!isa<linalg::BatchMatmulOp>(bmmOp.getOperation())) {
return std::nullopt;
}
auto rhs = bmmOp.getDpsInputOperand(1);
auto genericOp = rhs->get().getDefiningOp<linalg::GenericOp>();
if (genericOp && matchInner2DTranspose(genericOp, 3)) {
return genericOp.getDpsInputOperand(0)->get();
}
return std::nullopt;
Expand Down Expand Up @@ -361,6 +382,8 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {

SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
SmallVector<std::pair<linalg::MatmulOp, Value>> transposeMatmulRoots;
SmallVector<std::pair<linalg::BatchMatmulOp, Value>>
transposeBatchMatmulRoots;
SmallVector<std::pair<linalg::GenericOp, Value>> genericFills;
getOperation()->walk([&](linalg::LinalgOp op) {
{
Expand All @@ -376,6 +399,10 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
transposeMatmulRoots.push_back(std::make_pair(
cast<linalg::MatmulOp>(op.getOperation()), newRhs.value()));
}
if (std::optional<Value> newRhs = matchATransposeBBatchMatmul(op)) {
transposeBatchMatmulRoots.push_back(std::make_pair(
cast<linalg::BatchMatmulOp>(op.getOperation()), newRhs.value()));
}
if (std::optional<Value> fillInput = matchGenericFill(op)) {
genericFills.push_back(
std::make_pair(cast<linalg::GenericOp>(op), fillInput.value()));
Expand All @@ -402,6 +429,17 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
matmulOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
}
for (std::pair<linalg::BatchMatmulOp, Value> aTransposeBBatchMatmul :
transposeBatchMatmulRoots) {
auto bmmOp = aTransposeBBatchMatmul.first;
Value lhs = bmmOp.getDpsInputOperand(0)->get();
auto newRhs = aTransposeBBatchMatmul.second;
Value init = bmmOp.getDpsInitOperand(0)->get();
rewriter.setInsertionPoint(bmmOp);
SmallVector<NamedAttribute> attrs = getPrunedAttributeList(bmmOp);
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
bmmOp, ValueRange{lhs, newRhs}, ValueRange{init}, attrs);
}
for (std::pair<linalg::GenericOp, Value> genericFill : genericFills) {
auto genericOp = genericFill.first;
Value fillInput = genericFill.second;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ func.func @aTransposeBMatmul(%arg0 : tensor<10x20xf32>,
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[RESULT]]

func.func @aTransposeBBatchMatmul(%arg0 : tensor<5x10x20xf32>,
%arg1 : tensor<5x40x20xf32>) -> tensor<5x10x40xf32> {
%0 = tensor.empty() : tensor<5x20x40xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg1 : tensor<5x40x20xf32>) outs(%0 : tensor<5x20x40xf32>) {
^bb0(%b0 : f32, %b1 : f32):
linalg.yield %b0 : f32
} -> tensor<5x20x40xf32>
%2 = tensor.empty() : tensor<5x10x40xf32>
%3 = arith.constant 0.0 : f32
%4 = linalg.fill ins(%3 : f32) outs(%2 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
%5 = linalg.batch_matmul ins(%arg0, %1 : tensor<5x10x20xf32>, tensor<5x20x40xf32>)
outs(%4 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
return %5 : tensor<5x10x40xf32>
}
// CHECK-LABEL: func @aTransposeBBatchMatmul
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x10x20xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<5x40x20xf32>
// CHECK: %[[RESULT:.+]] = linalg.batch_matmul_transpose_b
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[RESULT]]

func.func @generic_fill(%arg0: tensor<?x?xf32>) -> tensor<1x1x?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
Expand Down

0 comments on commit 16e2931

Please sign in to comment.