Skip to content

Commit

Permalink
[spirv] Allow dynamic parallel dims in subgroup reduction pipeline (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst authored Aug 19, 2023
1 parent 839375f commit f50d0d9
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 9 deletions.
24 changes: 15 additions & 9 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,12 +1198,19 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv,

auto funcOp = op->getParentOfType<FunctionOpInterface>();
auto walkResult = funcOp.walk([](linalg::LinalgOp op) {
if (op.hasDynamicShape())
return WalkResult::interrupt();
using utils::IteratorType;
SmallVector<IteratorType, 4> kinds = op.getIteratorTypesArray();
SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
for (auto [kind, bound] : llvm::zip_equal(kinds, bounds)) {
if (kind == IteratorType::reduction && ShapedType::isDynamic(bound))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
if (walkResult.wasInterrupted()) {
LLVM_DEBUG(llvm::dbgs() << "failed: dynamic shapes in reduction dims\n");
return failure();
}

// This pipeline eventually generates non-uniform group shuffle ops, which
// requires special capability.
Expand Down Expand Up @@ -1278,8 +1285,8 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv,
targetEnv.getResourceLimits().getMaxComputeWorkgroupInvocations();
int64_t groupSize = dimSize / vectorSize;
if (groupSize > maxWorkgroupSize) {
groupSize = GreatestCommonDivisor({64, uint64_t(groupSize)},
{64, uint64_t(maxWorkgroupSize)})
groupSize = GreatestCommonDivisor(APInt(64, uint64_t(groupSize)),
APInt(64, uint64_t(maxWorkgroupSize)))
.getZExtValue();
}
// Current warp reduction pattern is a two step butterfly warp reduce.
Expand Down Expand Up @@ -1309,8 +1316,8 @@ static LogicalResult setReductionConfig(const spirv::TargetEnv &targetEnv,
int64_t bound = bounds[dim];
if (i == reductionDims.size() - 1)
bound /= vectorSize;
APInt size = GreatestCommonDivisor({64, uint64_t(remaingGroupSize)},
{64, uint64_t(bound)});
APInt size = GreatestCommonDivisor(APInt(64, uint64_t(remaingGroupSize)),
APInt(64, uint64_t(bound)));
reductionTileSizes[dim] = size.getSExtValue();
if (i == reductionDims.size() - 1)
reductionTileSizes[dim] *= vectorSize;
Expand Down Expand Up @@ -1469,11 +1476,10 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
// If there are more than 3 parallel dim try to tile the extra higher level
// dimensions to 1 for extra dimensions.
if (isa<linalg::GenericOp>(linalgOp.getOperation())) {
SmallVector<int64_t> ranges = linalgOp.getStaticLoopRanges();
for (int64_t i = 0, e = workgroupTileSizes.size(); i < e; i++) {
if (workgroupTileSizes[i] != 0)
break;
if (ranges[i] != 1)
if (loopBounds[i] != 1)
workgroupTileSizes[i] = 1;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,111 @@ hal.executable @i4_dequant_matvec {
// CHECK: func.func @i4_dequant_matvec()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>,
#hal.descriptor_set.binding<3, storage_buffer>,
#hal.descriptor_set.binding<4, storage_buffer>
]>
]>

hal.executable @i4_dequant_matvec {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
max_compute_shared_memory_size = 32768,
max_compute_workgroup_invocations = 1024,
max_compute_workgroup_size = [1024, 1024, 1024],
subgroup_size = 64>>
}> {
hal.executable.export @i4_dequant_matvec layout(#pipeline_layout)
builtin.module {
func.func @i4_dequant_matvec() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = hal.interface.constant.load[5] : i32
%6 = hal.interface.constant.load[6] : i32
%7 = hal.interface.constant.load[7] : i32
%8 = hal.interface.constant.load[8] : i32
%9 = arith.index_castui %0 : i32 to index
%10 = arith.index_castui %1 : i32 to index
%11 = arith.index_castui %2 : i32 to index
%12 = arith.extui %3 : i32 to i64
%13 = arith.extui %4 : i32 to i64
%14 = arith.shli %13, %c32_i64 : i64
%15 = arith.ori %12, %14 : i64
%16 = arith.index_castui %15 : i64 to index
%17 = arith.extui %5 : i32 to i64
%18 = arith.extui %6 : i32 to i64
%19 = arith.shli %18, %c32_i64 : i64
%20 = arith.ori %17, %19 : i64
%21 = arith.index_castui %20 : i64 to index
%22 = arith.extui %7 : i32 to i64
%23 = arith.extui %8 : i32 to i64
%24 = arith.shli %23, %c32_i64 : i64
%25 = arith.ori %22, %24 : i64
%26 = arith.index_castui %25 : i64 to index
%27 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
%28 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
%29 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
%30 = flow.dispatch.workload.ordinal %26, 0 : index
%31 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%16) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x86x128xf32>>{%30}
%32 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%21) : !flow.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%30}
%33 = flow.dispatch.tensor.load %27, offsets = [0, 0, 0], sizes = [4096, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>> -> tensor<4096x86x128xi4>
%34 = flow.dispatch.tensor.load %28, offsets = [0, 0], sizes = [4096, 86], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>> -> tensor<4096x86xf32>
%35 = flow.dispatch.tensor.load %29, offsets = [0, 0], sizes = [4096, 86], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>> -> tensor<4096x86xf32>
%36 = flow.dispatch.tensor.load %31, offsets = [0, 0, 0], sizes = [%30, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x86x128xf32>>{%30} -> tensor<?x86x128xf32>
%37 = tensor.empty(%30) : tensor<?x4096xf32>
%38 = tensor.empty() : tensor<4096x86x128xf32>
%39 = linalg.fill ins(%cst : f32) outs(%37 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
%40 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%33, %34, %35 : tensor<4096x86x128xi4>, tensor<4096x86xf32>, tensor<4096x86xf32>) outs(%38 : tensor<4096x86x128xf32>) {
^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
%42 = arith.extui %in : i4 to i32
%43 = arith.uitofp %42 : i32 to f32
%44 = arith.subf %43, %in_1 : f32
%45 = arith.mulf %44, %in_0 : f32
linalg.yield %45 : f32
} -> tensor<4096x86x128xf32>
%41 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
ins(%36, %40 : tensor<?x86x128xf32>, tensor<4096x86x128xf32>) outs(%39 : tensor<?x4096xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%42 = arith.mulf %in, %in_0 : f32
%43 = arith.addf %42, %out : f32
linalg.yield %43 : f32
} -> tensor<?x4096xf32>
flow.dispatch.tensor.store %41, %32, offsets = [0, 0], sizes = [%30, 4096], strides = [1, 1] : tensor<?x4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%30}
return
}
}
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 2, 128]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
// CHECK: func.func @i4_dequant_matvec()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

0 comments on commit f50d0d9

Please sign in to comment.