diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index f8169411fc22..fed4470e8580 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -314,51 +314,31 @@ static bool hasPadding(Operation *op) { } /// Control function for decomposing pack and unpack ops. Returns true if the -/// op is a pack or unpack op, and its reshapes can be folded with a producer -/// or consumer interface tensor op. To be foldable, the following conditions -/// must be met: -/// +/// op is an unpadded pack or unpack op, and it is at the boundary of a +/// dispatch. The following conditions need to be met: /// 1. The PackOp or UnPackOp must have no padding. /// 2. If the op is a PackOp, then its producer must be a dispatch tensor load. /// 3. If the op is an UnPackOp, then all of its consumers must be dispatch /// tensor stores. -/// 4. Any dispatch tensor load producers or dispatch tensor store consumers -/// must be full slices. -static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) { - // Full slice means zero offsets, unit strides, and sizes match full tensor - // shape. - auto isFullSlice = - [](ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, ArrayRef fullTensorShape) { - return areAllConstantIntValue(offsets, 0) && - areAllConstantIntValue(strides, 1) && - areConstantIntValues(sizes, fullTensorShape); - }; - if (!isa(op)) { +static LogicalResult isUnpaddedAndAtBoundary(Operation *op) { + if (!isa(op) && !isa(op)) { return failure(); } if (hasPadding(op)) { return failure(); } - // If the producer is a full slice dispatch tensor load, then the `op` is - // foldable if it is a PackOp. - auto load = dyn_cast( - op->getOperand(0).getDefiningOp()); - if (isa(op) && load && - isFullSlice(load.getMixedOffsets(), load.getMixedSizes(), - load.getMixedStrides(), load.getSourceType().getShape())) { + // If the producer is a dispatch tensor load, then the `op` is decomposable + // if it is a PackOp. + if (isa(op) && + op->getOperand(0).getDefiningOp()) { return success(); } - // If all consumers are full slice dispatch tensor stores, then the `op` is - // foldable if it is an UnPackOp. + // If all consumers are dispatch tensor stores, then the `op` is decomposable + // if it is an UnPackOp. if (isa(op) && llvm::all_of(op->getUsers(), [&](Operation *user) { - auto store = dyn_cast(user); - return store && - isFullSlice(store.getMixedOffsets(), store.getMixedSizes(), - store.getMixedStrides(), - store.getTargetType().getShape()); + return isa(user); })) { return success(); } @@ -368,7 +348,7 @@ static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) { void DecomposeBoundaryPackUnPackOpsPass::runOnOperation() { if (failed(commonRunOnOperation(&getContext(), getOperation(), /*useOnlyReshapes=*/true, tileOuterToOne, - isFoldableIntoInterfaceTensor))) { + isUnpaddedAndAtBoundary))) { return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir index 096043ba8897..6ff5bed59060 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir @@ -133,7 +133,7 @@ func.func @load_non_full_slice() { return } // CHECK-LABEL: func.func @load_non_full_slice -// CHECK: tensor.pack +// CHECK-NOT: tensor.pack // ----- @@ -152,7 +152,7 @@ func.func @store_non_full_slice() { return } // CHECK-LABEL: func.func @store_non_full_slice -// CHECK: tensor.unpack +// CHECK-NOT: tensor.unpack // -----