diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp index 07e95e8c7d9a..26cab680c34c 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp @@ -978,6 +978,31 @@ struct ReshapeOpCanon final : OpRewritePattern { } }; +struct MergeConsecutiveReshapes final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, + PatternRewriter &rewriter) const override { + // Fold noop reshape. + auto operand = op.getOperand(); + if (op.getType() == operand.getType()) { + rewriter.replaceOp(op, op.getOperand()); + return success(); + } + + // Fold reshape(reshape(x)). + auto reshapeOp = operand.getDefiningOp(); + if (!reshapeOp) { + return rewriter.notifyMatchFailure( + op, "requires defining op of operand to be Reshape"); + } + + op.setOperand(reshapeOp->getOperand(0)); + return success(); + } +}; + struct TransposeIsReshape final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1156,7 +1181,7 @@ void populateCanonicalizationPatterns(MLIRContext *context, NoopReduceOpCanon, EmptyReduceOpCanon, // Shape manipulation(-ish) ops. ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon, - ReshapeOpCanon, TransposeIsReshape, + ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape, // Types. ZeroExtentTensorCanon>(context, benefit); patterns->add(context); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir index e8eb9825f321..af6df8e81fcb 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir @@ -401,6 +401,17 @@ func.func @reshape(%arg0: tensor<1xf32>) // ----- +// CHECK-LABEL: @merge_consecutive_reshapes +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] +func.func @merge_consecutive_reshapes(%arg0: tensor<4x4xi32>) -> tensor<16xi32> { + %0 = stablehlo.reshape %arg0 : (tensor<4x4xi32>) -> tensor<2x8xi32> + %1 = stablehlo.reshape %0 : (tensor<2x8xi32>) -> tensor<16xi32> + // CHECK: [[R0:%.+]] = stablehlo.reshape %[[ARG0]] : (tensor<4x4xi32>) -> tensor<16xi32> + return %1 : tensor<16xi32> +} + +// ----- + // CHECK-LABEL: func.func @transpose // CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<3x2xf32>, [[ARG2:%.+]]: tensor) func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor)