From 9d36cfa0a95a606387b65a551cf33ba5d1fb91ee Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:30:37 -0700 Subject: [PATCH] [Codegen] Don't require full slice to decompose boundary pack and unpack ops (#18906) This PR loosens the restrictions on decomposing boundary pack and unpack ops. The current restriction is that the dispatch.tensor.load/store ops are full slices, but this is not necessary for the current use case in the TileAndFuse pipeline. Instead, it is better for the time being to decompose non-padded pack/unpack ops at function boundaries regardless of the dispatch.tensor.load/store ops being full slices, because decomposing such ops later on can cause issues with DPS. The DPS issues are tracked in https://github.com/iree-org/iree/issues/18902, but we can loosen the restrictions regardless, since it does not pose any issues to decompose in such cases. Signed-off-by: Max Dawkins --- .../Codegen/Common/DecomposePackUnPackOps.cpp | 44 +++++-------------- .../decompose_boundary_pack_unpack_ops.mlir | 4 +- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index f8169411fc22..fed4470e8580 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -314,51 +314,31 @@ static bool hasPadding(Operation *op) { } /// Control function for decomposing pack and unpack ops. Returns true if the -/// op is a pack or unpack op, and its reshapes can be folded with a producer -/// or consumer interface tensor op. To be foldable, the following conditions -/// must be met: -/// +/// op is an unpadded pack or unpack op, and it is at the boundary of a +/// dispatch. The following conditions need to be met: /// 1. The PackOp or UnPackOp must have no padding. /// 2. If the op is a PackOp, then its producer must be a dispatch tensor load. /// 3. If the op is an UnPackOp, then all of its consumers must be dispatch /// tensor stores. -/// 4. Any dispatch tensor load producers or dispatch tensor store consumers -/// must be full slices. -static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) { - // Full slice means zero offsets, unit strides, and sizes match full tensor - // shape. - auto isFullSlice = - [](ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, ArrayRef fullTensorShape) { - return areAllConstantIntValue(offsets, 0) && - areAllConstantIntValue(strides, 1) && - areConstantIntValues(sizes, fullTensorShape); - }; - if (!isa(op)) { +static LogicalResult isUnpaddedAndAtBoundary(Operation *op) { + if (!isa(op) && !isa(op)) { return failure(); } if (hasPadding(op)) { return failure(); } - // If the producer is a full slice dispatch tensor load, then the `op` is - // foldable if it is a PackOp. - auto load = dyn_cast( - op->getOperand(0).getDefiningOp()); - if (isa(op) && load && - isFullSlice(load.getMixedOffsets(), load.getMixedSizes(), - load.getMixedStrides(), load.getSourceType().getShape())) { + // If the producer is a dispatch tensor load, then the `op` is decomposable + // if it is a PackOp. + if (isa(op) && + op->getOperand(0).getDefiningOp()) { return success(); } - // If all consumers are full slice dispatch tensor stores, then the `op` is - // foldable if it is an UnPackOp. + // If all consumers are dispatch tensor stores, then the `op` is decomposable + // if it is an UnPackOp. if (isa(op) && llvm::all_of(op->getUsers(), [&](Operation *user) { - auto store = dyn_cast(user); - return store && - isFullSlice(store.getMixedOffsets(), store.getMixedSizes(), - store.getMixedStrides(), - store.getTargetType().getShape()); + return isa(user); })) { return success(); } @@ -368,7 +348,7 @@ static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) { void DecomposeBoundaryPackUnPackOpsPass::runOnOperation() { if (failed(commonRunOnOperation(&getContext(), getOperation(), /*useOnlyReshapes=*/true, tileOuterToOne, - isFoldableIntoInterfaceTensor))) { + isUnpaddedAndAtBoundary))) { return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir index 096043ba8897..6ff5bed59060 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir @@ -133,7 +133,7 @@ func.func @load_non_full_slice() { return } // CHECK-LABEL: func.func @load_non_full_slice -// CHECK: tensor.pack +// CHECK-NOT: tensor.pack // ----- @@ -152,7 +152,7 @@ func.func @store_non_full_slice() { return } // CHECK-LABEL: func.func @store_non_full_slice -// CHECK: tensor.unpack +// CHECK-NOT: tensor.unpack // -----