diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index a4f2849f9658..29552a24465f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -202,6 +202,8 @@ static void addMemRefLoweringPasses(OpPassManager &pm) { // Turn scalar load/store from memrefs into vectorized ones if possible. This // gives better memory access patterns, which is very important for perf. pm.addPass(createSPIRVVectorizeLoadStore()); + // Perform optimizations that need to across the scf.for region boundary. + pm.addNestedPass(createForOpCanonicalizationPass()); // Perform various vector-level cross-op optimizations like load-store // forwarding, shape casting and casting op cancelling. pm.addNestedPass(createOptimizeVectorTransferPass( diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel index b69b1853ddb8..24672b77a74a 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel @@ -53,6 +53,7 @@ iree_lit_test_suite( "pipeline_matmul_cooperative_ops.mlir", "pipeline_matmul_promotion.mlir", "pipeline_matmul_vectorization.mlir", + "pipeline_matvec.mlir", "pipeline_reduction_subgroup.mlir", "pipeline_sub_byte_dequant.mlir", "set_transform_strategy.mlir", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt index 2d496bb1a64b..5f53e555c8cd 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt @@ -49,6 +49,7 @@ iree_lit_test_suite( "pipeline_matmul_cooperative_ops.mlir" "pipeline_matmul_promotion.mlir" "pipeline_matmul_vectorization.mlir" + "pipeline_matvec.mlir" "pipeline_reduction_subgroup.mlir" "pipeline_sub_byte_dequant.mlir" "set_transform_strategy.mlir" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir index 77c0377114ae..f60f5a8d4bdb 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_matvec.mlir @@ -10,7 +10,7 @@ ]> ]> -hal.executable @i4_dequant_matvec { +hal.executable @i4_dequant_matvec_f32 { hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", { spirv.target_env = #spirv.target_env<#spirv.vce, Unknown:IntegratedGPU, #spirv.resource_limits< max_compute_shared_memory_size = 32768, @@ -18,9 +18,9 @@ hal.executable @i4_dequant_matvec { max_compute_workgroup_size = [512, 512, 512], subgroup_size = 64>> }> { - hal.executable.export @i4_dequant_matvec layout(#pipeline_layout) + hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout) builtin.module { - func.func @i4_dequant_matvec() { + func.func @i4_dequant_matvec_f32() { %cst = arith.constant 0.000000e+00 : f32 %10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor> %11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor> @@ -57,10 +57,10 @@ hal.executable @i4_dequant_matvec { // CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info -// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec +// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32 // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] -// CHECK: func.func @i4_dequant_matvec() +// CHECK: func.func @i4_dequant_matvec_f32() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[$CONFIG]] @@ -76,7 +76,7 @@ hal.executable @i4_dequant_matvec { ]> ]> -hal.executable @i4_dequant_matvec { +hal.executable @i4_dequant_matvec_f32 { hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", { spirv.target_env = #spirv.target_env<#spirv.vce, Unknown:IntegratedGPU, #spirv.resource_limits< max_compute_shared_memory_size = 32768, @@ -84,9 +84,9 @@ hal.executable @i4_dequant_matvec { max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64>> }> { - hal.executable.export @i4_dequant_matvec layout(#pipeline_layout) + hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout) builtin.module { - func.func @i4_dequant_matvec() { + func.func @i4_dequant_matvec_f32() { %c32_i64 = arith.constant 32 : i64 %cst = arith.constant 0.000000e+00 : f32 %c4294967296_i64 = arith.constant 4294967296 : i64 @@ -138,10 +138,204 @@ hal.executable @i4_dequant_matvec { // CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info -// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec +// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32 // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK-SAME: workgroup_size = [1024 : index, 1 : index, 1 : index] -// CHECK: func.func @i4_dequant_matvec() +// CHECK: func.func @i4_dequant_matvec_f32() +// CHECK: linalg.generic +// CHECK-SAME: lowering_config = #[[$CONFIG]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer>, + #hal.descriptor_set.binding<3, storage_buffer>, + #hal.descriptor_set.binding<4, storage_buffer> + ]> +]> + +hal.executable @i4_dequant_matvec_f32 { + hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, Unknown:IntegratedGPU, #spirv.resource_limits< + max_compute_shared_memory_size = 32768, + max_compute_workgroup_invocations = 1024, + max_compute_workgroup_size = [1024, 1024, 1024], + subgroup_size = 64>> + }> { + hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout) + builtin.module { + func.func @i4_dequant_matvec_f32() { + %c32_i64 = arith.constant 32 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.constant.load[0] : i32 + %1 = hal.interface.constant.load[1] : i32 + %2 = hal.interface.constant.load[2] : i32 + %3 = hal.interface.constant.load[3] : i32 + %4 = hal.interface.constant.load[4] : i32 + %5 = hal.interface.constant.load[5] : i32 + %6 = hal.interface.constant.load[6] : i32 + %7 = hal.interface.constant.load[7] : i32 + %8 = hal.interface.constant.load[8] : i32 + %9 = arith.index_castui %0 : i32 to index + %10 = arith.index_castui %1 : i32 to index + %11 = arith.index_castui %2 : i32 to index + %12 = arith.extui %3 : i32 to i64 + %13 = arith.extui %4 : i32 to i64 + %14 = arith.shli %13, %c32_i64 : i64 + %15 = arith.ori %12, %14 : i64 + %16 = arith.index_castui %15 : i64 to index + %17 = arith.extui %5 : i32 to i64 + %18 = arith.extui %6 : i32 to i64 + %19 = arith.shli %18, %c32_i64 : i64 + %20 = arith.ori %17, %19 : i64 + %21 = arith.index_castui %20 : i64 to index + %22 = arith.extui %7 : i32 to i64 + %23 = arith.extui %8 : i32 to i64 + %24 = arith.shli %23, %c32_i64 : i64 + %25 = arith.ori %22, %24 : i64 + %26 = arith.index_castui %25 : i64 to index + %27 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%9) flags(ReadOnly) : !flow.dispatch.tensor> + %28 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%10) flags(ReadOnly) : !flow.dispatch.tensor> + %29 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : !flow.dispatch.tensor> + %30 = flow.dispatch.workload.ordinal %26, 0 : index + %31 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%16) flags(ReadOnly) : !flow.dispatch.tensor>{%30} + %32 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%21) : !flow.dispatch.tensor>{%30} + %33 = flow.dispatch.tensor.load %27, offsets = [0, 0, 0], sizes = [4096, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x86x128xi4> + %34 = flow.dispatch.tensor.load %28, offsets = [0, 0], sizes = [4096, 86], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x86xf32> + %35 = flow.dispatch.tensor.load %29, offsets = [0, 0], sizes = [4096, 86], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x86xf32> + %36 = flow.dispatch.tensor.load %31, offsets = [0, 0, 0], sizes = [%30, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor>{%30} -> tensor + %37 = tensor.empty(%30) : tensor + %38 = tensor.empty() : tensor<4096x86x128xf32> + %39 = linalg.fill ins(%cst : f32) outs(%37 : tensor) -> tensor + %40 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%33, %34, %35 : tensor<4096x86x128xi4>, tensor<4096x86xf32>, tensor<4096x86xf32>) outs(%38 : tensor<4096x86x128xf32>) { + ^bb0(%in: i4, %in_0: f32, %in_1: f32, %out: f32): + %42 = arith.extui %in : i4 to i32 + %43 = arith.uitofp %42 : i32 to f32 + %44 = arith.subf %43, %in_1 : f32 + %45 = arith.mulf %44, %in_0 : f32 + linalg.yield %45 : f32 + } -> tensor<4096x86x128xf32> + %41 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + ins(%36, %40 : tensor, tensor<4096x86x128xf32>) outs(%39 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %42 = arith.mulf %in, %in_0 : f32 + %43 = arith.addf %42, %out : f32 + linalg.yield %43 : f32 + } -> tensor + flow.dispatch.tensor.store %41, %32, offsets = [0, 0], sizes = [%30, 4096], strides = [1, 1] : tensor -> !flow.dispatch.tensor>{%30} + return + } + } + } +} + +// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f32 +// CHECK-SAME: translation_info = #[[$TRANSLATION]] +// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] +// CHECK: func.func @i4_dequant_matvec_f32() +// CHECK: linalg.generic +// CHECK-SAME: lowering_config = #[[$CONFIG]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer>, + #hal.descriptor_set.binding<3, storage_buffer>, + #hal.descriptor_set.binding<4, storage_buffer> + ]> +]> +hal.executable @i4_dequant_matvec_f16 { + hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env< + #spirv.vce, + Unknown:IntegratedGPU, + #spirv.resource_limits< + max_compute_shared_memory_size = 32768, + max_compute_workgroup_invocations = 1024, + max_compute_workgroup_size = [1024, 1024, 64], + subgroup_size = 32>> + }> { + hal.executable.export @i4_dequant_matvec_f16 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @i4_dequant_matvec_f16() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x86x128xi4> + %6 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4096, 86, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x86x1xf16> + %7 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [4096, 86, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x86x1xf16> + %8 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [1, 1, 86, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1x86x128xf16> + %9 = tensor.empty() : tensor<1x1x4096xf16> + %10 = tensor.empty() : tensor<4096x86x128xf16> + %11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<1x1x4096xf16>) -> tensor<1x1x4096xf16> + %12 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, 0)>, + affine_map<(d0, d1, d2) -> (d0, d1, 0)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%5, %6, %7 : tensor<4096x86x128xi4>, tensor<4096x86x1xf16>, tensor<4096x86x1xf16>) outs(%10 : tensor<4096x86x128xf16>) { + ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16): + %14 = arith.extui %in : i4 to i32 + %15 = arith.uitofp %14 : i32 to f16 + %16 = arith.subf %15, %in_1 : f16 + %17 = arith.mulf %16, %in_0 : f16 + linalg.yield %17 : f16 + } -> tensor<4096x86x128xf16> + %13 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%8, %12 : tensor<1x1x86x128xf16>, tensor<4096x86x128xf16>) outs(%11 : tensor<1x1x4096xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %14 = arith.mulf %in, %in_0 : f16 + %15 = arith.addf %14, %out : f16 + linalg.yield %15 : f16 + } -> tensor<1x1x4096xf16> + flow.dispatch.tensor.store %13, %4, offsets = [0, 0, 0], sizes = [1, 1, 4096], strides = [1, 1, 1] : tensor<1x1x4096xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK-LABEL: hal.executable.export public @i4_dequant_matvec_f16 +// CHECK-SAME: translation_info = #[[$TRANSLATION]] +// CHECK-SAME: workgroup_size = [32 : index, 1 : index, 1 : index] +// CHECK: func.func @i4_dequant_matvec_f16() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[$CONFIG]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir index c5382deec09b..1d3618c2c488 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_matvec.mlir @@ -9,7 +9,7 @@ #hal.descriptor_set.binding<4, storage_buffer> ]> ]> -hal.executable @i4_dequant_matvec { +hal.executable @i4_dequant_matvec_f32 { hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", { spirv.target_env = #spirv.target_env<#spirv.vce, Unknown:IntegratedGPU, #spirv.resource_limits< max_compute_shared_memory_size = 32768, @@ -17,13 +17,13 @@ hal.executable @i4_dequant_matvec { max_compute_workgroup_size = [512, 512, 512], subgroup_size = 64>> }> { - hal.executable.export @i4_dequant_matvec layout(#pipeline_layout) { + hal.executable.export @i4_dequant_matvec_f32 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device): %x, %y, %z = flow.dispatch.workgroup_count_from_slice hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @i4_dequant_matvec() { + func.func @i4_dequant_matvec_f32() { %cst = arith.constant 0.000000e+00 : f32 %10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : !flow.dispatch.tensor> %11 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : !flow.dispatch.tensor> @@ -58,7 +58,7 @@ hal.executable @i4_dequant_matvec { } } -// CHECK-LABEL: func.func @i4_dequant_matvec() +// CHECK-LABEL: func.func @i4_dequant_matvec_f32() // CHECK: %[[FOR:.+]] = scf.for %arg0 = %c0 to %c86 step %c2 iter_args({{.+}}) -> (vector<1x4xf32>) // CHECK: %[[READ0:.+]] = vector.transfer_read {{.+}} : memref<4096x86x128xi4, #hal.descriptor_type>, vector<4xi4> diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir new file mode 100644 index 000000000000..7c73753093e1 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir @@ -0,0 +1,102 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-codegen-linalg-to-spirv-pipeline)))' %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer>, + #hal.descriptor_set.binding<3, storage_buffer>, + #hal.descriptor_set.binding<4, storage_buffer> + ]> +]> +hal.executable @i4_dequant_matvec_f16 { + hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", { + spirv.target_env = #spirv.target_env<#spirv.vce, Unknown:IntegratedGPU, #spirv.resource_limits< + max_compute_shared_memory_size = 32768, + max_compute_workgroup_invocations = 1024, + max_compute_workgroup_size = [1024, 1024, 64], + subgroup_size = 32>> + }> { + hal.executable.export @i4_dequant_matvec_f16 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @i4_dequant_matvec_f16() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %4 = hal.interface.binding.subspan set(0) binding(4) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4096, 86, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x86x128xi4> + %6 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [4096, 86, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x86x1xf16> + %7 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [4096, 86, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<4096x86x1xf16> + %8 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [1, 1, 86, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1x86x128xf16> + %9 = tensor.empty() : tensor<1x1x4096xf16> + %10 = tensor.empty() : tensor<4096x86x128xf16> + %11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<1x1x4096xf16>) -> tensor<1x1x4096xf16> + %12 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, 0)>, + affine_map<(d0, d1, d2) -> (d0, d1, 0)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%5, %6, %7 : tensor<4096x86x128xi4>, tensor<4096x86x1xf16>, tensor<4096x86x1xf16>) outs(%10 : tensor<4096x86x128xf16>) { + ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16): + %14 = arith.extui %in : i4 to i32 + %15 = arith.uitofp %14 : i32 to f16 + %16 = arith.subf %15, %in_1 : f16 + %17 = arith.mulf %16, %in_0 : f16 + linalg.yield %17 : f16 + } -> tensor<4096x86x128xf16> + %13 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%8, %12 : tensor<1x1x86x128xf16>, tensor<4096x86x128xf16>) outs(%11 : tensor<1x1x4096xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %14 = arith.mulf %in, %in_0 : f16 + %15 = arith.addf %14, %out : f16 + linalg.yield %15 : f16 + } -> tensor<1x1x4096xf16> + flow.dispatch.tensor.store %13, %4, offsets = [0, 0, 0], sizes = [1, 1, 4096], strides = [1, 1, 1] : tensor<1x1x4096xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// CHECK-LABEL: spirv.func @i4_dequant_matvec_f16() + +// CHECK: %[[C15:.+]] = spirv.Constant 15 : i32 + +// CHECK: spirv.mlir.loop + +// Load the quantized weight and get 8xi4 out of it. +// CHECK: spirv.Load "StorageBuffer" %{{.+}} : i32 +// CHECK-COUNT-8: spirv.BitwiseAnd %{{.+}}, %[[C15]] : i32 + +// CHECK-COUNT-2: spirv.ConvertUToF %{{.+}} : vector<4xi32> to vector<4xf16> +// CHECK-COUNT-2: spirv.FSub %{{.+}}, %{{.+}} : vector<4xf16> +// CHECK-COUNT-4: spirv.FMul %{{.+}}, %{{.+}} : vector<4xf16> +// CHECK-COUNT-2: spirv.FAdd %{{.+}}, %{{.+}} : vector<4xf16> +// CHECK-COUNT-2: spirv.Bitcast %{{.+}} : vector<4xf16> to vector<2xf32> +// CHECK-COUNT-2: spirv.VectorShuffle {{.+}} : vector<2xf32> -> vector<4xf32> + +// CHECK: spirv.mlir.merge + +// CHECK: %[[LD:.+]] = spirv.Load "Function" %4 : vector<4xf32> +// CHECK: %[[VS0:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] %[[LD]] +// CHECK: spirv.Bitcast %[[VS0]] : vector<2xf32> to vector<4xf16> +// CHECK: %[[VS1:.+]] = spirv.VectorShuffle [2 : i32, 3 : i32] %[[LD]] +// CHECK: spirv.Bitcast %[[VS1]] : vector<2xf32> to vector<4xf16> + +// CHECK-COUNT-5: spirv.GroupNonUniformShuffleXor + +// CHECK: spirv.mlir.selection