Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CombineStridedOps] Add a combinable case #839

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
}
if (strideA != strideB) return false;
}

// Don't check the outermost dimension of size at this point.
SmallVector<OpFoldResult> innerSizesA;
SmallVector<OpFoldResult> innerSizesB;
std::copy(sizesA.begin() + 1, sizesA.end(), std::back_inserter(innerSizesA));
std::copy(sizesB.begin() + 1, sizesB.end(), std::back_inserter(innerSizesB));
for (auto &&[sizeA, sizeB] :
llvm::zip(llvm::reverse(sizesA), llvm::reverse(sizesB))) {
llvm::zip(llvm::reverse(innerSizesA), llvm::reverse(innerSizesB))) {
std::optional<int64_t> maybeSizeA = getConstantIntValue(sizeA);
std::optional<int64_t> maybeSizeB = getConstantIntValue(sizeB);
// Handle static and constant value with same int value.
Expand All @@ -71,6 +77,20 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
if (sizeA != sizeB) return false;
}

// Edge case for sizesA[0] != sizesB[0].
if (offsetsB.size() == offsetsA.size() && sizesA[0] != sizesB[0]) {
std::optional<int64_t> constOffsetA = getConstantIntValue(offsetsA[0]);
std::optional<int64_t> constSizeA = getConstantIntValue(sizesA[0]);
std::optional<int64_t> constOffsetB = getConstantIntValue(offsetsB[0]);
std::optional<int64_t> constSizeB = getConstantIntValue(sizesB[0]);
if (constOffsetA && constOffsetB && constSizeA && constSizeB) {
int64_t offsetDiff = constOffsetB.value() - constOffsetA.value();
if (constSizeA.value() != offsetDiff) return false;
} else {
return false;
}
}

bool foundDiff{false};
for (auto iter : llvm::enumerate(
llvm::zip(llvm::reverse(offsetsA), llvm::reverse(offsetsB)))) {
Expand Down Expand Up @@ -169,40 +189,50 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
if (!size) return failure();
newSizes[0] = rewriter.getI64IntegerAttr(size.value() + 1);
} else {
// Sizes are the same, so add a new dimension with 'offset == 0', 'size ==
// 2' and 'stride == offsetDiff'.
newOffsets.push_back(rewriter.getI64IntegerAttr(0));
int64_t offsetDiff;
int64_t strideMultiplier;
for (auto iter : llvm::enumerate(llvm::zip(offsetsA, offsetsB))) {
const OpFoldResult &offsetA = std::get<0>(iter.value());
const OpFoldResult &offsetB = std::get<1>(iter.value());
newOffsets.push_back(offsetA);
if (offsetA != offsetB) {
std::optional<int64_t> constOffsetA = getConstantIntValue(offsetA);
std::optional<int64_t> constOffsetB = getConstantIntValue(offsetB);
if (!constOffsetA || !constOffsetB) {
return emitError(rewriter.getUnknownLoc())
<< "differing offsets should be constants";
}
offsetDiff = constOffsetB.value() - constOffsetA.value();
std::optional<int64_t> maybeStride =
getConstantIntValue(stridesA[iter.index()]);
if (!maybeStride) {
return emitError(rewriter.getUnknownLoc())
<< "no constant stride found at the same index where the "
"offset "
"difference occurs";
// Edge case for sizesA[0] != sizesB[0].
if (sizesA[0] != sizesB[0]) {
newOffsets = offsetsA;
newSizes = sizesA;
newStrides = stridesA;
std::optional<int64_t> sizeA = getConstantIntValue(sizesA[0]);
std::optional<int64_t> sizeB = getConstantIntValue(sizesB[0]);
if (!sizeA || !sizeB) return failure();
newSizes[0] = rewriter.getI64IntegerAttr(sizeA.value() + sizeB.value());
} else {
// All dims of sizes are the same, so add a new dimension with
// 'offset == 0', 'size == 2' and 'stride == offsetDiff'.
newOffsets.push_back(rewriter.getI64IntegerAttr(0));
int64_t offsetDiff;
int64_t strideMultiplier;
for (auto iter : llvm::enumerate(llvm::zip(offsetsA, offsetsB))) {
const OpFoldResult &offsetA = std::get<0>(iter.value());
const OpFoldResult &offsetB = std::get<1>(iter.value());
newOffsets.push_back(offsetA);
if (offsetA != offsetB) {
std::optional<int64_t> constOffsetA = getConstantIntValue(offsetA);
std::optional<int64_t> constOffsetB = getConstantIntValue(offsetB);
if (!constOffsetA || !constOffsetB) {
return emitError(rewriter.getUnknownLoc())
<< "differing offsets should be constants";
}
offsetDiff = constOffsetB.value() - constOffsetA.value();
std::optional<int64_t> maybeStride =
getConstantIntValue(stridesA[iter.index()]);
if (!maybeStride) {
return emitError(rewriter.getUnknownLoc())
<< "no constant stride found at the same index where the "
"offset "
"difference occurs";
}
strideMultiplier = maybeStride.value();
}
strideMultiplier = maybeStride.value();
}
newSizes.push_back(rewriter.getI64IntegerAttr(2));
newSizes.append(sizesA.begin(), sizesA.end());
newStrides.push_back(
rewriter.getI64IntegerAttr(offsetDiff * strideMultiplier));
newStrides.append(stridesA.begin(), stridesA.end());
}
newSizes.push_back(rewriter.getI64IntegerAttr(2));
newSizes.append(sizesA.begin(), sizesA.end());
newStrides.push_back(
rewriter.getI64IntegerAttr(offsetDiff * strideMultiplier));
newStrides.append(stridesA.begin(), stridesA.end());
;
}
assert(newOffsets.size() == newSizes.size() &&
"expected same number of new offsets and sizes");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ LogicalResult bufferizeTemporaryMemrefs(Operation *parentOp) {
});
}


// Note: we don't erase allocs/deallocs, we leave this for canonicalization.

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ TEST_F(AccessPatternCombinationTest, CombinableAccessPatterns) {
EXPECT_TRUE(checkAreAccessPatternsCombinable({0, 2, 0}, {16, 16, 32},
{32, 64, 1}, {0, 2, 32},
{16, 16, 32}, {32, 64, 1}, 4));
EXPECT_TRUE(checkAreAccessPatternsCombinable({32, 0}, {64, 64}, {128, 1},
{96, 0}, {32, 64}, {128, 1}, 4));
// size(A) > size(B)
EXPECT_TRUE(checkAreAccessPatternsCombinable(
{0, 0, 0}, {2, 16, 32}, {32, 64, 1}, {0, 64}, {16, 32}, {64, 1}, 4));
Expand Down Expand Up @@ -168,6 +170,12 @@ TEST_F(AccessPatternCombinationTest, NonCombinableAccessPatterns) {
{0, 0}, {16, 32}, {64, 1}, {0, 0, 96}, {2, 16, 32}, {32, 64, 1}, 4));
EXPECT_FALSE(checkAreAccessPatternsCombinable(
{0, 0}, {16, 32}, {64, 1}, {0, 1, 0}, {2, 16, 32}, {32, 64, 1}, 4));

// size(A) == size(B) Incompatible offset
EXPECT_FALSE(checkAreAccessPatternsCombinable(
{32, 0}, {64, 64}, {128, 1}, {32, 0}, {32, 64}, {128, 1}, 4));
EXPECT_FALSE(checkAreAccessPatternsCombinable(
{32, 0}, {32, 64}, {128, 1}, {96, 0}, {64, 64}, {128, 1}, 4));
}

TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) {
Expand Down Expand Up @@ -197,6 +205,8 @@ TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) {
checkCombineAccessPatterns({8, 0, 0}, {16, 8, 16}, {16, 8, 1}, {40, 0, 0},
{16, 8, 16}, {16, 8, 1}, {0, 8, 0, 0},
{2, 16, 8, 16}, {512, 16, 8, 1}, 4);
checkCombineAccessPatterns({32, 0}, {64, 64}, {128, 1}, {96, 0}, {32, 64},
{128, 1}, {32, 0}, {96, 64}, {128, 1}, 4);
// size(A) > size(B)
checkCombineAccessPatterns({0, 0}, {2, 32}, {64, 1}, {128}, {32}, {1}, {0, 0},
{3, 32}, {64, 1}, 3);
Expand Down Expand Up @@ -255,6 +265,10 @@ TEST_F(AccessPatternCombinationTest, FailCombineAccessPatterns) {
{3, 32}, {64, 1}, 3, false);
checkCombineAccessPatterns({0}, {32}, {1}, {0, 96}, {2, 32}, {64, 1}, {0, 0},
{3, 32}, {64, 1}, 3, false);

// size(A) == size(B) Incompatible offset
checkCombineAccessPatterns({32, 0}, {32, 64}, {128, 1}, {96, 0}, {64, 64},
{128, 1}, {32, 0}, {96, 64}, {128, 1}, 4, false);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,28 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}

// -----

// CHECK-LABEL: @combine_source_same_dims_diff_sizes
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([] [] [], [0, 0] [128, 64] [128, 1])
// CHECK-NOT: amdaie.npu.dma_cpy_nd
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @combine_source_same_dims_diff_sizes(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, %arg1: !amdaie.logicalobjectfifo<memref<128x128xi32>>) {
amdaie.workgroup {
%0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<128x128xi32>>)
amdaie.controlcode {
amdaie.npu.dma_cpy_nd %0([] [] [], [0, 0] [32, 64] [128, 1])
amdaie.npu.dma_cpy_nd %0([] [] [], [32, 0] [64, 64] [128, 1])
amdaie.npu.dma_cpy_nd %0([] [] [], [96, 0] [32, 64] [128, 1])
amdaie.end
}
}
return
}
}

// -----

// CHECK-LABEL: @combine_source_values
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([] [] [], [0, 0, 0, 0] [2, 16, 8, 16] [32, 32, 8, 1])
Expand Down Expand Up @@ -332,6 +354,28 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}

// -----

// CHECK-LABEL: @combine_target_same_dims_diff_sizes
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([0, 0] [128, 64] [128, 1], [] [] [])
// CHECK-NOT: amdaie.npu.dma_cpy_nd
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @combine_target_same_dims_diff_sizes(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, %arg1: !amdaie.logicalobjectfifo<memref<128x128xi32>>) {
amdaie.workgroup {
%0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<128x128xi32>>)
amdaie.controlcode {
amdaie.npu.dma_cpy_nd %0([0, 0] [32, 64] [128, 1], [] [] [])
amdaie.npu.dma_cpy_nd %0([32, 0] [64, 64] [128, 1], [] [] [])
amdaie.npu.dma_cpy_nd %0([96, 0] [32, 64] [128, 1], [] [] [])
amdaie.end
}
}
return
}
}

// -----

// CHECK-LABEL: @combine_target_diff_dims
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([0, 0, 0, 32] [3, 16, 8, 16] [64, 32, 8, 1], [] [] [])
Expand Down
Loading