From aef8d3ce065826e915a45e12135d6f791737d8b4 Mon Sep 17 00:00:00 2001 From: NatashaKnk Date: Tue, 15 Aug 2023 16:30:29 -0700 Subject: [PATCH] Add a StableHLO canonicalizer that rewrites transpose as reshape. (#14682) The canonicalization also "swallows" the TransposeOpCanon out of convenience, since the check has to happen either way (and a NoOp should take priority over different-op lowering). --- .../Preprocessing/Canonicalization.cpp | 40 ++++++++++++++++--- .../Preprocessing/test/canonicalization.mlir | 28 ++++++++++--- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp index 78faae4dcf49..07e95e8c7d9a 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp @@ -978,16 +978,44 @@ struct ReshapeOpCanon final : OpRewritePattern { } }; -struct TransposeOpCanon final : OpRewritePattern { +struct TransposeIsReshape final + : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, PatternRewriter &rewriter) const override { - // Check if this transpose is a noop and use the operand instead. - if (!isIotaRange(op.getPermutation())) - return failure(); + auto input = op.getOperand(); + auto permutation = op.getPermutation(); - rewriter.replaceOp(op, op.getOperand()); + if (isIotaRange(permutation)) { + rewriter.replaceOp(op, op.getOperand()); + return success(); + } + + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasStaticShape() || + !op.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "requires input/output to be of a statically-shaped ranked " + "tensor type"); + } + + SmallVector permValues(permutation.getValues()); + + SmallVector nonZeroPerms; + nonZeroPerms.reserve(permValues.size()); + for (auto idx : permValues) { + auto sz = inputTy.getDimSize(idx); + if (sz != 1) + nonZeroPerms.push_back(idx); + } + + for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) + if (nonZeroPerms[i - 1] > nonZeroPerms[i]) + return rewriter.notifyMatchFailure(op, "memory layout change"); + + rewriter.replaceOpWithNewOp(op, op.getType(), + input); return success(); } }; @@ -1128,7 +1156,7 @@ void populateCanonicalizationPatterns(MLIRContext *context, NoopReduceOpCanon, EmptyReduceOpCanon, // Shape manipulation(-ish) ops. ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon, - ReshapeOpCanon, TransposeOpCanon, + ReshapeOpCanon, TransposeIsReshape, // Types. ZeroExtentTensorCanon>(context, benefit); patterns->add(context); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir index a569716ff987..e8eb9825f321 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir @@ -402,17 +402,17 @@ func.func @reshape(%arg0: tensor<1xf32>) // ----- // CHECK-LABEL: func.func @transpose -// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<1x2xf32>, [[ARG2:%.+]]: tensor) -func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<1x2xf32>, %arg2: tensor) - -> (tensor<2xf32>, tensor<1x2xf32>, tensor<2x1xf32>, tensor) { +// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xf32>, [[ARG1:%.+]]: tensor<3x2xf32>, [[ARG2:%.+]]: tensor) +func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor) + -> (tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor) { %a = stablehlo.transpose %arg0, dims = [0] : (tensor<2xf32>) -> tensor<2xf32> - %b = stablehlo.transpose %arg1, dims = [0, 1] : (tensor<1x2xf32>) -> tensor<1x2xf32> - %c = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32> + %b = stablehlo.transpose %arg1, dims = [0, 1] : (tensor<3x2xf32>) -> tensor<3x2xf32> + %c = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32> %d = stablehlo.transpose %arg2, dims = [] : (tensor) -> tensor // CHECK-NEXT: [[X:%.+]] = stablehlo.transpose [[ARG1]], dims = [1, 0] // CHECK-NEXT: return [[ARG0]], [[ARG1]], [[X]], [[ARG2]] - return %a, %b, %c, %d : tensor<2xf32>, tensor<1x2xf32>, tensor<2x1xf32>, tensor + return %a, %b, %c, %d : tensor<2xf32>, tensor<3x2xf32>, tensor<2x3xf32>, tensor } // ----- @@ -577,6 +577,22 @@ func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) -> // ----- +// CHECK-LABEL: @transpose_is_reshape +func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> { + // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> + %0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> + return %0 : tensor<1x4x1x5xf32> +} + +// CHECK-LABEL: @transpose_is_not_reshape +func.func @transpose_is_not_reshape(%arg0: tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32> { + // CHECK-NOT: stablehlo.reshape + %0 = stablehlo.transpose %arg0, dims = [3, 1, 0, 2] : (tensor<1x4x5x2xf32>) -> tensor<2x4x1x5xf32> + return %0 : tensor<2x4x1x5xf32> +} + +// ----- + // CHECK-LABEL: func.func @reduce_noop_1 // CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xf32>) func.func @reduce_noop_1(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {