Skip to content

Commit

Permalink
[vulkan] Update default RDNA GPU subgroup size to 32 (iree-org#18207)
Browse files Browse the repository at this point in the history
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
  • Loading branch information
nithinsubbiah authored Aug 19, 2024
1 parent 30040c7 commit 95d5562
Show file tree
Hide file tree
Showing 18 changed files with 42 additions and 61 deletions.
3 changes: 1 addition & 2 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ createGPUTensorAlloc(GPUPromoteSharedMemPattern promoteSharedMemPattern =

// Distributes vector ops to all threads/warps in a GPU workgroup.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertVectorReductionToGPUPass(bool expandSubgroupReduction = true,
bool pickLargestSubroupSize = false);
createConvertVectorReductionToGPUPass(bool expandSubgroupReduction = true);

enum class ReorderWorkgroupsStrategy { None, Swizzle, Transpose };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,8 @@ static Value simpleWarpShuffleFunction(Location loc, OpBuilder &builder,

struct VectorReductionToGPUPass final
: impl::VectorReductionToGPUPassBase<VectorReductionToGPUPass> {
VectorReductionToGPUPass(bool expandSubgroupReduction,
bool pickLargestSubroupSize)
: expandSubgroupReduction(expandSubgroupReduction),
pickLargestSubroupSize(pickLargestSubroupSize) {}
VectorReductionToGPUPass(bool expandSubgroupReduction)
: expandSubgroupReduction(expandSubgroupReduction) {}

void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
Expand Down Expand Up @@ -258,8 +256,7 @@ struct VectorReductionToGPUPass final
// 4. Distribute transfer write operations and propagate vector
// distribution.
{
std::optional<int> subgroupSize =
getGPUSubgroupSize(funcOp, pickLargestSubroupSize);
std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize) {
funcOp->emitOpError("missing subgroup size");
return signalPassFailure();
Expand Down Expand Up @@ -316,16 +313,13 @@ struct VectorReductionToGPUPass final

private:
bool expandSubgroupReduction;
bool pickLargestSubroupSize;
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertVectorReductionToGPUPass(bool expandSubgroupReduction,
bool pickLargestSubroupSize) {
return std::make_unique<VectorReductionToGPUPass>(expandSubgroupReduction,
pickLargestSubroupSize);
createConvertVectorReductionToGPUPass(bool expandSubgroupReduction) {
return std::make_unique<VectorReductionToGPUPass>(expandSubgroupReduction);
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -374,17 +374,14 @@ def IREEGPU_TargetAttr : AttrDef<IREEGPU_Dialect, "Target"> {
return *llvm::max_element(getWgp().getSubgroupSizeChoices().asArrayRef());
}
// Returns the preferred subgroup size. If the target supports multiple
// subgroup sizes, pickLargest controls whether to return the largest one.
// subgroup sizes, pick the smallest one.
//
// AMD RDNA GPUs supports multiple subgroup sizes and the preferred one
// differ given the API--HIP prefers 32 while Vulkan prefers 64.
// TODO: We should be able to force Vulkan side to use 32 consistently
// too with subgroup size control; it might have perf implications though.
int getPreferredSubgroupSize(bool pickLargest=false) const {
if (pickLargest) {
return getMaxSubgroupSize();
}
return getMinSubgroupSize();
// We force Vulkan side to use 32 to be consistent with the HIP backend;
// might have implications on perf.
int getPreferredSubgroupSize() const {
return *llvm::min_element(getWgp().getSubgroupSizeChoices().asArrayRef());
}

// Hardware feature related APIs
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {

// vector -> simt gpu + vector
funcPassManager.addPass(createConvertVectorReductionToGPUPass(
/*expandSubgroupReduction=*/true, /*pickLargestSubgroupSize=*/false));
/*expandSubgroupReduction=*/true));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static LogicalResult setAMDMatmulConfig(linalg::LinalgOp op,
AMDCoopMatrixSoftwarePipelineStoreStage)))
return success();

int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
int subgroupSize = target.getPreferredSubgroupSize();
const std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 8};
std::array<int64_t, 3> threadMNK;
auto inputType =
Expand Down Expand Up @@ -67,7 +67,7 @@ static LogicalResult setAMDMatmulConfig(linalg::LinalgOp op,

LogicalResult setAMDCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp) {
int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
int subgroupSize = target.getPreferredSubgroupSize();

if (auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp)) {
if (isMatmulOrBatchMatmul(linalgOp))
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target,
llvm::dbgs() << ")\n";
});

int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
int subgroupSize = target.getPreferredSubgroupSize();
const int maxBytes = target.getWgp().getMaxWorkgroupMemoryBytes();

// We want a 2-stage pipeline without multi-buffering if the depth is 0 to
Expand Down Expand Up @@ -908,7 +908,7 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op,

// AMD RDNA architectures supports both wave32 and wave64 modes. Prefer to use
// wave32 mode for better performance.
int64_t subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/false);
int64_t subgroupSize = target.getPreferredSubgroupSize();

// Infer if lhs or rhs is transposed to help generate better schedule.
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
Expand Down Expand Up @@ -999,7 +999,7 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op,
static LogicalResult setFftOpConfig(IREE::GPU::TargetAttr target,
IREE::LinalgExt::FftOp op) {
LLVM_DEBUG(llvm::dbgs() << "trying to deduce config as fft...\n");
int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
int subgroupSize = target.getPreferredSubgroupSize();
auto pipeline = CodeGenPipeline::SPIRVBaseDistribute;

std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
Expand Down Expand Up @@ -1121,7 +1121,7 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target,
if (!foundSingleReductionOutput)
return failure();

int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
int subgroupSize = target.getPreferredSubgroupSize();

// Tile all the parallel dimension to 1.
SmallVector<unsigned> partitionedLoops =
Expand Down Expand Up @@ -1281,7 +1281,7 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target,
funcOp, op, TileSizesListType{}, pipeline, workgroupSize);
}

int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
int subgroupSize = target.getPreferredSubgroupSize();
const unsigned loopDepth = partitionedLoops.back() + 1;

// Configurations we need to decide.
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) {

// Handle vector reduction operations specifically.
funcPassManager.addPass(createConvertVectorReductionToGPUPass(
/*expandSubgroupReduction=*/false, /*pickLargestSubgroupSize=*/true));
/*expandSubgroupReduction=*/false));
// Perform normal vector unrolling and lowering transformations. This breaks
// vectors down to native machine size.
addSPIRVVectorLoweringPasses(funcPassManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,7 @@ spirv::ResourceLimitsAttr convertLimits(IREE::GPU::TargetAttr target) {
spirv::ScopeAttr::get(context, spirv::Scope::Subgroup)));
}

// This is mostly to match RDNA behavior on Vulkan--RDNA supports either 32 or
// 64 as subgroup sizes; the default subgroup size is 64.
const int preferredSubgroupSize =
target.getPreferredSubgroupSize(/*pickLargest=*/true);
const int preferredSubgroupSize = target.getPreferredSubgroupSize();

return spirv::ResourceLimitsAttr::get(
context, wgp.getMaxWorkgroupMemoryBytes(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ void SPIRVTileAndPromotePass::runOnOperation() {

SmallVector<int64_t> &workgroupSize = maybeWorkgroupSize.value();
int64_t totalThreads = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
std::optional<int> subgroupSize =
getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize) {
funcOp.emitError("failed to query subgroup size");
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,7 @@ class SPIRVTileToCooperativeOpsPass final
// Then tile and distribute to subgroups.

{
std::optional<int> subgroupSize =
getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize) {
funcOp.emitError("failed to query subgroup size");
return signalPassFailure();
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ LogicalResult verifySPIRVMatmulPromoteVectorizePassPipeline(
LLVM_DEBUG(llvm::dbgs() << "target: " << target << "\n");

auto funcOp = op->getParentOfType<mlir::FunctionOpInterface>();
std::optional<int> subgroupSize =
getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize)
return funcOp->emitError("failed to query subgroup size");
const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup();
Expand Down Expand Up @@ -169,8 +168,7 @@ LogicalResult verifySPIRVCooperativeMatrixVectorizePassPipeline(
LLVM_DEBUG(llvm::dbgs() << "target: " << target << "\n");

auto funcOp = op->getParentOfType<mlir::FunctionOpInterface>();
std::optional<int> subgroupSize =
getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize)
return funcOp->emitError("failed to query subgroup size");
const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func.func @nhwc_conv_pointwise_2x64x64x320() {
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 64, 64], [1, 1, 8, 8], [0, 0, 0, 0, 1, 1, 8], [0, 1, 0, 0]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [8, 8, 1]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 4, 4, 64], [1, 2, 2, 8], [0, 0, 0, 0, 1, 1, 8], [0, 1, 0, 0]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [8, 2, 2]>
// CHECK: func.func @nhwc_conv_pointwise_2x64x64x320()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.conv_2d_nhwc_hwcf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func.func @batch_matmul_f32_16x4096x40x4096() {
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 512, 8, 16]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [2, 64, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 256, 8, 32]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [2, 32, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK: func.func @batch_matmul_f32_16x4096x40x4096()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.batch_matmul
Expand Down Expand Up @@ -53,7 +53,7 @@ func.func @matmul_f16_64x640x320() {
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 128, 32]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 16, 1], {pipeline_depth = 2 : i64, store_stage = 0 : i64}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 2 : i64, store_stage = 0 : i64}>
// CHECK: func.func @matmul_f16_64x640x320()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
Expand Down Expand Up @@ -82,8 +82,8 @@ func.func @batch_matmul_f32_16x4096x40x4096() {
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 256, 16, 32]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [4, 32, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 128, 16, 32]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [4, 16, 1], {pipeline_depth = 2 : i64, store_stage = 0 : i64}>
// CHECK: func.func @batch_matmul_f32_16x4096x40x4096()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.batch_matmul
Expand Down Expand Up @@ -120,8 +120,8 @@ func.func @batch_matmul_f16_1x4096x4096x512() {
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 256, 32]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 128, 32]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK: func.func @batch_matmul_f16_1x4096x4096x512()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.batch_matmul
Expand Down Expand Up @@ -184,8 +184,8 @@ func.func @matmul_multi_reduce_i4xf32xf32() {
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 128, 1, 16]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 1, 16]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
// CHECK: func.func @matmul_multi_reduce_i4xf32xf32()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.generic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,6 @@ func.func @matmul_256x1024x8() {
return
}

// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK-LABEL: func.func @matmul_256x1024x8
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-s
// CHECK-SAME: AMD,
// CHECK-SAME: #spirv.resource_limits<max_compute_shared_memory_size = 65536,
// CHECK-SAME: max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32],
// CHECK-SAME: subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64,
// CHECK-SAME: min_subgroup_size = 32, max_subgroup_size = 64,
// CHECK-SAME: cooperative_matrix_properties_khr = [
// CHECK-SAME: #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>,
// CHECK-SAME: #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>
Expand Down
5 changes: 2 additions & 3 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,15 +998,14 @@ IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op) {
return getCLGPUTarget(op->getContext());
}

std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func,
bool pickLargest) {
std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func) {
// First try to see if there is a subgroup size chosen in the CodeGen pipeline
// configuration.
if (std::optional<int64_t> subgroupSize = getSubgroupSize(func))
return subgroupSize.value();
// Then try to find the subgroup size from the target description.
if (IREE::GPU::TargetAttr target = getGPUTargetAttr(func))
return target.getPreferredSubgroupSize(pickLargest);
return target.getPreferredSubgroupSize();
return std::nullopt;
}

Expand Down
3 changes: 1 addition & 2 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op);
/// Returns the GPU subgroup size chosen for the current CodeGen pipeline if
/// exists; otherwise returns the subgroup size from the GPU target description.
/// Returns std::nullopt if none found.
std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func,
bool pickLargest);
std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func);

} // namespace mlir::iree_compiler

Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_test_compilation_infos(
elif compilation_info_id == CompilationInfoId.SPIRVCooperativeMatrixVectorize:
tile_workgroup_size_pairs = [
TileWorkgroupSizePair(
[[64, 64], [16, 64], [0, 0, 16], [16, 16, 16]], [64, 4, 1]
[[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]], [64, 2, 1]
)
]
elif compilation_info_id == CompilationInfoId.SPIRVVectorizeNVIDIA:
Expand Down

0 comments on commit 95d5562

Please sign in to comment.