Skip to content

Commit

Permalink
Canonicalize tosa sqrt + reciprocal into rsqrt (#88)
Browse files Browse the repository at this point in the history
* Add draft for canonicalization of sqrt + reciprocal in rsqrt

* Address comments: Add error message, handle tile case and reject non-float scales
  • Loading branch information
flemairen6 authored Jan 12, 2024
1 parent d0868e8 commit 20684a4
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,8 @@ def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> {
let results = (outs
Tosa_Tensor:$z
);

let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
52 changes: 52 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -62,6 +63,57 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConcatOptimization>(context);
}

struct SqrtReciprocalOptimization : public OpRewritePattern<tosa::PowOp> {
using OpRewritePattern<tosa::PowOp>::OpRewritePattern;
// Pattern that matches a Sqrt + Reciprocal to replace them by a rsqrt.
// Sqrt is represented in tosa by a Pow so we check for Pow + reciprocal.
LogicalResult matchAndRewrite(tosa::PowOp op,
PatternRewriter &rewriter) const override {
// Check that the PowOp has a single user
if (!op->hasOneUse())
return rewriter.notifyMatchFailure(op, "pow operator has more than one user");

Operation* user = *op->user_begin();
// Check that this user is a reciprocal
if (!isa<tosa::ReciprocalOp>(user))
return rewriter.notifyMatchFailure(op, "expected a pow + reciprocal pattern");

// Check that the Pow op is an Sqrt - its second input should be the scale, 0.5 for Sqrt.
Operation* powScale = op.getInput2().getDefiningOp();
if (!powScale || !isa<tosa::ConstOp>(powScale))
return rewriter.notifyMatchFailure(op, "expected the pow to have a constant scale input");

auto scale = cast<DenseElementsAttr>(cast<tosa::ConstOp>(powScale).getValue());
if (!scale.isSplat())
return rewriter.notifyMatchFailure(op, "expected the pow scale to be a splat tensor");

auto constantType = scale.getElementType();
float scaleValue = 0.;
if (constantType.isF32())
scaleValue = scale.getSplatValue<float>();
else
return rewriter.notifyMatchFailure(op, "unexpected type for scale value of the pow op");
if(scaleValue != 0.5)
return rewriter.notifyMatchFailure(op, "expected the pow to have a scale of 0.5 to be a sqrt");

auto inputType = cast<ShapedType>(op.getOperand(0).getType());
auto outputType = cast<ShapedType>(op.getType());
// If the operator needs tiling, fail to match
// An improvement for the future would be to generate a tile operator here instead
if (inputType != outputType)
return rewriter.notifyMatchFailure(op, "input type and output type are different, tiling is not supported for this canonicalization");

rewriter.replaceOpWithNewOp<tosa::RsqrtOp>(user, outputType, op.getInput1());

return success();
}
};

void PowOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SqrtReciprocalOptimization>(context);
}

LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,40 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :

// -----

// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal
func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> {
// CHECK: %[[RSQRT:.*]] = "tosa.rsqrt"(%arg{{.*}}) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>
// CHECK: return %[[RSQRT]] : tensor<1x5x1x1xf32>
%0 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
%1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x5x1x1xf32>
%2 = "tosa.reciprocal"(%1) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>
return %2 : tensor<1x5x1x1xf32>
}

// -----

// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal_no_match
func.func @canonicalize_optimize_sqrt_reciprocal_no_match(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> {
// CHECK-NOT: tosa.rsqrt"(%arg{{.*}})
%0 = "tosa.const"() <{value = dense<4.000000e-01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
%1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x5x1x1xf32>
%2 = "tosa.reciprocal"(%1) : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32>
return %2 : tensor<1x5x1x1xf32>
}

// -----

// CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal_tile_no_match
func.func @canonicalize_optimize_sqrt_reciprocal_tile_no_match(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x7x1xf32> {
// CHECK-NOT: tosa.rsqrt"(%arg{{.*}})
%0 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x7x1xf32>}> : () -> tensor<1x1x7x1xf32>
%1 = "tosa.pow"(%arg0, %0) : (tensor<1x5x1x1xf32>, tensor<1x1x7x1xf32>) -> tensor<1x5x7x1xf32>
%2 = "tosa.reciprocal"(%1) : (tensor<1x5x7x1xf32>) -> tensor<1x5x7x1xf32>
return %2 : tensor<1x5x7x1xf32>
}

// -----

// CHECK-LABEL
func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg{{.*}} : tensor<?x1xf32>
Expand Down

0 comments on commit 20684a4

Please sign in to comment.