Skip to content

Commit

Permalink
[Codegen] Don't require full slice to decompose boundary pack and unp…
Browse files Browse the repository at this point in the history
…ack ops (#18906)

This PR loosens the restrictions on decomposing boundary pack and unpack
ops. The current restriction is that the dispatch.tensor.load/store ops
are full slices, but this is not necessary for the current use case in
the TileAndFuse pipeline.

Instead, it is better for the time being to decompose non-padded
pack/unpack ops at function boundaries regardless of the
dispatch.tensor.load/store ops being full slices, because decomposing
such ops later on can cause issues with DPS. The DPS issues are tracked
in #18902, but we can loosen the
restrictions regardless, since it does not pose any issues to decompose
in such cases.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Oct 28, 2024
1 parent e66171a commit 9d36cfa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,51 +314,31 @@ static bool hasPadding(Operation *op) {
}

/// Control function for decomposing pack and unpack ops. Returns true if the
/// op is a pack or unpack op, and its reshapes can be folded with a producer
/// or consumer interface tensor op. To be foldable, the following conditions
/// must be met:
///
/// op is an unpadded pack or unpack op, and it is at the boundary of a
/// dispatch. The following conditions need to be met:
/// 1. The PackOp or UnPackOp must have no padding.
/// 2. If the op is a PackOp, then its producer must be a dispatch tensor load.
/// 3. If the op is an UnPackOp, then all of its consumers must be dispatch
/// tensor stores.
/// 4. Any dispatch tensor load producers or dispatch tensor store consumers
/// must be full slices.
static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) {
// Full slice means zero offsets, unit strides, and sizes match full tensor
// shape.
auto isFullSlice =
[](ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides, ArrayRef<int64_t> fullTensorShape) {
return areAllConstantIntValue(offsets, 0) &&
areAllConstantIntValue(strides, 1) &&
areConstantIntValues(sizes, fullTensorShape);
};
if (!isa<tensor::PackOp, tensor::UnPackOp>(op)) {
static LogicalResult isUnpaddedAndAtBoundary(Operation *op) {
if (!isa<tensor::PackOp>(op) && !isa<tensor::UnPackOp>(op)) {
return failure();
}
if (hasPadding(op)) {
return failure();
}

// If the producer is a full slice dispatch tensor load, then the `op` is
// foldable if it is a PackOp.
auto load = dyn_cast<IREE::Flow::DispatchTensorLoadOp>(
op->getOperand(0).getDefiningOp());
if (isa<tensor::PackOp>(op) && load &&
isFullSlice(load.getMixedOffsets(), load.getMixedSizes(),
load.getMixedStrides(), load.getSourceType().getShape())) {
// If the producer is a dispatch tensor load, then the `op` is decomposable
// if it is a PackOp.
if (isa<tensor::PackOp>(op) &&
op->getOperand(0).getDefiningOp<IREE::Flow::DispatchTensorLoadOp>()) {
return success();
}
// If all consumers are full slice dispatch tensor stores, then the `op` is
// foldable if it is an UnPackOp.
// If all consumers are dispatch tensor stores, then the `op` is decomposable
// if it is an UnPackOp.
if (isa<tensor::UnPackOp>(op) &&
llvm::all_of(op->getUsers(), [&](Operation *user) {
auto store = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(user);
return store &&
isFullSlice(store.getMixedOffsets(), store.getMixedSizes(),
store.getMixedStrides(),
store.getTargetType().getShape());
return isa<IREE::Flow::DispatchTensorStoreOp>(user);
})) {
return success();
}
Expand All @@ -368,7 +348,7 @@ static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) {
void DecomposeBoundaryPackUnPackOpsPass::runOnOperation() {
if (failed(commonRunOnOperation(&getContext(), getOperation(),
/*useOnlyReshapes=*/true, tileOuterToOne,
isFoldableIntoInterfaceTensor))) {
isUnpaddedAndAtBoundary))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func.func @load_non_full_slice() {
return
}
// CHECK-LABEL: func.func @load_non_full_slice
// CHECK: tensor.pack
// CHECK-NOT: tensor.pack

// -----

Expand All @@ -152,7 +152,7 @@ func.func @store_non_full_slice() {
return
}
// CHECK-LABEL: func.func @store_non_full_slice
// CHECK: tensor.unpack
// CHECK-NOT: tensor.unpack

// -----

Expand Down

0 comments on commit 9d36cfa

Please sign in to comment.