Skip to content

Commit

Permalink
[spirv] Drop leading unit dim in f16 fused dequant matvec dispatch (#…
Browse files Browse the repository at this point in the history
…14752)

For f16 dequant+matvec, we may see each thread handling 8xf16
and have loop carried values with shape like vector<1x8xf16>.
Run ForOpCanonicalizationPass before breaking down large vectors
to make sure we actually drop the unit dim there to make
the following steps work.

Fixes #14740
  • Loading branch information
antiagainst authored Aug 19, 2023
1 parent f50d0d9 commit 5bfc37a
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 14 deletions.
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ static void addMemRefLoweringPasses(OpPassManager &pm) {
// Turn scalar load/store from memrefs into vectorized ones if possible. This
// gives better memory access patterns, which is very important for perf.
pm.addPass(createSPIRVVectorizeLoadStore());
// Perform optimizations that need to across the scf.for region boundary.
pm.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ iree_lit_test_suite(
"pipeline_matmul_cooperative_ops.mlir",
"pipeline_matmul_promotion.mlir",
"pipeline_matmul_vectorization.mlir",
"pipeline_matvec.mlir",
"pipeline_reduction_subgroup.mlir",
"pipeline_sub_byte_dequant.mlir",
"set_transform_strategy.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ iree_lit_test_suite(
"pipeline_matmul_cooperative_ops.mlir"
"pipeline_matmul_promotion.mlir"
"pipeline_matmul_vectorization.mlir"
"pipeline_matvec.mlir"
"pipeline_reduction_subgroup.mlir"
"pipeline_sub_byte_dequant.mlir"
"set_transform_strategy.mlir"
Expand Down
214 changes: 204 additions & 10 deletions compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
]>
]>

hal.executable @i4_dequant_matvec {
hal.executable @i4_dequant_matvec_f32 {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
max_compute_shared_memory_size = 32768,
max_compute_workgroup_invocations = 512,
max_compute_workgroup_size = [512, 512, 512],
subgroup_size = 64>>
}> {
hal.executable.export @i4_dequant_matvec layout(#pipeline_layout)
hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout)
builtin.module {
func.func @i4_dequant_matvec() {
func.func @i4_dequant_matvec_f32() {
%cst = arith.constant 0.000000e+00 : f32
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
Expand Down Expand Up @@ -57,10 +57,10 @@ hal.executable @i4_dequant_matvec {

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1], [0, 2, 128]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
// CHECK: func.func @i4_dequant_matvec()
// CHECK: func.func @i4_dequant_matvec_f32()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

Expand All @@ -76,17 +76,17 @@ hal.executable @i4_dequant_matvec {
]>
]>

hal.executable @i4_dequant_matvec {
hal.executable @i4_dequant_matvec_f32 {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
max_compute_shared_memory_size = 32768,
max_compute_workgroup_invocations = 1024,
max_compute_workgroup_size = [1024, 1024, 1024],
subgroup_size = 64>>
}> {
hal.executable.export @i4_dequant_matvec layout(#pipeline_layout)
hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout)
builtin.module {
func.func @i4_dequant_matvec() {
func.func @i4_dequant_matvec_f32() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%c4294967296_i64 = arith.constant 4294967296 : i64
Expand Down Expand Up @@ -138,10 +138,204 @@ hal.executable @i4_dequant_matvec {

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 1], [0, 0, 0, 32, 128]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [1024 : index, 1 : index, 1 : index]
// CHECK: func.func @i4_dequant_matvec()
// CHECK: func.func @i4_dequant_matvec_f32()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>,
#hal.descriptor_set.binding<3, storage_buffer>,
#hal.descriptor_set.binding<4, storage_buffer>
]>
]>

hal.executable @i4_dequant_matvec_f32 {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
max_compute_shared_memory_size = 32768,
max_compute_workgroup_invocations = 1024,
max_compute_workgroup_size = [1024, 1024, 1024],
subgroup_size = 64>>
}> {
hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout)
builtin.module {
func.func @i4_dequant_matvec_f32() {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = hal.interface.constant.load[5] : i32
%6 = hal.interface.constant.load[6] : i32
%7 = hal.interface.constant.load[7] : i32
%8 = hal.interface.constant.load[8] : i32
%9 = arith.index_castui %0 : i32 to index
%10 = arith.index_castui %1 : i32 to index
%11 = arith.index_castui %2 : i32 to index
%12 = arith.extui %3 : i32 to i64
%13 = arith.extui %4 : i32 to i64
%14 = arith.shli %13, %c32_i64 : i64
%15 = arith.ori %12, %14 : i64
%16 = arith.index_castui %15 : i64 to index
%17 = arith.extui %5 : i32 to i64
%18 = arith.extui %6 : i32 to i64
%19 = arith.shli %18, %c32_i64 : i64
%20 = arith.ori %17, %19 : i64
%21 = arith.index_castui %20 : i64 to index
%22 = arith.extui %7 : i32 to i64
%23 = arith.extui %8 : i32 to i64
%24 = arith.shli %23, %c32_i64 : i64
%25 = arith.ori %22, %24 : i64
%26 = arith.index_castui %25 : i64 to index
%27 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
%28 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
%29 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
%30 = flow.dispatch.workload.ordinal %26, 0 : index
%31 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%16) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x86x128xf32>>{%30}
%32 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%21) : !flow.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%30}
%33 = flow.dispatch.tensor.load %27, offsets = [0, 0, 0], sizes = [4096, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>> -> tensor<4096x86x128xi4>
%34 = flow.dispatch.tensor.load %28, offsets = [0, 0], sizes = [4096, 86], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>> -> tensor<4096x86xf32>
%35 = flow.dispatch.tensor.load %29, offsets = [0, 0], sizes = [4096, 86], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>> -> tensor<4096x86xf32>
%36 = flow.dispatch.tensor.load %31, offsets = [0, 0, 0], sizes = [%30, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x86x128xf32>>{%30} -> tensor<?x86x128xf32>
%37 = tensor.empty(%30) : tensor<?x4096xf32>
%38 = tensor.empty() : tensor<4096x86x128xf32>
%39 = linalg.fill ins(%cst : f32) outs(%37 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
%40 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%33, %34, %35 : tensor<4096x86x128xi4>, tensor<4096x86xf32>, tensor<4096x86xf32>) outs(%38 : tensor<4096x86x128xf32>) {
^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32):
%42 = arith.extui %in : i4 to i32
%43 = arith.uitofp %42 : i32 to f32
%44 = arith.subf %43, %in_1 : f32
%45 = arith.mulf %44, %in_0 : f32
linalg.yield %45 : f32
} -> tensor<4096x86x128xf32>
%41 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
ins(%36, %40 : tensor<?x86x128xf32>, tensor<4096x86x128xf32>) outs(%39 : tensor<?x4096xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%42 = arith.mulf %in, %in_0 : f32
%43 = arith.addf %42, %out : f32
linalg.yield %43 : f32
} -> tensor<?x4096xf32>
flow.dispatch.tensor.store %41, %32, offsets = [0, 0], sizes = [%30, 4096], strides = [1, 1] : tensor<?x4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%30}
return
}
}
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 2, 128]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
// CHECK: func.func @i4_dequant_matvec_f32()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>,
#hal.descriptor_set.binding<3, storage_buffer>,
#hal.descriptor_set.binding<4, storage_buffer>
]>
]>
hal.executable @i4_dequant_matvec_f16 {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.4, [Shader, Float16, StorageBuffer16BitAccess, GroupNonUniform, GroupNonUniformShuffle], [SPV_KHR_16bit_storage]>,
Unknown:IntegratedGPU,
#spirv.resource_limits<
max_compute_shared_memory_size = 32768,
max_compute_workgroup_invocations = 1024,
max_compute_workgroup_size = [1024, 1024, 64],
subgroup_size = 32>>
}> {
hal.executable.export @i4_dequant_matvec_f16 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @i4_dequant_matvec_f16() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x1xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x86x1xf16>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x1x86x128xf16>>
%4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x1x4096xf16>>
%5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>> -> tensor<4096x86x128xi4>
%6 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4096, 86, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x1xf16>> -> tensor<4096x86x1xf16>
%7 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [4096, 86, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x86x1xf16>> -> tensor<4096x86x1xf16>
%8 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [1, 1, 86, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x86x128xf16>> -> tensor<1x1x86x128xf16>
%9 = tensor.empty() : tensor<1x1x4096xf16>
%10 = tensor.empty() : tensor<4096x86x128xf16>
%11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<1x1x4096xf16>) -> tensor<1x1x4096xf16>
%12 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%5, %6, %7 : tensor<4096x86x128xi4>, tensor<4096x86x1xf16>, tensor<4096x86x1xf16>) outs(%10 : tensor<4096x86x128xf16>) {
^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
%14 = arith.extui %in : i4 to i32
%15 = arith.uitofp %14 : i32 to f16
%16 = arith.subf %15, %in_1 : f16
%17 = arith.mulf %16, %in_0 : f16
linalg.yield %17 : f16
} -> tensor<4096x86x128xf16>
%13 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
ins(%8, %12 : tensor<1x1x86x128xf16>, tensor<4096x86x128xf16>) outs(%11 : tensor<1x1x4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%14 = arith.mulf %in, %in_0 : f16
%15 = arith.addf %14, %out : f16
linalg.yield %15 : f16
} -> tensor<1x1x4096xf16>
flow.dispatch.tensor.store %13, %4, offsets = [0, 0, 0], sizes = [1, 1, 4096], strides = [1, 1, 1] : tensor<1x1x4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<1x1x4096xf16>>
return
}
}
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 1], [0, 0, 0, 2, 128]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f16
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [32 : index, 1 : index, 1 : index]
// CHECK: func.func @i4_dequant_matvec_f16()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
#hal.descriptor_set.binding<4, storage_buffer>
]>
]>
hal.executable @i4_dequant_matvec {
hal.executable @i4_dequant_matvec_f32 {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniform, GroupNonUniformShuffle], []>, Unknown:IntegratedGPU, #spirv.resource_limits<
max_compute_shared_memory_size = 32768,
max_compute_workgroup_invocations = 512,
max_compute_workgroup_size = [512, 512, 512],
subgroup_size = 64>>
}> {
hal.executable.export @i4_dequant_matvec layout(#pipeline_layout) {
hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @i4_dequant_matvec() {
func.func @i4_dequant_matvec_f32() {
%cst = arith.constant 0.000000e+00 : f32
%10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86x128xi4>>
%11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor<readonly:tensor<4096x86xf32>>
Expand Down Expand Up @@ -58,7 +58,7 @@ hal.executable @i4_dequant_matvec {
}
}

// CHECK-LABEL: func.func @i4_dequant_matvec()
// CHECK-LABEL: func.func @i4_dequant_matvec_f32()

// CHECK: %[[FOR:.+]] = scf.for %arg0 = %c0 to %c86 step %c2 iter_args({{.+}}) -> (vector<1x4xf32>)
// CHECK: %[[READ0:.+]] = vector.transfer_read {{.+}} : memref<4096x86x128xi4, #hal.descriptor_type<storage_buffer>>, vector<4xi4>
Expand Down
Loading

0 comments on commit 5bfc37a

Please sign in to comment.