Skip to content

Commit

Permalink
Add a StableHLO canonicalizer that rewrites transpose as reshape. (#1…
Browse files Browse the repository at this point in the history
…4682)

The canonicalization also "swallows" the TransposeOpCanon out of
convenience, since the check has to happen either way (and a NoOp should
take priority over different-op lowering).
  • Loading branch information
NatashaKnk authored Aug 15, 2023
1 parent 4f77c30 commit aef8d3c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -978,16 +978,44 @@ struct ReshapeOpCanon final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
}
};

struct TransposeOpCanon final : OpRewritePattern<mlir::stablehlo::TransposeOp> {
struct TransposeIsReshape final
: OpRewritePattern<mlir::stablehlo::TransposeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op,
PatternRewriter &rewriter) const override {
// Check if this transpose is a noop and use the operand instead.
if (!isIotaRange(op.getPermutation()))
return failure();
auto input = op.getOperand();
auto permutation = op.getPermutation();

rewriter.replaceOp(op, op.getOperand());
if (isIotaRange(permutation)) {
rewriter.replaceOp(op, op.getOperand());
return success();
}

auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy || !inputTy.hasStaticShape() ||
!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "requires input/output to be of a statically-shaped ranked "
"tensor type");
}

SmallVector<int64_t> permValues(permutation.getValues<int64_t>());

SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
for (auto idx : permValues) {
auto sz = inputTy.getDimSize(idx);
if (sz != 1)
nonZeroPerms.push_back(idx);
}

for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
if (nonZeroPerms[i - 1] > nonZeroPerms[i])
return rewriter.notifyMatchFailure(op, "memory layout change");

rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, op.getType(),
input);
return success();
}
};
Expand Down Expand Up @@ -1128,7 +1156,7 @@ void populateCanonicalizationPatterns(MLIRContext *context,
NoopReduceOpCanon, EmptyReduceOpCanon,
// Shape manipulation(-ish) ops.
ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon,
ReshapeOpCanon, TransposeOpCanon,
ReshapeOpCanon, TransposeIsReshape,
// Types.
ZeroExtentTensorCanon>(context, benefit);
patterns->add<ReorderElementwiseAndShapeOp>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,17 +402,17 @@ func.func @reshape(%arg0: tensor<1xf32>)
// -----

// CHECK-LABEL: func.func @transpose
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<1x2xf32>, [[ARG2:%.+]]: tensor<f32>)
func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<1x2xf32>, %arg2: tensor<f32>)
-> (tensor<2xf32>, tensor<1x2xf32>, tensor<2x1xf32>, tensor<f32>) {
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<3x2xf32>, [[ARG2:%.+]]: tensor<f32>)
func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor<f32>)
-> (tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor<f32>) {
%a = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%b = stablehlo.transpose %arg1, dims = [0, 1] : (tensor<1x2xf32>) -> tensor<1x2xf32>
%c = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
%b = stablehlo.transpose %arg1, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32>
%c = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32>
%d = stablehlo.transpose %arg2, dims = [] : (tensor<f32>) -> tensor<f32>

// CHECK-NEXT: [[X:%.+]] = stablehlo.transpose [[ARG1]], dims = [1, 0]
// CHECK-NEXT: return [[ARG0]], [[ARG1]], [[X]], [[ARG2]]
return %a, %b, %c, %d : tensor<2xf32>, tensor<1x2xf32>, tensor<2x1xf32>, tensor<f32>
return %a, %b, %c, %d : tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor<f32>
}

// -----
Expand Down Expand Up @@ -577,6 +577,22 @@ func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) ->

// -----

// CHECK-LABEL: @transpose_is_reshape
func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> {
// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
%0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
return %0 : tensor<1x4x1x5xf32>
}

// CHECK-LABEL: @transpose_is_not_reshape
func.func @transpose_is_not_reshape(%arg0: tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32> {
// CHECK-NOT: stablehlo.reshape
%0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32>
return %0 : tensor<2x4x1x5xf32>
}

// -----

// CHECK-LABEL: func.func @reduce_noop_1
// CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xf32>)
func.func @reduce_noop_1(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
Expand Down

0 comments on commit aef8d3c

Please sign in to comment.