Skip to content

Commit

Permalink
[Flow] Add pattern to canonicalize away full tensor.insert_slice ops (i…
Browse files Browse the repository at this point in the history
…ree-org#18941)

Additionally drops the pad combining pattern from this pass because it
was upstreamed.
  • Loading branch information
qedawkins authored Oct 30, 2024
1 parent 53813e8 commit 12cb042
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 133 deletions.
119 changes: 53 additions & 66 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand All @@ -18,80 +19,66 @@ namespace mlir::iree_compiler::IREE::Flow {

namespace {

/// Folds a chain of `tensor.pad` ops with the same constant padding value.
///
/// Example:
///
/// ```mlir
/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
/// tensor.yield %val
/// } : tensor<1x2xf32> to tensor<2x5xf32>
/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
/// tensor.yield %val
/// } : tensor<1x5xf32> to tensor<5x7xf32>
/// ```
///
/// folds into:
///
/// ```mlir
/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
/// tensor.yield %val
/// } : tensor<1x2xf32> to tensor<5x7xf32>
/// ```
///
/// NOTE: This wasn't sent upstream as a canonicalization due to the use of
/// the Affine dialect.
struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PadOp padOp,
static std::optional<SmallVector<OpFoldResult>> getDefiningMixedSizes(Value v) {
if (auto empty = v.getDefiningOp<tensor::EmptyOp>()) {
return empty.getMixedSizes();
} else if (auto extract = v.getDefiningOp<tensor::ExtractSliceOp>()) {
// TODO: Support rank reducing cases.
if (extract.getSourceType().getRank() !=
extract.getResultType().getRank()) {
return {};
}
return extract.getMixedSizes();
}
return {};
}

struct FoldFullInsertSlice : public OpRewritePattern<tensor::InsertSliceOp> {
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override {
if (padOp.getNofold()) {
return failure();
if (!insertSliceOp.hasUnitStride() || !insertSliceOp.hasZeroOffset()) {
return rewriter.notifyMatchFailure(insertSliceOp,
"non-unit stride or non-zero offset.");
}
auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
if (!producerPad || producerPad.getNofold()) {

RankedTensorType sourceType = insertSliceOp.getSourceType();
RankedTensorType resultType = insertSliceOp.getResultType();
if (sourceType != resultType) {
return rewriter.notifyMatchFailure(
padOp, "producer is not a foldable tensor.pad op");
insertSliceOp,
"unimplemented: Cast-like or reshape-like insert ops.");
}

// Fail if the tensor::PadOps padding values do not match.
Value consumerPadValue = padOp.getConstantPaddingValue();
Value producerPadValue = producerPad.getConstantPaddingValue();
if (!consumerPadValue || !producerPadValue ||
consumerPadValue != producerPadValue) {
std::optional<SmallVector<OpFoldResult>> mixedSizes =
getDefiningMixedSizes(insertSliceOp.getDest());
if (!mixedSizes) {
return rewriter.notifyMatchFailure(
padOp, "cannot fold PadOps with different padding values");
insertSliceOp, "Could not find producer with list of tensor sizes.");
}

Location loc = padOp.getLoc();
AffineExpr d0, d1;
bindDims(rewriter.getContext(), d0, d1);

// Combine the low/high paddings of the two tensor::PadOps.
auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
ArrayRef<OpFoldResult> producerPaddings) {
SmallVector<OpFoldResult> sumPaddings;
for (auto [consumerIndex, producerIndex] :
llvm::zip_equal(consumerPaddings, producerPaddings)) {
sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
for (auto [insertSize, destSize] :
llvm::zip_equal(insertSliceOp.getMixedSizes(), mixedSizes.value())) {
if (isa<Value>(insertSize) || isa<Value>(destSize)) {
if (insertSize != destSize) {
return rewriter.notifyMatchFailure(insertSliceOp,
"dynamic size mismatch");
}
continue;
}
return sumPaddings;
};

SmallVector<OpFoldResult> newHighPad =
addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
SmallVector<OpFoldResult> newLowPad =
addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());

auto newPadOp = rewriter.create<tensor::PadOp>(
padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
newLowPad, newHighPad, padOp.getNofold(),
getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
newPadOp.getRegion().begin());
rewriter.replaceOp(padOp, newPadOp.getResult());

// `getMixedSizes` for different ops returns different attribute types
// (`index` or `i64`) so we compare the values of the ints directly here.
int64_t staticInsertSize = getConstantIntValue(insertSize).value();
int64_t staticDestSize = getConstantIntValue(insertSize).value();
if (staticInsertSize != staticDestSize) {
return rewriter.notifyMatchFailure(insertSliceOp,
"static size mismatch");
}
}

rewriter.replaceOp(insertSliceOp, insertSliceOp.getSource());
return success();
}
};
Expand All @@ -117,7 +104,7 @@ struct CanonicalizerPass
// Pull in some borderline/downstream canonicalizations for the Flow
// compilation phase.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);
owningPatterns.add<FoldConsecutiveConstantPadding>(context);
owningPatterns.add<FoldFullInsertSlice>(context);

patterns =
std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,84 +1,56 @@
// RUN: iree-opt --iree-flow-canonicalize %s --split-input-file --mlir-print-local-scope | FileCheck %s

util.func public @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
%pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<2x3xf32> to tensor<4x4xf32>
%pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value : f32
} : tensor<4x4xf32> to tensor<7x8xf32>
util.return %pad1 : tensor<7x8xf32>
util.func public @fold_full_insert_into_extract(
%source: tensor<8x?xf32>,
%dest: tensor<10x?xf32>,
%size: index) -> tensor<8x?xf32> {
%extract = tensor.extract_slice %dest [1, 1] [8, %size] [1, 1] : tensor<10x?xf32> to tensor<8x?xf32>
%insert = tensor.insert_slice %source into %extract [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
util.return %insert : tensor<8x?xf32>
}
// CHECK-LABEL: util.func public @merge_constant_padding
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
// CHECK: tensor.yield %[[PADVAL]]
// CHECK: util.return %[[PAD]]

// CHECK-LABEL: util.func public @fold_full_insert_into_extract
// CHECK-SAME: %[[SOURCE:.+]]: tensor<8x?xf32>
// CHECK: util.return %[[SOURCE]]

// -----

util.func public @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
%pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
%pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
util.return %pad1 : tensor<?x?xf32>
util.func public @fold_full_insert_into_empty(
%source: tensor<8x?xf32>,
%size: index) -> tensor<8x?xf32> {
%empty = tensor.empty(%size) : tensor<8x?xf32>
%insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
util.return %insert : tensor<8x?xf32>
}
// CHECK-LABEL: util.func public @merge_constant_padding_dynamic
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
// CHECK: %[[HIGH:.+]] = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%[[IDX]]]
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
// CHECK: tensor.yield %[[PADVAL]]
// CHECK: util.return %[[PAD]]

// CHECK-LABEL: util.func public @fold_full_insert_into_empty
// CHECK-SAME: %[[SOURCE:.+]]: tensor<8x?xf32>
// CHECK: util.return %[[SOURCE]]

// -----

util.func public @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
%pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value : f32
} : tensor<2x3xf32> to tensor<4x4xf32>
%pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value : f32
} : tensor<4x4xf32> to tensor<7x8xf32>
util.return %pad1 : tensor<7x8xf32>
util.func public @dont_fold_not_full_insert_into_empty(
%source: tensor<8x?xf32>,
%size1: index, %size2: index) -> tensor<8x?xf32> {
%empty = tensor.empty(%size1) : tensor<8x?xf32>
%insert = tensor.insert_slice %source into %empty [0, 0] [8, %size2] [1, 1] : tensor<8x?xf32> into tensor<8x?xf32>
util.return %insert : tensor<8x?xf32>
}

// Verify that folding does not happen if it would drop a nofold attribute

// CHECK-LABEL: util.func public @dont_merge_constant_padding_nofold
// CHECK: tensor.pad
// CHECK: tensor.pad {{.*}} nofold
// CHECK-LABEL: util.func public @dont_fold_not_full_insert_into_empty
// CHECK: %[[INSERT:.+]] = tensor.insert_slice
// CHECK: util.return %[[INSERT]]

// -----

util.func public @dont_merge_constant_padding_different_vals(
%arg0: tensor<2x3xf32>,
%pad_value0: f32,
%pad_value1: f32) -> tensor<7x8xf32> {
%pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
^bb0(%b0: index, %b1 : index):
tensor.yield %pad_value0 : f32
} : tensor<2x3xf32> to tensor<4x4xf32>
%pad1 = tensor.pad %pad0 nofold low[0, 2] high[3, 2] {
^bb0(%b2: index, %b3 : index):
tensor.yield %pad_value1 : f32
} : tensor<4x4xf32> to tensor<7x8xf32>
util.return %pad1 : tensor<7x8xf32>
util.func public @dont_fold_not_full_static_insert_into_empty(
%source: tensor<8x?xf32>,
%size: index) -> tensor<10x?xf32> {
%empty = tensor.empty(%size) : tensor<10x?xf32>
%insert = tensor.insert_slice %source into %empty [0, 0] [8, %size] [1, 1] : tensor<8x?xf32> into tensor<10x?xf32>
util.return %insert : tensor<10x?xf32>
}

// Verify that folding does not happen if it would drop a nofold attribute

// CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals
// CHECK: tensor.pad
// CHECK: tensor.pad
// CHECK-LABEL: util.func public @dont_fold_not_full_static_insert_into_empty
// CHECK: %[[INSERT:.+]] = tensor.insert_slice
// CHECK: util.return %[[INSERT]]

0 comments on commit 12cb042

Please sign in to comment.