-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support data tiling + microkernels for batch_matmul #14431
Comments
assigning to @pzread who is looking into this. |
I post an initial result of lowering Similar to the current tiling configuration of Currently we see slightly regressions on T5 and BertLarge but I'll need to dig the root causes. |
Profiled the T5LargeTFBatch32 with single threading (local-sync) on #14542. On the slowest Here is the corresponding input IR segment in T5LargeTFBatch32: #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
%51 = tensor.empty() : tensor<512x512x512xf32>
%52 = linalg.fill ins(%cst_10 : f32) outs(%51 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
%53 = linalg.batch_matmul ins(%collapsed_17, %collapsed_18 : tensor<512x512x64xf32>, tensor<512x64x512xf32>) outs(%52 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
%expanded_19 = tensor.expand_shape %53 [[0, 1], [2], [3]] : tensor<512x512x512xf32> into tensor<32x16x512x512xf32>
%54 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_19, %41 : tensor<32x16x512x512xf32>, tensor<32x16x512x512xf32>) outs(%12 : tensor<32x16x512x512xf32>) {
^bb0(%in: f32 loc("t5large32_stripped_linalg.mlir":3767:10), %in_1336: f32 loc("t5large32_stripped_linalg.mlir":3767:20), %out: f32 loc("t5large32_stripped_linalg.mlir":3767:35)):
%2823 = arith.addf %in, %in_1336 : f32
linalg.yield %2823 : f32
} -> tensor<32x16x512x512xf32>
%55 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%54, %cst_6 : tensor<32x16x512x512xf32>, tensor<32x16x512x512xf32>) outs(%12 : tensor<32x16x512x512xf32>) {
^bb0(%in: f32 loc("t5large32_stripped_linalg.mlir":3772:10), %in_1336: f32 loc("t5large32_stripped_linalg.mlir":3772:20), %out: f32 loc("t5large32_stripped_linalg.mlir":3772:35)):
%2823 = arith.addf %in, %in_1336 : f32
linalg.yield %2823 : f32
} -> tensor<32x16x512x512xf32> |
Nice progress! Could you dump the unpack + generic dispatch? We probably want to do microbenchmark for |
Here is the dispatch of that unpack + generic: #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "cascadelake", cpu_features = "+cmov,+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vnni,+adx,+clflushopt,+clwb,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+pku,+prfchw,+rdrnd,+rdseed,+sahf,+x87,+xsave,+xsavec,+xsaveopt,+xsaves", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = true}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#pipeline_layout = #hal.pipeline.layout<push_constants = 5, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
hal.executable public @forward_dispatch_15 {
hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
hal.executable.export public @forward_dispatch_15_generic_512x512x512_f32 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @forward_dispatch_15_generic_512x512x512_f32() {
%cst = arith.constant 9.99999971E-10 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = hal.interface.constant.load[3] : i32
%4 = hal.interface.constant.load[4] : i32
%5 = arith.index_castui %0 : i32 to index
%6 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
%7 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
%8 = arith.index_castui %3 : i32 to index
%9 = arith.index_castui %4 : i32 to index
%10 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
%11 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%7) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
%12 = flow.dispatch.workload.ordinal %8, 0 : index
%13 = flow.dispatch.workload.ordinal %9, 1 : index
%14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>>{%12, %13}
%15 = flow.dispatch.tensor.load %14, offsets = [0, 0, 0], sizes = [512, %12, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>>{%12, %13} -> tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>
%16 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [512, 512, 512], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
%17 = tensor.empty() : tensor<512x512x512xf32>
%18 = iree_linalg_ext.unset_encoding %15 : tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>> -> tensor<512x?x?xf32>
%extracted_slice = tensor.extract_slice %18[0, 0, 0] [512, 512, 512] [1, 1, 1] : tensor<512x?x?xf32> to tensor<512x512x512xf32>
%19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice, %16 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%17 : tensor<512x512x512xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%20 = arith.addf %in, %in_0 : f32
%21 = arith.addf %20, %cst : f32
linalg.yield %21 : f32
} -> tensor<512x512x512xf32>
flow.dispatch.tensor.store %19, %11, offsets = [0, 0, 0], sizes = [512, 512, 512], strides = [1, 1, 1] : tensor<512x512x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
return
}
}
}
}
} One issue I just found is the leading unit dims (1x16x16xf32) before and after the linalg.transpose (which has no leading unit dims) causes an extra pair of transfer write/read and alloca. It can be cleaned up with // -----// IR Dump After DecomposePackUnPackOps (iree-codegen-decompose-pack-unpack-ops) //----- //
func.func @forward_dispatch_15_generic_512x512x512_f32() {
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c512 = arith.constant 512 : index
%cst = arith.constant 9.99999971E-10 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = arith.index_castui %0 : i32 to index
%4 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
%5 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
%6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
%7 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
%8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index
%9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_z]
%10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_z]
scf.for %arg0 = %9 to %c512 step %10 {
%11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
%12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
scf.for %arg1 = %11 to %c512 step %12 {
%13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg2 = %13 to %c512 step %14 {
%15 = flow.dispatch.tensor.load %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
%16 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
%17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
%18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %16, %17, 0, 0], sizes = [64, 4, 4, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>> -> tensor<64x4x4x16x16xf32>
%19 = flow.dispatch.tensor.load %6, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
%20 = scf.for %arg3 = %c0 to %c64 step %c1 iter_args(%arg4 = %15) -> (tensor<64x64x64xf32>) {
%21 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %arg4) -> (tensor<64x64x64xf32>) {
%22 = scf.for %arg7 = %c0 to %c64 step %c16 iter_args(%arg8 = %arg6) -> (tensor<64x64x64xf32>) {
%extracted_slice = tensor.extract_slice %19[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<64x64x64xf32> to tensor<1x16x16xf32>
%23 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg5)
%24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
%extracted_slice_0 = tensor.extract_slice %18[%arg3, %23, %24, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<64x4x4x16x16xf32> to tensor<1x1x1x16x16xf32>
%extracted_slice_1 = tensor.extract_slice %arg8[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<64x64x64xf32> to tensor<1x16x16xf32>
%extracted_slice_2 = tensor.extract_slice %extracted_slice_0[0, 0, 0, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<1x1x1x16x16xf32> to tensor<16x16xf32>
%25 = tensor.empty() : tensor<16x16xf32>
%transposed = linalg.transpose ins(%extracted_slice_2 : tensor<16x16xf32>) outs(%25 : tensor<16x16xf32>) permutation = [0, 1]
%inserted_slice = tensor.insert_slice %transposed into %extracted_slice_1[0, 0, 0] [1, 16, 16] [1, 1, 1] : tensor<16x16xf32> into tensor<1x16x16xf32>
%26 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice : tensor<1x16x16xf32>) outs(%inserted_slice : tensor<1x16x16xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 64], [1, 16, 16], [0, 0, 0], [0, 0, 0]]>} {
^bb0(%in: f32, %out: f32):
%27 = arith.addf %out, %in : f32
%28 = arith.addf %27, %cst : f32
linalg.yield %28 : f32
} -> tensor<1x16x16xf32>
%inserted_slice_3 = tensor.insert_slice %26 into %arg8[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<1x16x16xf32> into tensor<64x64x64xf32>
scf.yield %inserted_slice_3 : tensor<64x64x64xf32>
}
scf.yield %22 : tensor<64x64x64xf32>
}
scf.yield %21 : tensor<64x64x64xf32>
}
flow.dispatch.tensor.store %20, %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : tensor<64x64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
}
}
}
return
}
// -----// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @forward_dispatch_15_generic_512x512x512_f32() {
%cst = arith.constant dense<9.99999971E-10> : vector<1x16x16xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c512 = arith.constant 512 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = arith.index_castui %0 : i32 to index
%4 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
%5 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
%6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
%7 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
%8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index
%9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_z]
%10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_z]
scf.for %arg0 = %9 to %c512 step %10 {
%11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
%12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
scf.for %arg1 = %11 to %c512 step %12 {
%13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg2 = %13 to %c512 step %14 {
%15 = flow.dispatch.tensor.load %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
%16 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
%17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
%18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %16, %17, 0, 0], sizes = [64, 4, 4, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>> -> tensor<64x4x4x16x16xf32>
%19 = flow.dispatch.tensor.load %6, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
%20 = scf.for %arg3 = %c0 to %c64 step %c1 iter_args(%arg4 = %15) -> (tensor<64x64x64xf32>) {
%21 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %arg4) -> (tensor<64x64x64xf32>) {
%22 = scf.for %arg7 = %c0 to %c64 step %c16 iter_args(%arg8 = %arg6) -> (tensor<64x64x64xf32>) {
%23 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg5)
%24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
%extracted_slice = tensor.extract_slice %arg8[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<64x64x64xf32> to tensor<1x16x16xf32>
%25 = vector.transfer_read %18[%arg3, %23, %24, %c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<64x4x4x16x16xf32>, vector<16x16xf32>
%26 = vector.transfer_write %25, %extracted_slice[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<1x16x16xf32>
%27 = vector.transfer_read %19[%arg3, %arg5, %arg7], %cst_0 {in_bounds = [true, true, true]} : tensor<64x64x64xf32>, vector<1x16x16xf32>
%28 = vector.transfer_read %26[%c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true]} : tensor<1x16x16xf32>, vector<1x16x16xf32>
%29 = arith.addf %28, %27 : vector<1x16x16xf32>
%30 = arith.addf %29, %cst : vector<1x16x16xf32>
%31 = vector.transfer_write %30, %arg8[%arg3, %arg5, %arg7] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, tensor<64x64x64xf32>
scf.yield %31 : tensor<64x64x64xf32>
}
scf.yield %22 : tensor<64x64x64xf32>
}
scf.yield %21 : tensor<64x64x64xf32>
}
flow.dispatch.tensor.store %20, %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : tensor<64x64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
}
}
}
return
}
// -----// IR Dump After OptimizeVectorTransfer (iree-codegen-optimize-vector-transfer) //----- //
func.func @forward_dispatch_15_generic_512x512x512_f32() {
%cst = arith.constant dense<9.99999971E-10> : vector<16x16xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c512 = arith.constant 512 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = arith.index_castui %0 : i32 to index
%4 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
%5 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
%6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
%7 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
%8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index
%9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_z]
%10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_z]
%11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
%12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
%13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg0 = %9 to %c512 step %10 {
scf.for %arg1 = %11 to %c512 step %12 {
%15 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
scf.for %arg2 = %13 to %c512 step %14 {
%16 = flow.dispatch.tensor.load %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
%17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
%18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %15, %17, 0, 0], sizes = [64, 4, 4, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>> -> tensor<64x4x4x16x16xf32>
%19 = flow.dispatch.tensor.load %6, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
%20 = scf.for %arg3 = %c0 to %c64 step %c1 iter_args(%arg4 = %16) -> (tensor<64x64x64xf32>) {
%21 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %arg4) -> (tensor<64x64x64xf32>) {
%22 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg5)
%23 = scf.for %arg7 = %c0 to %c64 step %c16 iter_args(%arg8 = %arg6) -> (tensor<64x64x64xf32>) {
%24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
%25 = vector.transfer_read %18[%arg3, %22, %24, %c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<64x4x4x16x16xf32>, vector<16x16xf32>
%26 = vector.transfer_read %19[%arg3, %arg5, %arg7], %cst_0 {in_bounds = [true, true]} : tensor<64x64x64xf32>, vector<16x16xf32>
%27 = arith.addf %25, %26 : vector<16x16xf32>
%28 = arith.addf %27, %cst : vector<16x16xf32>
%29 = vector.transfer_write %28, %arg8[%arg3, %arg5, %arg7] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<64x64x64xf32>
scf.yield %29 : tensor<64x64x64xf32>
}
scf.yield %23 : tensor<64x64x64xf32>
}
scf.yield %21 : tensor<64x64x64xf32>
}
flow.dispatch.tensor.store %20, %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : tensor<64x64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
}
}
}
return
} |
Here are the steps to reproduce microbenchmarks on the slowest Download the sampled mlir below (verified it is compiled in the same way as in the model): // unpack_generic.mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
func.func @unpack_generic(%arg0: tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>, %arg1: tensor<512x512x512xf32>) -> tensor<512x512x512xf32> {
%cst = arith.constant 9.99999971E-10 : f32
%0 = tensor.empty() : tensor<512x512x512xf32>
%1 = iree_linalg_ext.unset_encoding %arg0 : tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>> -> tensor<512x?x?xf32>
%extracted_slice = tensor.extract_slice %1[0, 0, 0] [512, 512, 512] [1, 1, 1] : tensor<512x?x?xf32> to tensor<512x512x512xf32>
%2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice, %arg1 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%3 = arith.addf %in, %in_0 : f32
%4 = arith.addf %3, %cst : f32
linalg.yield %4 : f32
} -> tensor<512x512x512xf32>
return %2 : tensor<512x512x512xf32>
}
} Checkout #14542 and run (note that iree-compile \
unpack_generic.mlir \
-o 'unpack_generic.vmfb' \
--iree-hal-target-backends=llvm-cpu \
--iree-input-type=none \
--iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu \
--iree-llvmcpu-target-cpu=cascadelake \
--iree-flow-enable-data-tiling \
--iree-llvmcpu-enable-microkernels \
--iree-hal-benchmark-dispatch-repeat-count=48 And benchmark with:
Here is a sample benchmark output (from c2-standard-60):
The full mlir dump after each pass can be found at: https://gist.github.com/pzread/088bcaf67302d814c329ff0b58ef8d9c#file-unpack_generic-dump-mlir The full ASM dump can be found at: |
@bjacob @MaheshRavishankar FYI. The unpack (and fusion) performance looks bad. We might need some input from Benoit about what's the final code would look like. I will study it this week, and will circle back to the issue. |
I have a improved microbenchmark file which is adapted based on what @pzread has. #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
util.global private @__var1 {noinline} = dense<1.000000e+00> : tensor<512x32x32x16x16xf32>
util.global private @__var2 {noinline} = dense<1.000000e+00> : tensor<512x512x512xf32>
util.global private @__var3 {noinline} = dense<2.000000e+00> : tensor<512x512x512xf32>
func.func @unpack_generic() -> tensor<512x512x512xf32> {
%cst = arith.constant 9.99999971E-10 : f32
%0 = tensor.empty() : tensor<512x512x512xf32>
%ptr___var1 = util.global.address @__var1 : !util.ptr<tensor<512x32x32x16x16xf32>>
%1 = util.global.load.indirect %ptr___var1 : !util.ptr<tensor<512x32x32x16x16xf32>> -> tensor<512x32x32x16x16xf32>
%ptr___var2 = util.global.address @__var2 : !util.ptr<tensor<512x512x512xf32>>
%2 = util.global.load.indirect %ptr___var2 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
%unpack = tensor.unpack %1 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %0 : tensor<512x32x32x16x16xf32> -> tensor<512x512x512xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack, %2 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
%5 = arith.addf %4, %cst : f32
linalg.yield %5 : f32
} -> tensor<512x512x512xf32>
return %3 : tensor<512x512x512xf32>
}
func.func @unpack() -> tensor<512x512x512xf32> {
%0 = tensor.empty() : tensor<512x512x512xf32>
%ptr___var1 = util.global.address @__var1 : !util.ptr<tensor<512x32x32x16x16xf32>>
%1 = util.global.load.indirect %ptr___var1 : !util.ptr<tensor<512x32x32x16x16xf32>> -> tensor<512x32x32x16x16xf32>
%ptr___var2 = util.global.address @__var2 : !util.ptr<tensor<512x512x512xf32>>
%2 = util.global.load.indirect %ptr___var2 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
%unpack = tensor.unpack %1 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %0 : tensor<512x32x32x16x16xf32> -> tensor<512x512x512xf32>
return %unpack : tensor<512x512x512xf32>
}
func.func @generic() -> tensor<512x512x512xf32> {
%cst = arith.constant 9.99999971E-10 : f32
%0 = tensor.empty() : tensor<512x512x512xf32>
%ptr___var2 = util.global.address @__var2 : !util.ptr<tensor<512x512x512xf32>>
%1 = util.global.load.indirect %ptr___var2 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
%ptr___var3 = util.global.address @__var3 : !util.ptr<tensor<512x512x512xf32>>
%2 = util.global.load.indirect %ptr___var3 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %2 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
%5 = arith.addf %4, %cst : f32
linalg.yield %5 : f32
} -> tensor<512x512x512xf32>
return %3 : tensor<512x512x512xf32>
}
} To benchmark, run
What I got on local workstation (which has scaled CPU frequency comparing to cloud machine):
The fusion case is not as good as generic op, which matches the result that @pzread shared with me. It is because we flatten dimensions of LinAlg ops in |
Nice find. Incidentally, I am working on improving and generalizing the |
I misread what you were saying. So linearization helps performance... so the question is if you can keep the linearization and still generate good code for unpack + linalg.generic... That seems a bit hard.. but looking at IR might help. |
I think the distribution messes cache. If we disable distribution (with
@pzread maybe we can try
This would give us a sense about if it's helping on single-threaded or not. |
note: the numbers are with IMO, we can just start with single-threaded, and then figure out how to propagate it to multi-threaded. |
I am not sure how we can land it if we dont address the multi-threaded case. |
I meant that we should make sure it is as good as expected in single-threaded; we may figure a decent configuration for multi-threaded later. I can dig deeper tomorrow and Friday. |
I think the batch dim of |
I feel that we now rely on The other possible workaround is to adjust tile sizes for unpack ops in https://github.com/openxla/iree/blob/main/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L1993 which can make sure that a batch is not spanned to different workgroup (i.e., we may set the tile sizes to [1, X, Y] in this case). |
Update on Here is the IR dump: https://gist.githubusercontent.com/hanhanW/f0dc1d9aa0e479136a29411070bebe4a/raw
|
Add pattern the materialize `batch_matmul` with data-tiling encoding to `batch_mmt4d` Tracking issue: #14431
Populate lowering config for `batch_mmt4d` op in CPU backend. Tracking issue: #14431
I decided to do a more complete analysis on two models with Google doc version with better table rendering can be found here Scripts and code to reproduce e2e benchmark results can be found in #14542 (see README and scripts under Tl;drWe benchmarked the batch_matmul data-tiling + ukernel path on BertLargeTF and T5LargeTF models to check performance. The current results show slightly regressions on both models (max at +8.8%). Therefore we did an analysis on results and tried to identify the potential improvements as next steps. E2E Model BenchmarksHere we choose BertLargeTF(Batch1, Batch32) and T5LargeTF(Batch1, Batch32) to measure the e2e performance of batch_mamul data-tiling + ukernel (refer to as “data-tiling” below). The benchmarking was performed on c2-standard-60 (cascadelake) and the numbers are averages of 5 repeated runs. The results are shown in the table below:
There are slightly regressions on both models and all thread numbers but not too much (max at +8.8%). The sections below provide detailed analysis based on 1-thread (local-sync) results. Note that one reason we only have slight regressions in e2e latencies is that batch_matmul doesn’t occupy large portions of the models (only 10.7% run time in T5LargeTFBatch32 and 6% in BertLargeTFBatch32). Breakdown: T5LargeTFBatch32There are 4 batch_matmul dispatches in the baseline module and 2 batch_matmul dispatches in the data-tiling module, due to the deduplication (but 4 unpack dispatches). The run time and characteristics of each dispatch are listed in the table below. The symbol on each row links the relations of dispatches between two modules. Same symbol means they are the same workload on both sides (origin from the same parts of source MLIR). T5LargeTFBatch32 baseline batch_matmul dispatches characteristics
“Full runs” means the number of complete batch_matmul runs (each run may issue multiple dispatch launches due to workgroup distribution) “Artith ops” in a dispatch is estimated from high-level MLIR ops (matmul = 2, addf = 1, etc.) and used to estimated GFlops T5LargeTFBatch32 data-tiling batch_matmul dispatches characteristics
AnalysisThere are 2 types of batch_matmul: 512x512x512x64 (BxNxMxK) and 512x512x64x512 and the both run slower with data-tiling:
Pack and unpack have significant overhead. The 11.07s pack time on 512x512x64x512 looks abnormal and needs a further investigation. We also found that 13.77s unpack time on 512x512x512x64 can be reduced to 8s if Another way to look at the data is that the estimated GFlops on baseline batch_matmul are close or exceed 100 GFlops. The batch_matmul (batch_mmt4d) ukernel also only reaches similar GFlops, which means that the ukernel doesn’t help too much in this case. Next stepsIn this case, here are 4 improvements can be made:
Or if 100 GFlops on the non-data-tiling batch_matmul is considered good enough, we might need to control when to apply data-tiling and skip these cases. Breakdown: BertLargeTFBatch32Similar to the analysis on T5LargeTFBatch32, we first analyze the batch_matmul dispatches: BertLargeTFBatch32 baseline batch_matmul dispatches characteristics
BertLargeTFBatch32 data-tiling batch_matmul dispatches characteristics
AnalysisThere are 2 types of batch_matmul: 512x384x384x64 (BxNxMxK) and 512x384x64x384. The table below shows that on both types data-tiling is slower and majorly due to the pack/unpack overhead. Also in this case there is no generic op fused together to hide the overhead.
Next stepsThis looks like not a data-tiling friendly case as there is no generic op to fuse with pack/unpack. In the MLIR we see that the predecessors and successors of batch_matmul are reshape ops. Before and after the reshape ops are generic ops. A potential improvement is to move pack/unpack across generic ops to pull them in for fusion. But reshape ops between them will need to be handled. MicrobenchmarksWe sampled the heaviest batch_matmul from T5LargeTFBatch32 to further understand the performance on a single batch_matmul. #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
func.func @batch_matmul_generic(%arg0: tensor<512x512x64xf32>, %arg1: tensor<512x64x512xf32>, %arg2: tensor<512x512x512xf32>) -> tensor<512x512x512xf32> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant dense<9.99999971E-10> : tensor<512x512x512xf32>
%0 = tensor.empty() : tensor<512x512x512xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
%2 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x512x64xf32>, tensor<512x64x512xf32>) outs(%1 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2, %arg2, %cst_0 : tensor<512x512x512xf32>, tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32):
%5 = arith.addf %in, %in_1 : f32
%6 = arith.addf %5, %in_2 : f32
linalg.yield %6 : f32
} -> tensor<512x512x512xf32>
return %3 : tensor<512x512x512xf32>
}
} Dispatch is repeated 50 times with
The results match what we saw in the e2e benchmark breakdown. Note that when only benchmarking generic op, by default That inspired us that there might be issues in distribution tiling. So we benchmarked with
ConclusionIn T5LargeTF and BertLargeTF we identified there might be performance issues with distribution-level tiling and in some cases the pack/unpack ops have no generic ops to fuse to hide overhead. Also in baseline some batch_matmul are already running quite well. So the action items could be:
|
Hi @MaheshRavishankar @bjacob I posted an analysis of batch_matmul data-tiling on e2e models: #14431 (comment). I'm happy to know if there are any comments 🙂 |
Great analysis! Lot to unpack there, which I am not going to be able to do right now, but I will come back to this in a couple of weeks. |
Thanks for the analysis. I find it striking that For example let's look at the final table in the above comment. For Already this shows that the data-tiling+ukernel is having the intended effect of enabling a much faster matmul kernel (0.15) vs the baseline (0.22) --- the matmul kernel part is being sped up by ~ 1.5x here. The question becomes, how do we ensure that pack/unpack overhead, and other ops in the workload such as that generic op being fused on the unpack side, don't ruin this 1.5x speedup. These numbers, So my guess is there's something really wrong going on in the codegen for unpack. if you disassemble it, something obvious should show. Next, the All in all, these measurements show that data-tiling+ukernel is having the intended effect of speeding up the matmul kernel, but work is needed on the rest, particularly the unpack, to reap the e2e benefits. Note: we have a ukernel for unpack. We just don't use it outside of VMVX at the moment. We could always give it a try here. We want to avoid it because it will prevent fusions, and unpack is a place where we think that fusions are going to be important. But if the unpack codegen is this bad, that is even more important than fusions. Depends how much you want to unblock this in the short term. Long term, we really should improve codegen so we can keep the fusions and not need a unpack ukernel. |
Couple of high-level points after some skim of the analysis from CheYu
|
Really great analysis! I think something is wrong in pack/unpack codegen. The number of elements to be packed are the same in the example, but one takes 0.63 sec and the other takes 2.07 sec.
|
Closing the issue because the functionality is completed. |
batch_matmul is just a matmul plus batch dimension. We have all utilities about encoding and microkernels, and we can reuse them to enable data tiling + microekernels for batch_matmul. For data tiling part, we'll need to set/unset encodings. For microkernel part, we can tile batch dimension with
1
; it becomes a mmt4d op; we can lower it to microkernel ops.The text was updated successfully, but these errors were encountered: