From 12cb042b3ef4d4c16aab9fe232d1ff6c5a9e9888 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Wed, 30 Oct 2024 18:24:52 -0400 Subject: [PATCH] [Flow] Add pattern to canonicalize away full tensor.insert_slice ops (#18941) Additionally drops the pad combining pattern from this pass because it was upstreamed. --- .../Dialect/Flow/Transforms/Canonicalizer.cpp | 119 ++++++++---------- .../Transforms/test/flow_canonicalize.mlir | 106 ++++++---------- 2 files changed, 92 insertions(+), 133 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp index b19bccc14da3..1978d77b2d8d 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -18,80 +19,66 @@ namespace mlir::iree_compiler::IREE::Flow { namespace { -/// Folds a chain of `tensor.pad` ops with the same constant padding value. -/// -/// Example: -/// -/// ```mlir -/// %1 = tensor.pad %0 low[0, 1] high[0, 2] { -/// tensor.yield %val -/// } : tensor<1x2xf32> to tensor<2x5xf32> -/// %res = tensor.pad %1 low[0, 2] high[3, 0] { -/// tensor.yield %val -/// } : tensor<1x5xf32> to tensor<5x7xf32> -/// ``` -/// -/// folds into: -/// -/// ```mlir -/// %res = tensor.pad %0 low[0, 3] high[3, 2] { -/// tensor.yield %val -/// } : tensor<1x2xf32> to tensor<5x7xf32> -/// ``` -/// -/// NOTE: This wasn't sent upstream as a canonicalization due to the use of -/// the Affine dialect. -struct FoldConsecutiveConstantPadding : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::PadOp padOp, +static std::optional> getDefiningMixedSizes(Value v) { + if (auto empty = v.getDefiningOp()) { + return empty.getMixedSizes(); + } else if (auto extract = v.getDefiningOp()) { + // TODO: Support rank reducing cases. + if (extract.getSourceType().getRank() != + extract.getResultType().getRank()) { + return {}; + } + return extract.getMixedSizes(); + } + return {}; +} + +struct FoldFullInsertSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const override { - if (padOp.getNofold()) { - return failure(); + if (!insertSliceOp.hasUnitStride() || !insertSliceOp.hasZeroOffset()) { + return rewriter.notifyMatchFailure(insertSliceOp, + "non-unit stride or non-zero offset."); } - auto producerPad = padOp.getSource().getDefiningOp(); - if (!producerPad || producerPad.getNofold()) { + + RankedTensorType sourceType = insertSliceOp.getSourceType(); + RankedTensorType resultType = insertSliceOp.getResultType(); + if (sourceType != resultType) { return rewriter.notifyMatchFailure( - padOp, "producer is not a foldable tensor.pad op"); + insertSliceOp, + "unimplemented: Cast-like or reshape-like insert ops."); } - // Fail if the tensor::PadOps padding values do not match. - Value consumerPadValue = padOp.getConstantPaddingValue(); - Value producerPadValue = producerPad.getConstantPaddingValue(); - if (!consumerPadValue || !producerPadValue || - consumerPadValue != producerPadValue) { + std::optional> mixedSizes = + getDefiningMixedSizes(insertSliceOp.getDest()); + if (!mixedSizes) { return rewriter.notifyMatchFailure( - padOp, "cannot fold PadOps with different padding values"); + insertSliceOp, "Could not find producer with list of tensor sizes."); } - Location loc = padOp.getLoc(); - AffineExpr d0, d1; - bindDims(rewriter.getContext(), d0, d1); - - // Combine the low/high paddings of the two tensor::PadOps. - auto addPaddings = [&](ArrayRef consumerPaddings, - ArrayRef producerPaddings) { - SmallVector sumPaddings; - for (auto [consumerIndex, producerIndex] : - llvm::zip_equal(consumerPaddings, producerPaddings)) { - sumPaddings.push_back(affine::makeComposedFoldedAffineApply( - rewriter, loc, d0 + d1, {consumerIndex, producerIndex})); + for (auto [insertSize, destSize] : + llvm::zip_equal(insertSliceOp.getMixedSizes(), mixedSizes.value())) { + if (isa(insertSize) || isa(destSize)) { + if (insertSize != destSize) { + return rewriter.notifyMatchFailure(insertSliceOp, + "dynamic size mismatch"); + } + continue; } - return sumPaddings; - }; - - SmallVector newHighPad = - addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad()); - SmallVector newLowPad = - addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad()); - - auto newPadOp = rewriter.create( - padOp.getLoc(), padOp.getResultType(), producerPad.getSource(), - newLowPad, newHighPad, padOp.getNofold(), - getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames())); - rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), - newPadOp.getRegion().begin()); - rewriter.replaceOp(padOp, newPadOp.getResult()); + + // `getMixedSizes` for different ops returns different attribute types + // (`index` or `i64`) so we compare the values of the ints directly here. + int64_t staticInsertSize = getConstantIntValue(insertSize).value(); + int64_t staticDestSize = getConstantIntValue(insertSize).value(); + if (staticInsertSize != staticDestSize) { + return rewriter.notifyMatchFailure(insertSliceOp, + "static size mismatch"); + } + } + + rewriter.replaceOp(insertSliceOp, insertSliceOp.getSource()); return success(); } }; @@ -117,7 +104,7 @@ struct CanonicalizerPass // Pull in some borderline/downstream canonicalizations for the Flow // compilation phase. tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns); - owningPatterns.add(context); + owningPatterns.add(context); patterns = std::make_shared(std::move(owningPatterns)); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir index 81203a5db24c..8734b8591ce6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/flow_canonicalize.mlir @@ -1,84 +1,56 @@ // RUN: iree-opt --iree-flow-canonicalize %s --split-input-file --mlir-print-local-scope | FileCheck %s -util.func public @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> { - %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] { - ^bb0(%b0: index, %b1 : index): - tensor.yield %pad_value : f32 - } : tensor<2x3xf32> to tensor<4x4xf32> - %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] { - ^bb0(%b2: index, %b3 : index): - tensor.yield %pad_value : f32 - } : tensor<4x4xf32> to tensor<7x8xf32> - util.return %pad1 : tensor<7x8xf32> +util.func public @fold_full_insert_into_extract( + %source: tensor<8x?xf32>, + %dest: tensor<10x?xf32>, + %size: index) -> tensor<8x?xf32> { + %extract = tensor.extract_slice %dest [1, 1] [8, %size] [1, 1] : tensor<10x?xf32> to tensor<8x?xf32> + %insert = tensor.insert_slice %source into %extract [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32> + util.return %insert : tensor<8x?xf32> } -// CHECK-LABEL: util.func public @merge_constant_padding -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32> -// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32 -// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2] -// CHECK: tensor.yield %[[PADVAL]] -// CHECK: util.return %[[PAD]] + +// CHECK-LABEL: util.func public @fold_full_insert_into_extract +// CHECK-SAME: %[[SOURCE:.+]]: tensor<8x?xf32> +// CHECK: util.return %[[SOURCE]] // ----- -util.func public @merge_constant_padding_dynamic(%arg0: tensor, %idx: index, %pad_value: f32) -> tensor { - %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] { - ^bb0(%b0: index, %b1 : index): - tensor.yield %pad_value : f32 - } : tensor to tensor - %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] { - ^bb0(%b2: index, %b3 : index): - tensor.yield %pad_value : f32 - } : tensor to tensor - util.return %pad1 : tensor +util.func public @fold_full_insert_into_empty( + %source: tensor<8x?xf32>, + %size: index) -> tensor<8x?xf32> { + %empty = tensor.empty(%size) : tensor<8x?xf32> + %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32> + util.return %insert : tensor<8x?xf32> } -// CHECK-LABEL: util.func public @merge_constant_padding_dynamic -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor -// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index -// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32 -// CHECK: %[[HIGH:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%[[IDX]]] -// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2] -// CHECK: tensor.yield %[[PADVAL]] -// CHECK: util.return %[[PAD]] + +// CHECK-LABEL: util.func public @fold_full_insert_into_empty +// CHECK-SAME: %[[SOURCE:.+]]: tensor<8x?xf32> +// CHECK: util.return %[[SOURCE]] // ----- -util.func public @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> { - %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] { - ^bb0(%b0: index, %b1 : index): - tensor.yield %pad_value : f32 - } : tensor<2x3xf32> to tensor<4x4xf32> - %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] { - ^bb0(%b2: index, %b3 : index): - tensor.yield %pad_value : f32 - } : tensor<4x4xf32> to tensor<7x8xf32> - util.return %pad1 : tensor<7x8xf32> +util.func public @dont_fold_not_full_insert_into_empty( + %source: tensor<8x?xf32>, + %size1: index, %size2: index) -> tensor<8x?xf32> { + %empty = tensor.empty(%size1) : tensor<8x?xf32> + %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size2] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32> + util.return %insert : tensor<8x?xf32> } -// Verify that folding does not happen if it would drop a nofold attribute - -// CHECK-LABEL: util.func public @dont_merge_constant_padding_nofold -// CHECK: tensor.pad -// CHECK: tensor.pad {{.*}} nofold +// CHECK-LABEL: util.func public @dont_fold_not_full_insert_into_empty +// CHECK: %[[INSERT:.+]] = tensor.insert_slice +// CHECK: util.return %[[INSERT]] // ----- -util.func public @dont_merge_constant_padding_different_vals( - %arg0: tensor<2x3xf32>, - %pad_value0: f32, - %pad_value1: f32) -> tensor<7x8xf32> { - %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] { - ^bb0(%b0: index, %b1 : index): - tensor.yield %pad_value0 : f32 - } : tensor<2x3xf32> to tensor<4x4xf32> - %pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] { - ^bb0(%b2: index, %b3 : index): - tensor.yield %pad_value1 : f32 - } : tensor<4x4xf32> to tensor<7x8xf32> - util.return %pad1 : tensor<7x8xf32> +util.func public @dont_fold_not_full_static_insert_into_empty( + %source: tensor<8x?xf32>, + %size: index) -> tensor<10x?xf32> { + %empty = tensor.empty(%size) : tensor<10x?xf32> + %insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<10x?xf32> + util.return %insert : tensor<10x?xf32> } -// Verify that folding does not happen if it would drop a nofold attribute - -// CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals -// CHECK: tensor.pad -// CHECK: tensor.pad +// CHECK-LABEL: util.func public @dont_fold_not_full_static_insert_into_empty +// CHECK: %[[INSERT:.+]] = tensor.insert_slice +// CHECK: util.return %[[INSERT]]