diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp index 744a7eeca833..78faae4dcf49 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp @@ -1055,20 +1055,25 @@ struct ReorderElementwiseAndShapeOp final LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getOperands().size() != 1) { - return rewriter.notifyMatchFailure(op, "expected to be unary."); + return rewriter.notifyMatchFailure(op, "expected to be unary"); } auto definingOp = op->getOperand(0).getDefiningOp(); if (!definingOp) { return rewriter.notifyMatchFailure( - op, "expected to have an op before elementise op."); + op, "expected to have an op before elementise op"); } if (!isa(definingOp) && !isa(definingOp) && !isa(definingOp)) { return rewriter.notifyMatchFailure( - op, "defining operation of unexpected type."); + op, "defining operation of unexpected type"); + } + + // Only reorder if the defining op has no other uses. + if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) { + return rewriter.notifyMatchFailure(op, "operation has more than one use"); } Value input = definingOp->getOperand(0); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir index bda6c3c9d0f5..a569716ff987 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir @@ -736,3 +736,20 @@ func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> { // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[CONVERT]] : (tensor<3x4xi64>) -> tensor<12xi64> // CHECK: return %[[RESHAPE]] +// ----- + +func.func @do_not_reorder_with_other_uses(%arg0: tensor<2x2xf64>, %arg1: tensor<4xf32>, %arg2: tensor) -> (tensor, tensor<4xf32>) { + %0 = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64> + %1 = stablehlo.convert %0 : (tensor<4xf64>) -> tensor<4xf32> + %2 = stablehlo.subtract %arg1, %1 : tensor<4xf32> + %3 = stablehlo.reduce(%0 init: %arg2) across dimensions = [0] : (tensor<4xf64>, tensor) -> tensor + reducer(%arg3: tensor, %arg4: tensor) { + %4 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %4 : tensor + } + return %3, %2 : tensor, tensor<4xf32> +} + +// CHECK-LABEL: do_not_reorder_with_other_uses +// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<2x2xf64>) -> tensor<4xf64> +// CHECK: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<4xf64>) -> tensor<4xf32>