diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp index 34592975af26..6e7ecde0176e 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/StablehloOps.h" @@ -316,7 +317,48 @@ struct GeneralDotConvert final } }; +struct DotVectorOptimization final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::stablehlo::DotOp op, + PatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + ShapedType lhsTy = lhs.getType().cast(); + ShapedType rhsTy = rhs.getType().cast(); + ShapedType resultTy = op.getType().cast(); + + llvm::SmallVector dotShape; + if (lhsTy.getRank() == 2 && lhsTy.getDimSize(0) == 1) { + lhs = b.create( + lhsTy.clone({lhsTy.getDimSize(1)}), lhs); + } else if (lhsTy.getRank() == 2) { + dotShape.push_back(lhsTy.getDimSize(0)); + } + + if (rhsTy.getRank() == 2 && rhsTy.getDimSize(1) == 1) { + rhs = b.create( + rhsTy.clone({rhsTy.getDimSize(0)}), rhs); + } else if (rhsTy.getRank() == 2) { + dotShape.push_back(rhsTy.getDimSize(1)); + } + + if (lhs == op.getLhs() && rhs == op.getRhs()) { + return rewriter.notifyMatchFailure(op, "no vector reform available."); + } + + auto newDot = b.create( + resultTy.clone(dotShape), lhs, rhs, op.getPrecisionConfigAttr()); + auto resultReshape = b.create(resultTy, newDot); + + rewriter.replaceOp(op, resultReshape); + return success(); + } +}; + struct DotGeneralToDot final : impl::DotGeneralToDotBase { + void runOnOperation() override { RewritePatternSet patterns(&getContext()); populatePreprocessingDotGeneralToDotPatterns(&getContext(), &patterns); @@ -330,8 +372,9 @@ struct DotGeneralToDot final : impl::DotGeneralToDotBase { } // namespace void populatePreprocessingDotGeneralToDotPatterns(mlir::MLIRContext *context, - RewritePatternSet *patterns) { - patterns->add(context); + RewritePatternSet *patterns, + PatternBenefit benefit) { + patterns->add(context, benefit); } } // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h index 280e8dbca6e1..c2bf829fcb9d 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h @@ -24,7 +24,8 @@ void populateCanonicalizationPatterns(MLIRContext *context, /// Collection of rewrite patterns for lowering of StableHLO dot general /// operations. void populatePreprocessingDotGeneralToDotPatterns(MLIRContext *context, - RewritePatternSet *patterns); + RewritePatternSet *patterns, + PatternBenefit benefit = 1); /// Collection of rewrite patterns for lowering of StableHLO einsum operations. void populatePreprocessingEinsumToDotGeneralPatterns( diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/dot_general_to_dot.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/dot_general_to_dot.mlir index 932f1dabe46f..e72fe1963a36 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/dot_general_to_dot.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/dot_general_to_dot.mlir @@ -2,13 +2,13 @@ // RUN: --split-input-file %s | FileCheck %s // CHECK-LABEL: @testDebatch1 -// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x1x2xf32>, [[ARG1:%.+]]: tensor<2x3xf32>) -func.func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { +// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x5x2xf32>, [[ARG1:%.+]]: tensor<2x3xf32>) +func.func @testDebatch1(%arg0: tensor<1x5x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x5x3xf32> { // CHECK-DAG: [[T0:%.+]] = stablehlo.transpose [[ARG0]], dims = [0, 1, 2] - // CHECK-DAG: [[R0:%.+]] = stablehlo.reshape [[T0]] : (tensor<1x1x2xf32>) -> tensor<1x2xf32> + // CHECK-DAG: [[R0:%.+]] = stablehlo.reshape [[T0]] : (tensor<1x5x2xf32>) -> tensor<5x2xf32> // CHECK-DAG: [[T1:%.+]] = stablehlo.transpose [[ARG1]], dims = [0, 1] - // CHECK: [[R1:%.+]] = stablehlo.dot [[R0]], [[T1]], precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> - // CHECK: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<1x3xf32>) -> tensor<1x1x3xf32> + // CHECK: [[R1:%.+]] = stablehlo.dot [[R0]], [[T1]], precision = [DEFAULT, DEFAULT] : (tensor<5x2xf32>, tensor<2x3xf32>) -> tensor<5x3xf32> + // CHECK: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<5x3xf32>) -> tensor<1x5x3xf32> // CHECK: return [[R2]] %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< @@ -16,21 +16,21 @@ func.func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> ten rhs_contracting_dimensions = [0] >, precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> + } : (tensor<1x5x2xf32>, tensor<2x3xf32>) -> tensor<1x5x3xf32> - func.return %0 : tensor<1x1x3xf32> + func.return %0 : tensor<1x5x3xf32> } // ----- // CHECK-LABEL: @testDebatch2 -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2x3xf32>, [[ARG1:%.+]]: tensor<1x1x2xf32>) -func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> { +// CHECK-SAME: ([[ARG0:%.+]]: tensor<2x3xf32>, [[ARG1:%.+]]: tensor<1x5x2xf32>) +func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x5x2xf32>) -> tensor<3x1x5xf32> { // CHECK-DAG: [[R0:%.+]] = stablehlo.transpose [[ARG0]], dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32> - // CHECK-DAG: [[R1:%.+]] = stablehlo.transpose [[ARG1]], dims = [2, 0, 1] : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> - // CHECK-DAG: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<2x1x1xf32>) -> tensor<2x1xf32> - // CHECK: [[R3:%.+]] = stablehlo.dot [[R0]], [[R2]], precision = [DEFAULT, DEFAULT] : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> - // CHECK: [[R4:%.+]] = stablehlo.reshape [[R3]] : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + // CHECK-DAG: [[R1:%.+]] = stablehlo.transpose [[ARG1]], dims = [2, 0, 1] : (tensor<1x5x2xf32>) -> tensor<2x1x5xf32> + // CHECK-DAG: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<2x1x5xf32>) -> tensor<2x5xf32> + // CHECK: [[R3:%.+]] = stablehlo.dot [[R0]], [[R2]], precision = [DEFAULT, DEFAULT] : (tensor<3x2xf32>, tensor<2x5xf32>) -> tensor<3x5xf32> + // CHECK: [[R4:%.+]] = stablehlo.reshape [[R3]] : (tensor<3x5xf32>) -> tensor<3x1x5xf32> %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< @@ -38,8 +38,8 @@ func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> ten rhs_contracting_dimensions = [2] >, precision_config = [#stablehlo, #stablehlo] - } : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> - func.return %0 : tensor<3x1x1xf32> + } : (tensor<2x3xf32>, tensor<1x5x2xf32>) -> tensor<3x1x5xf32> + func.return %0 : tensor<3x1x5xf32> } // ----- @@ -102,9 +102,11 @@ func.func @testMatVec(%arg0: tensor<32x20xf32>, %arg1: tensor<32xf32>) -> tensor // CHECK-DAG: [[T0:%.+]] = stablehlo.transpose [[ARG0]], dims = [1, 0] // CHECK-DAG: [[T1:%.+]] = stablehlo.transpose [[ARG1]], dims = [0] // CHECK-DAG: [[R1:%.+]] = stablehlo.reshape [[T1]] : (tensor<32xf32>) -> tensor<32x1xf32> - // CHECK-NEXT: [[M:%.+]] = stablehlo.dot [[T0]], [[R1]] - // CHECK-NEXT: [[R:%.+]] = stablehlo.reshape [[M]] - // CHECK-NEXT: return [[R]] + // CHECK-DAG: [[R2:%.+]] = stablehlo.reshape [[R1]] : (tensor<32x1xf32>) -> tensor<32xf32> + // CHECK-NEXT: [[M:%.+]] = stablehlo.dot [[T0]], [[R2]] + // CHECK-NEXT: [[R1:%.+]] = stablehlo.reshape [[M]] + // CHECK-NEXT: [[R2:%.+]] = stablehlo.reshape [[R1]] + // CHECK-NEXT: return [[R2]] %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_contracting_dimensions = [0], @@ -169,7 +171,7 @@ func.func @dot_no_rhs_batch(%arg0: tensor<1x512x768xf32>, %arg1: tensor<768x12x6 // CHECK-LABEL: @testPrefElem func.func @testPrefElem(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf64> { - // CHECK: stablehlo.dot {{%.*}}, {{%.*}} precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf64> + // CHECK: stablehlo.dot {{%.*}}, {{%.*}} precision = [DEFAULT, DEFAULT] : (tensor<2xf32>, tensor<2x3xf32>) -> tensor<3xf64> %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_contracting_dimensions = [2], @@ -180,3 +182,45 @@ func.func @testPrefElem(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> ten func.return %0 : tensor<1x1x3xf64> } + +// ----- + +// CHECK-LABEL: @vecmat +func.func @vecmat(%arg0 : tensor<1x256xf32>, %arg1 : tensor<256x40xf32>) -> tensor<1x40xf32> { + // CHECK: %[[R:.+]] = stablehlo.reshape %arg0 : (tensor<1x256xf32>) -> tensor<256xf32> + // CHECK: %[[DOT:.+]] = stablehlo.dot %[[R]], %arg1, precision = [DEFAULT, DEFAULT] : (tensor<256xf32>, tensor<256x40xf32>) -> tensor<40xf32> + // CHECK: %[[R:.+]] = stablehlo.reshape %[[DOT]] : (tensor<40xf32>) -> tensor<1x40xf32> + %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo, #stablehlo]} : (tensor<1x256xf32>, tensor<256x40xf32>) -> tensor<1x40xf32> + + // CHECK: return %[[R]] + return %0 : tensor<1x40xf32> +} + +// ----- + +// CHECK-LABEL: @matvec +func.func @matvec(%arg0 : tensor<20x144xf32>, %arg1 : tensor<1x144xf32>) -> tensor<20x1xf32> { + // CHECK: %[[T0:.+]] = stablehlo.transpose %arg0, dims = [0, 1] : (tensor<20x144xf32>) -> tensor<20x144xf32> + // CHECK: %[[T1:.+]] = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<1x144xf32>) -> tensor<144x1xf32> + // CHECK: %[[R0:.+]] = stablehlo.reshape %[[T1]] : (tensor<144x1xf32>) -> tensor<144xf32> + // CHECK: %[[DOT:.+]] = stablehlo.dot %[[T0]], %[[R0]], precision = [DEFAULT, DEFAULT] : (tensor<20x144xf32>, tensor<144xf32>) -> tensor<20xf32> + // CHECK: %[[R2:.+]] = stablehlo.reshape %[[DOT]] : (tensor<20xf32>) -> tensor<20x1xf32> + %0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<20x144xf32>, tensor<1x144xf32>) -> tensor<20x1xf32> + + // CHECK: return %[[R2]] + return %0 : tensor<20x1xf32> +} + +// ----- + +// CHECK-LABEL: @vecdot +func.func @vecdot(%arg0 : tensor<1x32xf64>, %arg1 : tensor<32x1xf64>) -> tensor<1x1xf64> { + // CHECK: %[[R0:.+]] = stablehlo.reshape %arg0 : (tensor<1x32xf64>) -> tensor<32xf64> + // CHECK: %[[R1:.+]] = stablehlo.reshape %arg1 : (tensor<32x1xf64>) -> tensor<32xf64> + // CHECK: %[[DOT:.+]] = stablehlo.dot %[[R0]], %[[R1]], precision = [DEFAULT, DEFAULT] : (tensor<32xf64>, tensor<32xf64>) -> tensor + // CHECK: %[[R2:.+]] = stablehlo.reshape %[[DOT]] : (tensor) -> tensor<1x1xf64> + %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo, #stablehlo]} : (tensor<1x32xf64>, tensor<32x1xf64>) -> tensor<1x1xf64> + + // CHECK: %[[R2]] + return %0 : tensor<1x1xf64> +}