Skip to content

Commit

Permalink
[stablehlo] Rewriter for stablehlo.dot to linalg vec operations (#14956)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman authored Sep 18, 2023
1 parent 6212fc0 commit 735a77e
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
Expand Down Expand Up @@ -316,7 +317,48 @@ struct GeneralDotConvert final
}
};

struct DotVectorOptimization final : OpRewritePattern<mlir::stablehlo::DotOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::stablehlo::DotOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value lhs = op.getLhs();
Value rhs = op.getRhs();

ShapedType lhsTy = lhs.getType().cast<ShapedType>();
ShapedType rhsTy = rhs.getType().cast<ShapedType>();
ShapedType resultTy = op.getType().cast<ShapedType>();

llvm::SmallVector<int64_t> dotShape;
if (lhsTy.getRank() == 2 && lhsTy.getDimSize(0) == 1) {
lhs = b.create<mlir::stablehlo::ReshapeOp>(
lhsTy.clone({lhsTy.getDimSize(1)}), lhs);
} else if (lhsTy.getRank() == 2) {
dotShape.push_back(lhsTy.getDimSize(0));
}

if (rhsTy.getRank() == 2 && rhsTy.getDimSize(1) == 1) {
rhs = b.create<mlir::stablehlo::ReshapeOp>(
rhsTy.clone({rhsTy.getDimSize(0)}), rhs);
} else if (rhsTy.getRank() == 2) {
dotShape.push_back(rhsTy.getDimSize(1));
}

if (lhs == op.getLhs() && rhs == op.getRhs()) {
return rewriter.notifyMatchFailure(op, "no vector reform available.");
}

auto newDot = b.create<mlir::stablehlo::DotOp>(
resultTy.clone(dotShape), lhs, rhs, op.getPrecisionConfigAttr());
auto resultReshape = b.create<mlir::stablehlo::ReshapeOp>(resultTy, newDot);

rewriter.replaceOp(op, resultReshape);
return success();
}
};

struct DotGeneralToDot final : impl::DotGeneralToDotBase<DotGeneralToDot> {

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePreprocessingDotGeneralToDotPatterns(&getContext(), &patterns);
Expand All @@ -330,8 +372,9 @@ struct DotGeneralToDot final : impl::DotGeneralToDotBase<DotGeneralToDot> {
} // namespace

void populatePreprocessingDotGeneralToDotPatterns(mlir::MLIRContext *context,
RewritePatternSet *patterns) {
patterns->add<GeneralDotConvert>(context);
RewritePatternSet *patterns,
PatternBenefit benefit) {
patterns->add<GeneralDotConvert, DotVectorOptimization>(context, benefit);
}

} // namespace mlir::iree_compiler::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ void populateCanonicalizationPatterns(MLIRContext *context,
/// Collection of rewrite patterns for lowering of StableHLO dot general
/// operations.
void populatePreprocessingDotGeneralToDotPatterns(MLIRContext *context,
RewritePatternSet *patterns);
RewritePatternSet *patterns,
PatternBenefit benefit = 1);

/// Collection of rewrite patterns for lowering of StableHLO einsum operations.
void populatePreprocessingEinsumToDotGeneralPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,44 @@
// RUN: --split-input-file %s | FileCheck %s

// CHECK-LABEL: @testDebatch1
// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x1x2xf32>, [[ARG1:%.+]]: tensor<2x3xf32>)
func.func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x5x2xf32>, [[ARG1:%.+]]: tensor<2x3xf32>)
func.func @testDebatch1(%arg0: tensor<1x5x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x5x3xf32> {
// CHECK-DAG: [[T0:%.+]] = stablehlo.transpose [[ARG0]], dims = [0, 1, 2]
// CHECK-DAG: [[R0:%.+]] = stablehlo.reshape [[T0]] : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK-DAG: [[R0:%.+]] = stablehlo.reshape [[T0]] : (tensor<1x5x2xf32>) -> tensor<5x2xf32>
// CHECK-DAG: [[T1:%.+]] = stablehlo.transpose [[ARG1]], dims = [0, 1]
// CHECK: [[R1:%.+]] = stablehlo.dot [[R0]], [[T1]], precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
// CHECK: [[R1:%.+]] = stablehlo.dot [[R0]], [[T1]], precision = [DEFAULT, DEFAULT] : (tensor<5x2xf32>, tensor<2x3xf32>) -> tensor<5x3xf32>
// CHECK: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<5x3xf32>) -> tensor<1x5x3xf32>
// CHECK: return [[R2]]
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [0]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
} : (tensor<1x5x2xf32>, tensor<2x3xf32>) -> tensor<1x5x3xf32>

func.return %0 : tensor<1x1x3xf32>
func.return %0 : tensor<1x5x3xf32>
}

// -----

// CHECK-LABEL: @testDebatch2
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2x3xf32>, [[ARG1:%.+]]: tensor<1x1x2xf32>)
func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2x3xf32>, [[ARG1:%.+]]: tensor<1x5x2xf32>)
func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x5x2xf32>) -> tensor<3x1x5xf32> {
// CHECK-DAG: [[R0:%.+]] = stablehlo.transpose [[ARG0]], dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK-DAG: [[R1:%.+]] = stablehlo.transpose [[ARG1]], dims = [2, 0, 1] : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
// CHECK-DAG: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
// CHECK: [[R3:%.+]] = stablehlo.dot [[R0]], [[R2]], precision = [DEFAULT, DEFAULT] : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
// CHECK: [[R4:%.+]] = stablehlo.reshape [[R3]] : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
// CHECK-DAG: [[R1:%.+]] = stablehlo.transpose [[ARG1]], dims = [2, 0, 1] : (tensor<1x5x2xf32>) -> tensor<2x1x5xf32>
// CHECK-DAG: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<2x1x5xf32>) -> tensor<2x5xf32>
// CHECK: [[R3:%.+]] = stablehlo.dot [[R0]], [[R2]], precision = [DEFAULT, DEFAULT] : (tensor<3x2xf32>, tensor<2x5xf32>) -> tensor<3x5xf32>
// CHECK: [[R4:%.+]] = stablehlo.reshape [[R3]] : (tensor<3x5xf32>) -> tensor<3x1x5xf32>

%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_contracting_dimensions = [0],
rhs_contracting_dimensions = [2]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
func.return %0 : tensor<3x1x1xf32>
} : (tensor<2x3xf32>, tensor<1x5x2xf32>) -> tensor<3x1x5xf32>
func.return %0 : tensor<3x1x5xf32>
}

// -----
Expand Down Expand Up @@ -102,9 +102,11 @@ func.func @testMatVec(%arg0: tensor<32x20xf32>, %arg1: tensor<32xf32>) -> tensor
// CHECK-DAG: [[T0:%.+]] = stablehlo.transpose [[ARG0]], dims = [1, 0]
// CHECK-DAG: [[T1:%.+]] = stablehlo.transpose [[ARG1]], dims = [0]
// CHECK-DAG: [[R1:%.+]] = stablehlo.reshape [[T1]] : (tensor<32xf32>) -> tensor<32x1xf32>
// CHECK-NEXT: [[M:%.+]] = stablehlo.dot [[T0]], [[R1]]
// CHECK-NEXT: [[R:%.+]] = stablehlo.reshape [[M]]
// CHECK-NEXT: return [[R]]
// CHECK-DAG: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<32x1xf32>) -> tensor<32xf32>
// CHECK-NEXT: [[M:%.+]] = stablehlo.dot [[T0]], [[R2]]
// CHECK-NEXT: [[R1:%.+]] = stablehlo.reshape [[M]]
// CHECK-NEXT: [[R2:%.+]] = stablehlo.reshape [[R1]]
// CHECK-NEXT: return [[R2]]
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_contracting_dimensions = [0],
Expand Down Expand Up @@ -169,7 +171,7 @@ func.func @dot_no_rhs_batch(%arg0: tensor<1x512x768xf32>, %arg1: tensor<768x12x6

// CHECK-LABEL: @testPrefElem
func.func @testPrefElem(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf64> {
// CHECK: stablehlo.dot {{%.*}}, {{%.*}} precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf64>
// CHECK: stablehlo.dot {{%.*}}, {{%.*}} precision = [DEFAULT, DEFAULT] : (tensor<2xf32>, tensor<2x3xf32>) -> tensor<3xf64>
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_contracting_dimensions = [2],
Expand All @@ -180,3 +182,45 @@ func.func @testPrefElem(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> ten

func.return %0 : tensor<1x1x3xf64>
}

// -----

// CHECK-LABEL: @vecmat
func.func @vecmat(%arg0 : tensor<1x256xf32>, %arg1 : tensor<256x40xf32>) -> tensor<1x40xf32> {
// CHECK: %[[R:.+]] = stablehlo.reshape %arg0 : (tensor<1x256xf32>) -> tensor<256xf32>
// CHECK: %[[DOT:.+]] = stablehlo.dot %[[R]], %arg1, precision = [DEFAULT, DEFAULT] : (tensor<256xf32>, tensor<256x40xf32>) -> tensor<40xf32>
// CHECK: %[[R:.+]] = stablehlo.reshape %[[DOT]] : (tensor<40xf32>) -> tensor<1x40xf32>
%0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x256xf32>, tensor<256x40xf32>) -> tensor<1x40xf32>

// CHECK: return %[[R]]
return %0 : tensor<1x40xf32>
}

// -----

// CHECK-LABEL: @matvec
func.func @matvec(%arg0 : tensor<20x144xf32>, %arg1 : tensor<1x144xf32>) -> tensor<20x1xf32> {
// CHECK: %[[T0:.+]] = stablehlo.transpose %arg0, dims = [0, 1] : (tensor<20x144xf32>) -> tensor<20x144xf32>
// CHECK: %[[T1:.+]] = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<1x144xf32>) -> tensor<144x1xf32>
// CHECK: %[[R0:.+]] = stablehlo.reshape %[[T1]] : (tensor<144x1xf32>) -> tensor<144xf32>
// CHECK: %[[DOT:.+]] = stablehlo.dot %[[T0]], %[[R0]], precision = [DEFAULT, DEFAULT] : (tensor<20x144xf32>, tensor<144xf32>) -> tensor<20xf32>
// CHECK: %[[R2:.+]] = stablehlo.reshape %[[DOT]] : (tensor<20xf32>) -> tensor<20x1xf32>
%0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<20x144xf32>, tensor<1x144xf32>) -> tensor<20x1xf32>

// CHECK: return %[[R2]]
return %0 : tensor<20x1xf32>
}

// -----

// CHECK-LABEL: @vecdot
func.func @vecdot(%arg0 : tensor<1x32xf64>, %arg1 : tensor<32x1xf64>) -> tensor<1x1xf64> {
// CHECK: %[[R0:.+]] = stablehlo.reshape %arg0 : (tensor<1x32xf64>) -> tensor<32xf64>
// CHECK: %[[R1:.+]] = stablehlo.reshape %arg1 : (tensor<32x1xf64>) -> tensor<32xf64>
// CHECK: %[[DOT:.+]] = stablehlo.dot %[[R0]], %[[R1]], precision = [DEFAULT, DEFAULT] : (tensor<32xf64>, tensor<32xf64>) -> tensor<f64>
// CHECK: %[[R2:.+]] = stablehlo.reshape %[[DOT]] : (tensor<f64>) -> tensor<1x1xf64>
%0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x32xf64>, tensor<32x1xf64>) -> tensor<1x1xf64>

// CHECK: %[[R2]]
return %0 : tensor<1x1xf64>
}

0 comments on commit 735a77e

Please sign in to comment.