Skip to content

Commit

Permalink
[GPUVectorDistribute] Distribute vector.step
Browse files Browse the repository at this point in the history
This commit enables distribution
of vector.step via recreating the distributed
constant and its dynamic offsets.

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
  • Loading branch information
manupak committed Oct 28, 2024
1 parent 1aa5825 commit aca550b
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,196 @@ struct DistributeBatchOuterToLayoutConversions final
}
};

struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
using OpDistributionPattern::OpDistributionPattern;

// This is a helper aggregate
// to hold the information about
// a dimension.
// For e.g. : 3x4x2 shape will
// have lengths = [3, 4, 2]
// and strides = [8, 2, 1]
struct DimInfo {
std::optional<Value> dimIdx;
int64_t dimLen;
int64_t dimStride;
};

// This is a helper function to extract the remaining
// dimensions with their original strides once the
// distributed dimensions are extracted out
// threads
// V
// E.g. 3 x 4 x 2
// This will return back remaining dimensions that
// have lengths = [3, 2] and strides = [8, 1]
SmallVector<DimInfo> getRemainingDims(ArrayRef<DimInfo> distributedStrides,
int64_t originalLen) const {
SmallVector<DimInfo> remainingDims;
int64_t currLen = originalLen;
for (const DimInfo &dInfo : distributedStrides) {
if (dInfo.dimStride != 0) {
int64_t dStride = dInfo.dimStride;
int64_t dLen = dInfo.dimLen;
int64_t higherStride = dLen * dStride;
if (higherStride < currLen) {
remainingDims.push_back(
{std::nullopt, currLen / higherStride, higherStride});
}
currLen = dStride;
}
}
remainingDims.push_back({std::nullopt, currLen, 1});
return remainingDims;
}

// This is a helper to extract lengths of all dimensions
SmallVector<int64_t> getLens(ArrayRef<DimInfo> dimInfos) const {
SmallVector<int64_t> lens;
for (const DimInfo &dInfo : dimInfos) {
lens.push_back(dInfo.dimLen);
}
return lens;
}

// Once we are in the realm of remaining dimensions,
// the strides are not packed. This is a helper to
// obtain the packed strides of the remaining dimensions.
// (See above for an example of remaining dimensions under
// getRemainingDims)
SmallVector<int64_t> getPackedStrides(ArrayRef<DimInfo> dims) const {
SmallVector<int64_t> lens = getLens(dims);
int64_t elementCount = ShapedType::getNumElements(lens);
SmallVector<int64_t> packedStrides;
int64_t currStride = elementCount;
for (int64_t len : lens) {
currStride = currStride / len;
packedStrides.push_back(currStride);
}
return packedStrides;
}

// This function emulates the slicing of otherwise large constant
// across threads and subgroups.
VectorValue generateSlicedStep(OpBuilder &builder, Location loc,
ArrayRef<DimInfo> distributedDims,
int64_t distributedLen,
int64_t originalLen) const {
SmallVector<DimInfo> remainingDims =
getRemainingDims(distributedDims, originalLen);
SmallVector<int64_t> remainingPackedStrides =
getPackedStrides(remainingDims);
llvm::reverse(remainingDims);
llvm::reverse(remainingPackedStrides);

SmallVector<APInt> offsets;
offsets.reserve(distributedLen);
// As for a complex example what the following
// maths would achieve:
// wave
// | threads
// V V
// 2 x 3 x 4 x 2 = 0 1 2 .... 48
// say vector.step : vector<48xindex> is to be distributed.
// --------------------------------------------------------
// The the distribution should be as follows:
// wave0:
// t0: 0 1 24 25
// t1: 2 3 26 27
// t2: 4 5 28 29
// t4: 6 7 30 31
//
// wave1:
// t0: 8 9 32 33
// t1: 10 11 34 35
// t2: 12 13 36 37
// t4: 14 15 38 39
// ... etc
//
// So wave0 & t0 value this constant offset that we generate
// below initially. Then followed by thread and subgroup weighted
// addition that is weighted by their stride.
for (size_t i = 0; i < distributedLen; i++) {
int64_t offset = 0;
for (const auto &[dimInfo, packedStride] :
zip(remainingDims, remainingPackedStrides)) {
offset += ((i / packedStride) % dimInfo.dimLen) * dimInfo.dimStride;
}
offsets.push_back(APInt(/*width=*/64, offset));
}
VectorType offsetType =
VectorType::get({distributedLen}, builder.getIndexType());
auto constOffset = builder.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(offsetType, offsets));
Value finalOffset = constOffset;
for (const DimInfo &dimInfo : distributedDims) {
assert(dimInfo.dimIdx.has_value());
if (dimInfo.dimStride != 0) {
auto strideVal =
builder.create<arith::ConstantIndexOp>(loc, dimInfo.dimStride);
auto dimIdxOffsetPerElem = builder.create<arith::MulIOp>(
loc, strideVal, dimInfo.dimIdx.value());
auto dimIdxOffset = builder.create<vector::BroadcastOp>(
loc, offsetType, dimIdxOffsetPerElem);
finalOffset =
builder.create<arith::AddIOp>(loc, finalOffset, dimIdxOffset);
}
}
return cast<VectorValue>(finalOffset);
}

DistributeStep(MLIRContext *context, Value threadId, int64_t subgroupSize)
: OpDistributionPattern(context), threadId(threadId),
subgroupSize(subgroupSize) {}
LogicalResult matchAndRewrite(vector::StepOp stepOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
Location loc = stepOp.getLoc();
VectorValue result = stepOp.getResult();
NestedLayoutAttr resultLayout =
dyn_cast<NestedLayoutAttr>(signature[result]);
if (!resultLayout) {
return rewriter.notifyMatchFailure(
stepOp, "missing nested layout for step op result");
}
SmallVector<Value> subgroupIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
subgroupIndices, threadIndices);
ArrayRef<int64_t> subgroupStrides = resultLayout.getSubgroupStrides();
ArrayRef<int64_t> subgroupLengths = resultLayout.getSubgroupTile();
ArrayRef<int64_t> threadStrides = resultLayout.getThreadStrides();
ArrayRef<int64_t> threadLengths = resultLayout.getThreadTile();
// Step op by definition should be single dimensional.
assert(subgroupIndices.size() == 1);
assert(threadIndices.size() == 1);
assert(subgroupLengths.size() == 1);
assert(threadLengths.size() == 1);
assert(subgroupStrides.size() == 1);
assert(threadStrides.size() == 1);
auto distributedShape = signature[result].getDistributedShape();

int64_t distributedElements = ShapedType::getNumElements(distributedShape);
int64_t originalElements = result.getType().getNumElements();
SmallVector<DimInfo, 2> distributedDims{
{subgroupIndices[0], subgroupLengths[0], subgroupStrides[0]},
{threadIndices[0], threadLengths[0], threadStrides[0]}};
sort(distributedDims, [](const DimInfo &lhs, const DimInfo &rhs) {
return lhs.dimStride > rhs.dimStride;
});
VectorValue slicedStepOp = generateSlicedStep(
rewriter, loc, distributedDims, distributedElements, originalElements);
VectorType finalSlicedStepOpType =
VectorType::get({distributedShape}, result.getType().getElementType());
auto finalSlicedStepOp = rewriter.create<vector::ShapeCastOp>(
loc, finalSlicedStepOpType, slicedStepOp);
replaceOpWithDistributedValues(rewriter, stepOp, {finalSlicedStepOp});
return success();
}

Value threadId;
int64_t subgroupSize;
};

} // namespace

void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
Expand All @@ -690,6 +880,7 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
}

}; // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize --cse --mlir-print-local-scope %s | FileCheck %s

#nested = #iree_vector_ext.nested_layout<
subgroup_tile = [1],
batch_tile = [4],
outer_tile = [1],
thread_tile = [4],
element_tile = [1],

subgroup_strides = [0],
thread_strides = [16]
>

func.func @step_1() -> vector<16xindex> {
%step = vector.step : vector<16xindex>
%stepl = iree_vector_ext.to_layout %step to layout(#nested) : vector<16xindex>
return %stepl : vector<16xindex>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @step_1
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 16) mod 4)>()[%thread_id_x]
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c16 : index
// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<4xindex>
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<4xindex>

// -----

#nested = #iree_vector_ext.nested_layout<
subgroup_tile = [1],
batch_tile = [1],
outer_tile = [3],
thread_tile = [4],
element_tile = [2],

subgroup_strides = [0],
thread_strides = [2]
>

func.func @step_2() -> vector<24xindex> {
%step = vector.step : vector<24xindex>
%stepl = iree_vector_ext.to_layout %step to layout(#nested) : vector<24xindex>
return %stepl : vector<24xindex>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @step_2
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 8, 9, 16, 17]> : vector<6xindex>
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 2) mod 4)>()[%thread_id_x]
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c2 : index
// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<6xindex>
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TID_STRIDEV]], %[[CST]] : vector<6xindex>

// -----

#nested = #iree_vector_ext.nested_layout<
subgroup_tile = [3],
batch_tile = [1],
outer_tile = [2],
thread_tile = [4],
element_tile = [2],

subgroup_strides = [8],
thread_strides = [2]
>

func.func @step_3() -> vector<48xindex> {
%step = vector.step : vector<48xindex>
%stepl = iree_vector_ext.to_layout %step to layout(#nested) : vector<48xindex>
return %stepl : vector<48xindex>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @step_3
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 24, 25]> : vector<4xindex>
// CHECK: %[[WID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 512) mod 3)>()[%thread_id_x]
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 2) mod 4)>()[%thread_id_x]
// CHECK: %[[WID_STRIDE:.+]] = arith.muli %[[WID]], %c8 : index
// CHECK: %[[WID_STRIDEV:.+]] = vector.broadcast %[[WID_STRIDE]] : index to vector<4xindex>
// CHECK: %[[OFFSET0:.+]] = arith.addi %[[WID_STRIDEV]], %[[CST]] : vector<4xindex>
// CHECK: %[[TID_STRIDE:.+]] = arith.muli %[[TID]], %c2 : index
// CHECK: %[[TID_STRIDEV:.+]] = vector.broadcast %[[TID_STRIDE]] : index to vector<4xindex>
// CHECK: %[[OFFSET1:.+]] = arith.addi %[[OFFSET0]], %[[TID_STRIDEV]] : vector<4xindex>

// -----

#nested = #iree_vector_ext.nested_layout<
subgroup_tile = [1],
batch_tile = [1],
outer_tile = [1],
thread_tile = [16],
element_tile = [8],

subgroup_strides = [0],
thread_strides = [1]
>

func.func @step_4() -> vector<128xindex> {
%step = vector.step : vector<128xindex>
%stepl = iree_vector_ext.to_layout %step to layout(#nested) : vector<128xindex>
return %stepl : vector<128xindex>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @step_4
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112]> : vector<8xindex>
// CHECK: %[[TID:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x]
// CHECK: %[[TIDV:.+]] = vector.broadcast %[[TID]] : index to vector<8xindex>
// CHECK: %[[OFFSET:.+]] = arith.addi %[[TIDV]], %[[CST]] : vector<8xindex>

0 comments on commit aca550b

Please sign in to comment.