Skip to content

Commit

Permalink
[StableHLO] Merge reshape(reshape(x))
Browse files Browse the repository at this point in the history
  • Loading branch information
mariecwhite committed Aug 21, 2023
1 parent 836f6f1 commit 3d1e704
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,29 @@ struct ReshapeOpCanon final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
}
};

struct MergeConsecutiveReshapes final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op,
PatternRewriter &rewriter) const override {
// Fold noop reshape.
auto operand = op.getOperand();
if (op.getType() == operand.getType()) {
rewriter.replaceOp(op, op.getOperand());
return success();
}

// Fold reshape(reshape(x)).
auto reshapeOp = operand.getDefiningOp<mlir::stablehlo::ReshapeOp>();
if (!reshapeOp) {
return failure();
}

op.setOperand(reshapeOp->getOperand(0));
return success();
}
};

struct TransposeIsReshape final
: OpRewritePattern<mlir::stablehlo::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -1156,7 +1179,7 @@ void populateCanonicalizationPatterns(MLIRContext *context,
NoopReduceOpCanon, EmptyReduceOpCanon,
// Shape manipulation(-ish) ops.
ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon,
ReshapeOpCanon, TransposeIsReshape,
ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape,
// Types.
ZeroExtentTensorCanon>(context, benefit);
patterns->add<ReorderElementwiseAndShapeOp>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,17 @@ func.func @reshape(%arg0: tensor<1xf32>)

// -----

// CHECK-LABEL: @merge_consecutive_reshapes
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]
func.func @merge_consecutive_reshapes(%arg0: tensor<4x4xi32>) -> tensor<16xi32> {
%0 = stablehlo.reshape %arg0 : (tensor<4x4xi32>) -> tensor<2x8xi32>
%1 = stablehlo.reshape %0 : (tensor<2x8xi32>) -> tensor<16xi32>
// CHECK: [[R0:%.+]] = stablehlo.reshape %[[ARG0]] : (tensor<4x4xi32>) -> tensor<16xi32>
return %1 : tensor<16xi32>
}

// -----

// CHECK-LABEL: func.func @transpose
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<3x2xf32>, [[ARG2:%.+]]: tensor<f32>)
func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor<f32>)
Expand Down

0 comments on commit 3d1e704

Please sign in to comment.