From 20684a4abb410f68b3020c4c8ed01607aef679bc Mon Sep 17 00:00:00 2001 From: Ferdinand Lemaire Date: Fri, 12 Jan 2024 12:19:15 +0100 Subject: [PATCH] Canonicalize tosa sqrt + reciprocal into rsqrt (#88) * Add draft for canonicalization of sqrt + reciprocal in rsqrt * Address comments: Add error message, handle tile case and reject non-float scales --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 + .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 52 +++++++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 34 ++++++++++++ 3 files changed, 88 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 3331ca4cb8643f..522aa2b02f1af6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -804,6 +804,8 @@ def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> { let results = (outs Tosa_Tensor:$z ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 636e0deb18a0e8..b6ae029e7a76a0 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" @@ -62,6 +63,57 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +struct SqrtReciprocalOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + // Pattern that matches a Sqrt + Reciprocal to replace them by a rsqrt. + // Sqrt is represented in tosa by a Pow so we check for Pow + reciprocal. + LogicalResult matchAndRewrite(tosa::PowOp op, + PatternRewriter &rewriter) const override { + // Check that the PowOp has a single user + if (!op->hasOneUse()) + return rewriter.notifyMatchFailure(op, "pow operator has more than one user"); + + Operation* user = *op->user_begin(); + // Check that this user is a reciprocal + if (!isa(user)) + return rewriter.notifyMatchFailure(op, "expected a pow + reciprocal pattern"); + + // Check that the Pow op is an Sqrt - its second input should be the scale, 0.5 for Sqrt. + Operation* powScale = op.getInput2().getDefiningOp(); + if (!powScale || !isa(powScale)) + return rewriter.notifyMatchFailure(op, "expected the pow to have a constant scale input"); + + auto scale = cast(cast(powScale).getValue()); + if (!scale.isSplat()) + return rewriter.notifyMatchFailure(op, "expected the pow scale to be a splat tensor"); + + auto constantType = scale.getElementType(); + float scaleValue = 0.; + if (constantType.isF32()) + scaleValue = scale.getSplatValue(); + else + return rewriter.notifyMatchFailure(op, "unexpected type for scale value of the pow op"); + if(scaleValue != 0.5) + return rewriter.notifyMatchFailure(op, "expected the pow to have a scale of 0.5 to be a sqrt"); + + auto inputType = cast(op.getOperand(0).getType()); + auto outputType = cast(op.getType()); + // If the operator needs tiling, fail to match + // An improvement for the future would be to generate a tile operator here instead + if (inputType != outputType) + return rewriter.notifyMatchFailure(op, "input type and output type are different, tiling is not supported for this canonicalization"); + + rewriter.replaceOpWithNewOp(user, outputType, op.getInput1()); + + return success(); + } +}; + +void PowOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { auto notOp = op.getPred().getDefiningOp(); if (!notOp) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 31227698e09bf4..129c52725a59de 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -584,6 +584,40 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : // ----- +// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal +func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> { + // CHECK: %[[RSQRT:.*]] = "tosa.rsqrt"(%arg{{.*}}) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> + // CHECK: return %[[RSQRT]] : tensor<1x5x1x1xf32> + %0 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> + %1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x5x1x1xf32> + %2 = "tosa.reciprocal"(%1) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> + return %2 : tensor<1x5x1x1xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal_no_match +func.func @canonicalize_optimize_sqrt_reciprocal_no_match(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> { + // CHECK-NOT: tosa.rsqrt"(%arg{{.*}}) + %0 = "tosa.const"() <{value = dense<4.000000e-01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> + %1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x5x1x1xf32> + %2 = "tosa.reciprocal"(%1) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> + return %2 : tensor<1x5x1x1xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal_tile_no_match +func.func @canonicalize_optimize_sqrt_reciprocal_tile_no_match(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x7x1xf32> { + // CHECK-NOT: tosa.rsqrt"(%arg{{.*}}) + %0 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x7x1xf32>}> : () -> tensor<1x1x7x1xf32> + %1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x7x1xf32>) -> tensor<1x5x7x1xf32> + %2 = "tosa.reciprocal"(%1) : (tensor<1x5x7x1xf32>) -> tensor<1x5x7x1xf32> + return %2 : tensor<1x5x7x1xf32> +} + +// ----- + // CHECK-LABEL func.func @fold_log_exp(%arg0: tensor) -> tensor { // CHECK: return %arg{{.*}} : tensor