From a62652ec40e7030ea816b3c86c02fa00f56256c4 Mon Sep 17 00:00:00 2001 From: NatashaKnk Date: Tue, 8 Aug 2023 14:13:12 -0700 Subject: [PATCH] Fix bug with stablehlo canonicalizer which reordered operations that had more than one use. (#14599) The previously-added canonicalizer would reorder unary elementwise operations even when the first operation in the swap had other uses. The swap should only be applied when the elementwise op is the only user of the shape operation (since it may change the value/type of the result). --- .../Preprocessing/Canonicalization.cpp | 11 ++++++++--- .../Preprocessing/test/canonicalization.mlir | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp index 744a7eeca833..78faae4dcf49 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp @@ -1055,20 +1055,25 @@ struct ReorderElementwiseAndShapeOp final LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getOperands().size() != 1) { - return rewriter.notifyMatchFailure(op, "expected to be unary."); + return rewriter.notifyMatchFailure(op, "expected to be unary"); } auto definingOp = op->getOperand(0).getDefiningOp(); if (!definingOp) { return rewriter.notifyMatchFailure( - op, "expected to have an op before elementise op."); + op, "expected to have an op before elementise op"); } if (!isa(definingOp) && !isa(definingOp) && !isa(definingOp)) { return rewriter.notifyMatchFailure( - op, "defining operation of unexpected type."); + op, "defining operation of unexpected type"); + } + + // Only reorder if the defining op has no other uses. + if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) { + return rewriter.notifyMatchFailure(op, "operation has more than one use"); } Value input = definingOp->getOperand(0); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir index bda6c3c9d0f5..a569716ff987 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir @@ -736,3 +736,20 @@ func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> { // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[CONVERT]] : (tensor<3x4xi64>) -> tensor<12xi64> // CHECK: return %[[RESHAPE]] +// ----- + +func.func @do_not_reorder_with_other_uses(%arg0: tensor<2x2xf64>, %arg1: tensor<4xf32>, %arg2: tensor) -> (tensor, tensor<4xf32>) { + %0 = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64> + %1 = stablehlo.convert %0 : (tensor<4xf64>) -> tensor<4xf32> + %2 = stablehlo.subtract %arg1, %1 : tensor<4xf32> + %3 = stablehlo.reduce(%0 init: %arg2) across dimensions = [0] : (tensor<4xf64>, tensor) -> tensor + reducer(%arg3: tensor, %arg4: tensor) { + %4 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %4 : tensor + } + return %3, %2 : tensor, tensor<4xf32> +} + +// CHECK-LABEL: do_not_reorder_with_other_uses +// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64> +// CHECK: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<4xf64>) -> tensor<4xf32>