From 1bf6f059d9f1e29917096da25fee9884425cd1bd Mon Sep 17 00:00:00 2001 From: yzhang93 Date: Wed, 9 Oct 2024 11:49:32 -0700 Subject: [PATCH] [CombineStridedOps] Add a combinable case --- .../Transforms/AMDAIEDmaUtils.cpp | 94 ++++++++++++------- .../AMDAIETemporaryAllocBufferization.cpp | 1 - .../Transforms/test/AMDAIEDmaUtilsTest.cpp | 14 +++ .../Transforms/test/combine_strided_ops.mlir | 44 +++++++++ 4 files changed, 120 insertions(+), 33 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaUtils.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaUtils.cpp index 229e86e53..e1fdc46f8 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaUtils.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEDmaUtils.cpp @@ -60,8 +60,14 @@ bool areAccessPatternsCombinable(const SmallVector &offsetsA, } if (strideA != strideB) return false; } + + // Don't check the outermost dimension of size at this point. + SmallVector innerSizesA; + SmallVector innerSizesB; + std::copy(sizesA.begin() + 1, sizesA.end(), std::back_inserter(innerSizesA)); + std::copy(sizesB.begin() + 1, sizesB.end(), std::back_inserter(innerSizesB)); for (auto &&[sizeA, sizeB] : - llvm::zip(llvm::reverse(sizesA), llvm::reverse(sizesB))) { + llvm::zip(llvm::reverse(innerSizesA), llvm::reverse(innerSizesB))) { std::optional maybeSizeA = getConstantIntValue(sizeA); std::optional maybeSizeB = getConstantIntValue(sizeB); // Handle static and constant value with same int value. @@ -71,6 +77,20 @@ bool areAccessPatternsCombinable(const SmallVector &offsetsA, if (sizeA != sizeB) return false; } + // Edge case for sizesA[0] != sizesB[0]. + if (offsetsB.size() == offsetsA.size() && sizesA[0] != sizesB[0]) { + std::optional constOffsetA = getConstantIntValue(offsetsA[0]); + std::optional constSizeA = getConstantIntValue(sizesA[0]); + std::optional constOffsetB = getConstantIntValue(offsetsB[0]); + std::optional constSizeB = getConstantIntValue(sizesB[0]); + if (constOffsetA && constOffsetB && constSizeA && constSizeB) { + int64_t offsetDiff = constOffsetB.value() - constOffsetA.value(); + if (constSizeA.value() != offsetDiff) return false; + } else { + return false; + } + } + bool foundDiff{false}; for (auto iter : llvm::enumerate( llvm::zip(llvm::reverse(offsetsA), llvm::reverse(offsetsB)))) { @@ -169,40 +189,50 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter, if (!size) return failure(); newSizes[0] = rewriter.getI64IntegerAttr(size.value() + 1); } else { - // Sizes are the same, so add a new dimension with 'offset == 0', 'size == - // 2' and 'stride == offsetDiff'. - newOffsets.push_back(rewriter.getI64IntegerAttr(0)); - int64_t offsetDiff; - int64_t strideMultiplier; - for (auto iter : llvm::enumerate(llvm::zip(offsetsA, offsetsB))) { - const OpFoldResult &offsetA = std::get<0>(iter.value()); - const OpFoldResult &offsetB = std::get<1>(iter.value()); - newOffsets.push_back(offsetA); - if (offsetA != offsetB) { - std::optional constOffsetA = getConstantIntValue(offsetA); - std::optional constOffsetB = getConstantIntValue(offsetB); - if (!constOffsetA || !constOffsetB) { - return emitError(rewriter.getUnknownLoc()) - << "differing offsets should be constants"; - } - offsetDiff = constOffsetB.value() - constOffsetA.value(); - std::optional maybeStride = - getConstantIntValue(stridesA[iter.index()]); - if (!maybeStride) { - return emitError(rewriter.getUnknownLoc()) - << "no constant stride found at the same index where the " - "offset " - "difference occurs"; + // Edge case for sizesA[0] != sizesB[0]. + if (sizesA[0] != sizesB[0]) { + newOffsets = offsetsA; + newSizes = sizesA; + newStrides = stridesA; + std::optional sizeA = getConstantIntValue(sizesA[0]); + std::optional sizeB = getConstantIntValue(sizesB[0]); + if (!sizeA || !sizeB) return failure(); + newSizes[0] = rewriter.getI64IntegerAttr(sizeA.value() + sizeB.value()); + } else { + // All dims of sizes are the same, so add a new dimension with + // 'offset == 0', 'size == 2' and 'stride == offsetDiff'. + newOffsets.push_back(rewriter.getI64IntegerAttr(0)); + int64_t offsetDiff; + int64_t strideMultiplier; + for (auto iter : llvm::enumerate(llvm::zip(offsetsA, offsetsB))) { + const OpFoldResult &offsetA = std::get<0>(iter.value()); + const OpFoldResult &offsetB = std::get<1>(iter.value()); + newOffsets.push_back(offsetA); + if (offsetA != offsetB) { + std::optional constOffsetA = getConstantIntValue(offsetA); + std::optional constOffsetB = getConstantIntValue(offsetB); + if (!constOffsetA || !constOffsetB) { + return emitError(rewriter.getUnknownLoc()) + << "differing offsets should be constants"; + } + offsetDiff = constOffsetB.value() - constOffsetA.value(); + std::optional maybeStride = + getConstantIntValue(stridesA[iter.index()]); + if (!maybeStride) { + return emitError(rewriter.getUnknownLoc()) + << "no constant stride found at the same index where the " + "offset " + "difference occurs"; + } + strideMultiplier = maybeStride.value(); } - strideMultiplier = maybeStride.value(); } + newSizes.push_back(rewriter.getI64IntegerAttr(2)); + newSizes.append(sizesA.begin(), sizesA.end()); + newStrides.push_back( + rewriter.getI64IntegerAttr(offsetDiff * strideMultiplier)); + newStrides.append(stridesA.begin(), stridesA.end()); } - newSizes.push_back(rewriter.getI64IntegerAttr(2)); - newSizes.append(sizesA.begin(), sizesA.end()); - newStrides.push_back( - rewriter.getI64IntegerAttr(offsetDiff * strideMultiplier)); - newStrides.append(stridesA.begin(), stridesA.end()); - ; } assert(newOffsets.size() == newSizes.size() && "expected same number of new offsets and sizes"); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIETemporaryAllocBufferization.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIETemporaryAllocBufferization.cpp index 480dfda6d..8b56096c3 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIETemporaryAllocBufferization.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIETemporaryAllocBufferization.cpp @@ -48,7 +48,6 @@ LogicalResult bufferizeTemporaryMemrefs(Operation *parentOp) { }); } - // Note: we don't erase allocs/deallocs, we leave this for canonicalization. return success(); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp index ab15b1fe4..0a015a08a 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/AMDAIEDmaUtilsTest.cpp @@ -111,6 +111,8 @@ TEST_F(AccessPatternCombinationTest, CombinableAccessPatterns) { EXPECT_TRUE(checkAreAccessPatternsCombinable({0, 2, 0}, {16, 16, 32}, {32, 64, 1}, {0, 2, 32}, {16, 16, 32}, {32, 64, 1}, 4)); + EXPECT_TRUE(checkAreAccessPatternsCombinable({32, 0}, {64, 64}, {128, 1}, + {96, 0}, {32, 64}, {128, 1}, 4)); // size(A) > size(B) EXPECT_TRUE(checkAreAccessPatternsCombinable( {0, 0, 0}, {2, 16, 32}, {32, 64, 1}, {0, 64}, {16, 32}, {64, 1}, 4)); @@ -168,6 +170,12 @@ TEST_F(AccessPatternCombinationTest, NonCombinableAccessPatterns) { {0, 0}, {16, 32}, {64, 1}, {0, 0, 96}, {2, 16, 32}, {32, 64, 1}, 4)); EXPECT_FALSE(checkAreAccessPatternsCombinable( {0, 0}, {16, 32}, {64, 1}, {0, 1, 0}, {2, 16, 32}, {32, 64, 1}, 4)); + + // size(A) == size(B) Incompatible offset + EXPECT_FALSE(checkAreAccessPatternsCombinable( + {32, 0}, {64, 64}, {128, 1}, {32, 0}, {32, 64}, {128, 1}, 4)); + EXPECT_FALSE(checkAreAccessPatternsCombinable( + {32, 0}, {32, 64}, {128, 1}, {96, 0}, {64, 64}, {128, 1}, 4)); } TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) { @@ -197,6 +205,8 @@ TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) { checkCombineAccessPatterns({8, 0, 0}, {16, 8, 16}, {16, 8, 1}, {40, 0, 0}, {16, 8, 16}, {16, 8, 1}, {0, 8, 0, 0}, {2, 16, 8, 16}, {512, 16, 8, 1}, 4); + checkCombineAccessPatterns({32, 0}, {64, 64}, {128, 1}, {96, 0}, {32, 64}, + {128, 1}, {32, 0}, {96, 64}, {128, 1}, 4); // size(A) > size(B) checkCombineAccessPatterns({0, 0}, {2, 32}, {64, 1}, {128}, {32}, {1}, {0, 0}, {3, 32}, {64, 1}, 3); @@ -255,6 +265,10 @@ TEST_F(AccessPatternCombinationTest, FailCombineAccessPatterns) { {3, 32}, {64, 1}, 3, false); checkCombineAccessPatterns({0}, {32}, {1}, {0, 96}, {2, 32}, {64, 1}, {0, 0}, {3, 32}, {64, 1}, 3, false); + + // size(A) == size(B) Incompatible offset + checkCombineAccessPatterns({32, 0}, {32, 64}, {128, 1}, {96, 0}, {64, 64}, + {128, 1}, {32, 0}, {96, 64}, {128, 1}, 4, false); } } // namespace diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir index fd0a49bc9..25dd958c8 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/combine_strided_ops.mlir @@ -230,6 +230,28 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} // ----- +// CHECK-LABEL: @combine_source_same_dims_diff_sizes +// CHECK: %[[CONNECTION:.+]] = amdaie.connection +// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([] [] [], [0, 0] [128, 64] [128, 1]) +// CHECK-NOT: amdaie.npu.dma_cpy_nd +#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}> +module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { + func.func @combine_source_same_dims_diff_sizes(%arg0: !amdaie.logicalobjectfifo>, %arg1: !amdaie.logicalobjectfifo>) { + amdaie.workgroup { + %0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo>, !amdaie.logicalobjectfifo>) + amdaie.controlcode { + amdaie.npu.dma_cpy_nd %0([] [] [], [0, 0] [32, 64] [128, 1]) + amdaie.npu.dma_cpy_nd %0([] [] [], [32, 0] [64, 64] [128, 1]) + amdaie.npu.dma_cpy_nd %0([] [] [], [96, 0] [32, 64] [128, 1]) + amdaie.end + } + } + return + } +} + +// ----- + // CHECK-LABEL: @combine_source_values // CHECK: %[[CONNECTION:.+]] = amdaie.connection // CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([] [] [], [0, 0, 0, 0] [2, 16, 8, 16] [32, 32, 8, 1]) @@ -332,6 +354,28 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} // ----- +// CHECK-LABEL: @combine_target_same_dims_diff_sizes +// CHECK: %[[CONNECTION:.+]] = amdaie.connection +// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([0, 0] [128, 64] [128, 1], [] [] []) +// CHECK-NOT: amdaie.npu.dma_cpy_nd +#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}> +module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { + func.func @combine_target_same_dims_diff_sizes(%arg0: !amdaie.logicalobjectfifo>, %arg1: !amdaie.logicalobjectfifo>) { + amdaie.workgroup { + %0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo>, !amdaie.logicalobjectfifo>) + amdaie.controlcode { + amdaie.npu.dma_cpy_nd %0([0, 0] [32, 64] [128, 1], [] [] []) + amdaie.npu.dma_cpy_nd %0([32, 0] [64, 64] [128, 1], [] [] []) + amdaie.npu.dma_cpy_nd %0([96, 0] [32, 64] [128, 1], [] [] []) + amdaie.end + } + } + return + } +} + +// ----- + // CHECK-LABEL: @combine_target_diff_dims // CHECK: %[[CONNECTION:.+]] = amdaie.connection // CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([0, 0, 0, 32] [3, 16, 8, 16] [64, 32, 8, 1], [] [] [])