Skip to content

Commit

Permalink
Fix bug with stablehlo canonicalizer which reordered operations that …
Browse files Browse the repository at this point in the history
…had more than one use. (#14599)

The previously-added canonicalizer would reorder unary elementwise
operations even when the first operation in the swap had other uses. The
swap should only be applied when the elementwise op is the only user of
the shape operation (since it may change the value/type of the result).
  • Loading branch information
NatashaKnk authored Aug 8, 2023
1 parent 4efb894 commit a62652e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1055,20 +1055,25 @@ struct ReorderElementwiseAndShapeOp final
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getOperands().size() != 1) {
return rewriter.notifyMatchFailure(op, "expected to be unary.");
return rewriter.notifyMatchFailure(op, "expected to be unary");
}

auto definingOp = op->getOperand(0).getDefiningOp();
if (!definingOp) {
return rewriter.notifyMatchFailure(
op, "expected to have an op before elementise op.");
op, "expected to have an op before elementise op");
}

if (!isa<mlir::stablehlo::ReshapeOp>(definingOp) &&
!isa<mlir::stablehlo::TransposeOp>(definingOp) &&
!isa<mlir::stablehlo::BroadcastOp>(definingOp)) {
return rewriter.notifyMatchFailure(
op, "defining operation of unexpected type.");
op, "defining operation of unexpected type");
}

// Only reorder if the defining op has no other uses.
if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) {
return rewriter.notifyMatchFailure(op, "operation has more than one use");
}

Value input = definingOp->getOperand(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,20 @@ func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> {
// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[CONVERT]] : (tensor<3x4xi64>) -> tensor<12xi64>
// CHECK: return %[[RESHAPE]]

// -----

func.func @do_not_reorder_with_other_uses(%arg0: tensor<2x2xf64>, %arg1: tensor<4xf32>, %arg2: tensor<f64>) -> (tensor<f64>, tensor<4xf32>) {
%0 = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64>
%1 = stablehlo.convert %0 : (tensor<4xf64>) -> tensor<4xf32>
%2 = stablehlo.subtract %arg1, %1 : tensor<4xf32>
%3 = stablehlo.reduce(%0 init: %arg2) across dimensions = [0] : (tensor<4xf64>, tensor<f64>) -> tensor<f64>
reducer(%arg3: tensor<f64>, %arg4: tensor<f64>) {
%4 = stablehlo.add %arg3, %arg4 : tensor<f64>
stablehlo.return %4 : tensor<f64>
}
return %3, %2 : tensor<f64>, tensor<4xf32>
}

// CHECK-LABEL: do_not_reorder_with_other_uses
// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64>
// CHECK: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<4xf64>) -> tensor<4xf32>

0 comments on commit a62652e

Please sign in to comment.