Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support data tiling + microkernels for batch_matmul #14431

Closed
hanhanW opened this issue Jul 18, 2023 · 24 comments
Closed

Support data tiling + microkernels for batch_matmul #14431

hanhanW opened this issue Jul 18, 2023 · 24 comments
Assignees
Labels
codegen/llvm LLVM code generation compiler backend codegen Shared code generation infrastructure and dialects

Comments

@hanhanW
Copy link
Contributor

hanhanW commented Jul 18, 2023

batch_matmul is just a matmul plus batch dimension. We have all utilities about encoding and microkernels, and we can reuse them to enable data tiling + microekernels for batch_matmul. For data tiling part, we'll need to set/unset encodings. For microkernel part, we can tile batch dimension with 1; it becomes a mmt4d op; we can lower it to microkernel ops.

### Tasks
- [x] Set encodings on batch_matmul in Flow::SetEncodingPass
- [x] Add materialization configs to CPUMaterializationPass and Materialize batch_matmul to batch_mmt4d
- [x] Add a batch_mmt4d named op to LinAlg
- [x] Create a TileBatchMmt4dToMmt4d pass which tiles and rewrites a batch_mmt4d op to mmt4d_op
@hanhanW
Copy link
Contributor Author

hanhanW commented Jul 28, 2023

assigning to @pzread who is looking into this.

@hanhanW hanhanW added the codegen/llvm LLVM code generation compiler backend label Jul 28, 2023
@pzread
Copy link
Contributor

pzread commented Aug 11, 2023

I post an initial result of lowering batch_matmul -> batch_mmt4d -> mmt4d here: #14542 (comment)

Similar to the current tiling configuration of batch_matmul, the batch dim of batch_mmt4d is tiled to 1 at distribution level.

Currently we see slightly regressions on T5 and BertLarge but I'll need to dig the root causes.

@pzread
Copy link
Contributor

pzread commented Aug 17, 2023

Profiled the T5LargeTFBatch32 with single threading (local-sync) on #14542.

On the slowest batch_matmul, original batch_matmul + generic op takes 17s while the data-tiled + ukernel batch_matamul takes 12s, but the following unpack + generic op takes lots of time (13s). Given pack ops are reasonably fast (LHS 1.22s RHS 1.67s), there might be some codegen issues specially in unpack

image

Here is the corresponding input IR segment in T5LargeTFBatch32:

    #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
    %51 = tensor.empty() : tensor<512x512x512xf32>
    %52 = linalg.fill ins(%cst_10 : f32) outs(%51 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
    %53 = linalg.batch_matmul ins(%collapsed_17, %collapsed_18 : tensor<512x512x64xf32>, tensor<512x64x512xf32>) outs(%52 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
    %expanded_19 = tensor.expand_shape %53 [[0, 1], [2], [3]] : tensor<512x512x512xf32> into tensor<32x16x512x512xf32>
    %54 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_19, %41 : tensor<32x16x512x512xf32>, tensor<32x16x512x512xf32>) outs(%12 : tensor<32x16x512x512xf32>) {
    ^bb0(%in: f32 loc("t5large32_stripped_linalg.mlir":3767:10), %in_1336: f32 loc("t5large32_stripped_linalg.mlir":3767:20), %out: f32 loc("t5large32_stripped_linalg.mlir":3767:35)):
      %2823 = arith.addf %in, %in_1336 : f32
      linalg.yield %2823 : f32
    } -> tensor<32x16x512x512xf32>
    %55 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%54, %cst_6 : tensor<32x16x512x512xf32>, tensor<32x16x512x512xf32>) outs(%12 : tensor<32x16x512x512xf32>) {
    ^bb0(%in: f32 loc("t5large32_stripped_linalg.mlir":3772:10), %in_1336: f32 loc("t5large32_stripped_linalg.mlir":3772:20), %out: f32 loc("t5large32_stripped_linalg.mlir":3772:35)):
      %2823 = arith.addf %in, %in_1336 : f32
      linalg.yield %2823 : f32
    } -> tensor<32x16x512x512xf32>

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 17, 2023

Nice progress! Could you dump the unpack + generic dispatch? We probably want to do microbenchmark for unpack, generic, and unpack + generic. Also, we want to look at pack codegen dumps/asm, so we can make sure that 16x16 transpose trick is kicked in.

@pzread
Copy link
Contributor

pzread commented Aug 17, 2023

Here is the dispatch of that unpack + generic:

#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "cascadelake", cpu_features = "+cmov,+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vnni,+adx,+clflushopt,+clwb,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+pku,+prfchw,+rdrnd,+rdseed,+sahf,+x87,+xsave,+xsavec,+xsaveopt,+xsaves", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 64 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = true}>
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#pipeline_layout = #hal.pipeline.layout<push_constants = 5, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
  hal.executable public @forward_dispatch_15 {
    hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ {
      hal.executable.export public @forward_dispatch_15_generic_512x512x512_f32 ordinal(0) layout(#pipeline_layout) {
      ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
        %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2
        hal.return %x, %y, %z : index, index, index
      }
      builtin.module {
        func.func @forward_dispatch_15_generic_512x512x512_f32() {
          %cst = arith.constant 9.99999971E-10 : f32
          %0 = hal.interface.constant.load[0] : i32
          %1 = hal.interface.constant.load[1] : i32
          %2 = hal.interface.constant.load[2] : i32
          %3 = hal.interface.constant.load[3] : i32
          %4 = hal.interface.constant.load[4] : i32
          %5 = arith.index_castui %0 : i32 to index
          %6 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
          %7 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
          %8 = arith.index_castui %3 : i32 to index
          %9 = arith.index_castui %4 : i32 to index
          %10 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%6) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
          %11 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%7) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
          %12 = flow.dispatch.workload.ordinal %8, 0 : index
          %13 = flow.dispatch.workload.ordinal %9, 1 : index
          %14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%5) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>>{%12, %13}
          %15 = flow.dispatch.tensor.load %14, offsets = [0, 0, 0], sizes = [512, %12, %13], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>>{%12, %13} -> tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>
          %16 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0], sizes = [512, 512, 512], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
          %17 = tensor.empty() : tensor<512x512x512xf32>
          %18 = iree_linalg_ext.unset_encoding %15 : tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>> -> tensor<512x?x?xf32>
          %extracted_slice = tensor.extract_slice %18[0, 0, 0] [512, 512, 512] [1, 1, 1] : tensor<512x?x?xf32> to tensor<512x512x512xf32>
          %19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice, %16 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%17 : tensor<512x512x512xf32>) {
          ^bb0(%in: f32, %in_0: f32, %out: f32):
            %20 = arith.addf %in, %in_0 : f32
            %21 = arith.addf %20, %cst : f32
            linalg.yield %21 : f32
          } -> tensor<512x512x512xf32>
          flow.dispatch.tensor.store %19, %11, offsets = [0, 0, 0], sizes = [512, 512, 512], strides = [1, 1, 1] : tensor<512x512x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
          return
        }
      }
    }
  }
}

One issue I just found is the leading unit dims (1x16x16xf32) before and after the linalg.transpose (which has no leading unit dims) causes an extra pair of transfer write/read and alloca. It can be cleaned up with OptimizeVectorTransfer but the run time didn't really change. The snippet below shows some key steps in the codegen:

// -----// IR Dump After DecomposePackUnPackOps (iree-codegen-decompose-pack-unpack-ops) //----- //
func.func @forward_dispatch_15_generic_512x512x512_f32() {
  %c64 = arith.constant 64 : index
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c16 = arith.constant 16 : index
  %c512 = arith.constant 512 : index
  %cst = arith.constant 9.99999971E-10 : f32
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = arith.index_castui %0 : i32 to index
  %4 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
  %5 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
  %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
  %7 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  %workgroup_count_z = hal.interface.workgroup.count[2] : index
  %9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_z]
  %10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_z]
  scf.for %arg0 = %9 to %c512 step %10 {
    %11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
    %12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
    scf.for %arg1 = %11 to %c512 step %12 {
      %13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
      %14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
      scf.for %arg2 = %13 to %c512 step %14 {
        %15 = flow.dispatch.tensor.load %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
        %16 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
        %17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
        %18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %16, %17, 0, 0], sizes = [64, 4, 4, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>> -> tensor<64x4x4x16x16xf32>
        %19 = flow.dispatch.tensor.load %6, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
        %20 = scf.for %arg3 = %c0 to %c64 step %c1 iter_args(%arg4 = %15) -> (tensor<64x64x64xf32>) {
          %21 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %arg4) -> (tensor<64x64x64xf32>) {
            %22 = scf.for %arg7 = %c0 to %c64 step %c16 iter_args(%arg8 = %arg6) -> (tensor<64x64x64xf32>) {
              %extracted_slice = tensor.extract_slice %19[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<64x64x64xf32> to tensor<1x16x16xf32>
              %23 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg5)
              %24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
              %extracted_slice_0 = tensor.extract_slice %18[%arg3, %23, %24, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<64x4x4x16x16xf32> to tensor<1x1x1x16x16xf32>
              %extracted_slice_1 = tensor.extract_slice %arg8[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<64x64x64xf32> to tensor<1x16x16xf32>
              %extracted_slice_2 = tensor.extract_slice %extracted_slice_0[0, 0, 0, 0, 0] [1, 1, 1, 16, 16] [1, 1, 1, 1, 1] : tensor<1x1x1x16x16xf32> to tensor<16x16xf32>
              %25 = tensor.empty() : tensor<16x16xf32>
              %transposed = linalg.transpose ins(%extracted_slice_2 : tensor<16x16xf32>) outs(%25 : tensor<16x16xf32>) permutation = [0, 1] 
              %inserted_slice = tensor.insert_slice %transposed into %extracted_slice_1[0, 0, 0] [1, 16, 16] [1, 1, 1] : tensor<16x16xf32> into tensor<1x16x16xf32>
              %26 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice : tensor<1x16x16xf32>) outs(%inserted_slice : tensor<1x16x16xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 64], [1, 16, 16], [0, 0, 0], [0, 0, 0]]>} {
              ^bb0(%in: f32, %out: f32):
                %27 = arith.addf %out, %in : f32
                %28 = arith.addf %27, %cst : f32
                linalg.yield %28 : f32
              } -> tensor<1x16x16xf32>
              %inserted_slice_3 = tensor.insert_slice %26 into %arg8[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<1x16x16xf32> into tensor<64x64x64xf32>
              scf.yield %inserted_slice_3 : tensor<64x64x64xf32>
            }
            scf.yield %22 : tensor<64x64x64xf32>
          }
          scf.yield %21 : tensor<64x64x64xf32>
        }
        flow.dispatch.tensor.store %20, %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : tensor<64x64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
      }
    }
  }
  return
}

// -----// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @forward_dispatch_15_generic_512x512x512_f32() {
  %cst = arith.constant dense<9.99999971E-10> : vector<1x16x16xf32>
  %cst_0 = arith.constant 0.000000e+00 : f32
  %c64 = arith.constant 64 : index
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c16 = arith.constant 16 : index
  %c512 = arith.constant 512 : index
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = arith.index_castui %0 : i32 to index
  %4 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
  %5 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
  %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
  %7 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  %workgroup_count_z = hal.interface.workgroup.count[2] : index
  %9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_z]
  %10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_z]
  scf.for %arg0 = %9 to %c512 step %10 {
    %11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
    %12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
    scf.for %arg1 = %11 to %c512 step %12 {
      %13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
      %14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
      scf.for %arg2 = %13 to %c512 step %14 {
        %15 = flow.dispatch.tensor.load %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
        %16 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
        %17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
        %18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %16, %17, 0, 0], sizes = [64, 4, 4, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>> -> tensor<64x4x4x16x16xf32>
        %19 = flow.dispatch.tensor.load %6, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
        %20 = scf.for %arg3 = %c0 to %c64 step %c1 iter_args(%arg4 = %15) -> (tensor<64x64x64xf32>) {
          %21 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %arg4) -> (tensor<64x64x64xf32>) {
            %22 = scf.for %arg7 = %c0 to %c64 step %c16 iter_args(%arg8 = %arg6) -> (tensor<64x64x64xf32>) {
              %23 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg5)
              %24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
              %extracted_slice = tensor.extract_slice %arg8[%arg3, %arg5, %arg7] [1, 16, 16] [1, 1, 1] : tensor<64x64x64xf32> to tensor<1x16x16xf32>
              %25 = vector.transfer_read %18[%arg3, %23, %24, %c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<64x4x4x16x16xf32>, vector<16x16xf32>
              %26 = vector.transfer_write %25, %extracted_slice[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<1x16x16xf32>
              %27 = vector.transfer_read %19[%arg3, %arg5, %arg7], %cst_0 {in_bounds = [true, true, true]} : tensor<64x64x64xf32>, vector<1x16x16xf32>
              %28 = vector.transfer_read %26[%c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true]} : tensor<1x16x16xf32>, vector<1x16x16xf32>
              %29 = arith.addf %28, %27 : vector<1x16x16xf32>
              %30 = arith.addf %29, %cst : vector<1x16x16xf32>
              %31 = vector.transfer_write %30, %arg8[%arg3, %arg5, %arg7] {in_bounds = [true, true, true]} : vector<1x16x16xf32>, tensor<64x64x64xf32>
              scf.yield %31 : tensor<64x64x64xf32>
            }
            scf.yield %22 : tensor<64x64x64xf32>
          }
          scf.yield %21 : tensor<64x64x64xf32>
        }
        flow.dispatch.tensor.store %20, %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : tensor<64x64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
      }
    }
  }
  return
}

// -----// IR Dump After OptimizeVectorTransfer (iree-codegen-optimize-vector-transfer) //----- //
func.func @forward_dispatch_15_generic_512x512x512_f32() {
  %cst = arith.constant dense<9.99999971E-10> : vector<16x16xf32>
  %cst_0 = arith.constant 0.000000e+00 : f32
  %c64 = arith.constant 64 : index
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c16 = arith.constant 16 : index
  %c512 = arith.constant 512 : index
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = hal.interface.constant.load[2] : i32
  %3 = arith.index_castui %0 : i32 to index
  %4 = arith.index_castui %1 {stream.alignment = 268435456 : index, stream.values = [0 : index, 268435456 : index]} : i32 to index
  %5 = arith.index_castui %2 {stream.alignment = 67108864 : index, stream.values = [0 : index, 67108864 : index, 134217728 : index, 805306368 : index, 1543503872 : index, 2080374784 : index]} : i32 to index
  %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>>
  %7 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
  %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %workgroup_id_z = hal.interface.workgroup.id[2] : index
  %workgroup_count_z = hal.interface.workgroup.count[2] : index
  %9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_z]
  %10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_z]
  %11 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
  %12 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
  %13 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
  %14 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
  scf.for %arg0 = %9 to %c512 step %10 {
    scf.for %arg1 = %11 to %c512 step %12 {
      %15 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg1)
      scf.for %arg2 = %13 to %c512 step %14 {
        %16 = flow.dispatch.tensor.load %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
        %17 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg2)
        %18 = flow.dispatch.tensor.load %8, offsets = [%arg0, %15, %17, 0, 0], sizes = [64, 4, 4, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x32x32x16x16xf32>> -> tensor<64x4x4x16x16xf32>
        %19 = flow.dispatch.tensor.load %6, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<512x512x512xf32>> -> tensor<64x64x64xf32>
        %20 = scf.for %arg3 = %c0 to %c64 step %c1 iter_args(%arg4 = %16) -> (tensor<64x64x64xf32>) {
          %21 = scf.for %arg5 = %c0 to %c64 step %c16 iter_args(%arg6 = %arg4) -> (tensor<64x64x64xf32>) {
            %22 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg5)
            %23 = scf.for %arg7 = %c0 to %c64 step %c16 iter_args(%arg8 = %arg6) -> (tensor<64x64x64xf32>) {
              %24 = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%arg7)
              %25 = vector.transfer_read %18[%arg3, %22, %24, %c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<64x4x4x16x16xf32>, vector<16x16xf32>
              %26 = vector.transfer_read %19[%arg3, %arg5, %arg7], %cst_0 {in_bounds = [true, true]} : tensor<64x64x64xf32>, vector<16x16xf32>
              %27 = arith.addf %25, %26 : vector<16x16xf32>
              %28 = arith.addf %27, %cst : vector<16x16xf32>
              %29 = vector.transfer_write %28, %arg8[%arg3, %arg5, %arg7] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<64x64x64xf32>
              scf.yield %29 : tensor<64x64x64xf32>
            }
            scf.yield %23 : tensor<64x64x64xf32>
          }
          scf.yield %21 : tensor<64x64x64xf32>
        }
        flow.dispatch.tensor.store %20, %7, offsets = [%arg0, %arg1, %arg2], sizes = [64, 64, 64], strides = [1, 1, 1] : tensor<64x64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<512x512x512xf32>>
      }
    }
  }
  return
}

@pzread
Copy link
Contributor

pzread commented Aug 23, 2023

Here are the steps to reproduce microbenchmarks on the slowest unpack + generic of a batch_matmul in T5LargeBatch32 model.

Download the sampled mlir below (verified it is compiled in the same way as in the model):

// unpack_generic.mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  func.func @unpack_generic(%arg0: tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>>, %arg1: tensor<512x512x512xf32>) -> tensor<512x512x512xf32> {
    %cst = arith.constant 9.99999971E-10 : f32
    %0 = tensor.empty() : tensor<512x512x512xf32>
    %1 = iree_linalg_ext.unset_encoding %arg0 : tensor<512x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT, original_type = tensor<512x512x512xf32>>> -> tensor<512x?x?xf32>
    %extracted_slice = tensor.extract_slice %1[0, 0, 0] [512, 512, 512] [1, 1, 1] : tensor<512x?x?xf32> to tensor<512x512x512xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice, %arg1 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %3 = arith.addf %in, %in_0 : f32
      %4 = arith.addf %3, %cst : f32
      linalg.yield %4 : f32
    } -> tensor<512x512x512xf32>
    return %2 : tensor<512x512x512xf32>
  }
}

Checkout #14542 and run (note that --iree-hal-benchmark-dispatch-repeat-count=48 is used here to mitigate runtime overhead, and the dispatch will be launched the same times as in T5LargeBatch32 model):

iree-compile \
  unpack_generic.mlir \
  -o 'unpack_generic.vmfb' \
  --iree-hal-target-backends=llvm-cpu \
  --iree-input-type=none \
  --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu \
  --iree-llvmcpu-target-cpu=cascadelake \
  --iree-flow-enable-data-tiling \
  --iree-llvmcpu-enable-microkernels \
  --iree-hal-benchmark-dispatch-repeat-count=48

And benchmark with:

iree-benchmark-module \
  --function=unpack_generic \
  --module=unpack_generic.vmfb \
  --input=512x512x512xf32=0 \
  --input=512x512x512xf32=0 \
  --device_allocator=caching \
  --device=local-sync \
  --benchmark_repetitions=5 \
  --benchmark_min_warmup_time=3.0

Here is a sample benchmark output (from c2-standard-60):

----------------------------------------------------------------------------------------------------------
Benchmark                                                Time             CPU   Iterations UserCounters...
----------------------------------------------------------------------------------------------------------
BM_unpack_generic/process_time/real_time             11794 ms        11792 ms            1 items_per_second=0.0847894/s
BM_unpack_generic/process_time/real_time             11824 ms        11822 ms            1 items_per_second=0.0845757/s
BM_unpack_generic/process_time/real_time             11825 ms        11823 ms            1 items_per_second=0.0845683/s
BM_unpack_generic/process_time/real_time             11913 ms        11911 ms            1 items_per_second=0.0839425/s
BM_unpack_generic/process_time/real_time             11817 ms        11815 ms            1 items_per_second=0.0846255/s
BM_unpack_generic/process_time/real_time_mean        11834 ms        11832 ms            5 items_per_second=0.0845003/s
BM_unpack_generic/process_time/real_time_median      11824 ms        11822 ms            5 items_per_second=0.0845757/s
BM_unpack_generic/process_time/real_time_stddev       45.6 ms         45.6 ms            5 items_per_second=324.294u/s
BM_unpack_generic/process_time/real_time_cv           0.39 %          0.39 %             5 items_per_second=0.38%

The full mlir dump after each pass can be found at: https://gist.github.com/pzread/088bcaf67302d814c329ff0b58ef8d9c#file-unpack_generic-dump-mlir

The full ASM dump can be found at:
https://gist.github.com/pzread/c37837756216df9dac9492e5ee4b362c#file-unpack_generic-dump-s

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 23, 2023

@bjacob @MaheshRavishankar FYI.

The unpack (and fusion) performance looks bad. We might need some input from Benoit about what's the final code would look like. I will study it this week, and will circle back to the issue.

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 23, 2023

I have a improved microbenchmark file which is adapted based on what @pzread has.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  util.global private @__var1 {noinline} = dense<1.000000e+00> : tensor<512x32x32x16x16xf32>
  util.global private @__var2 {noinline} = dense<1.000000e+00> : tensor<512x512x512xf32>
  util.global private @__var3 {noinline} = dense<2.000000e+00> : tensor<512x512x512xf32>
  func.func @unpack_generic() -> tensor<512x512x512xf32> {
    %cst = arith.constant 9.99999971E-10 : f32
    %0 = tensor.empty() : tensor<512x512x512xf32>
    %ptr___var1 = util.global.address @__var1 : !util.ptr<tensor<512x32x32x16x16xf32>>
    %1 = util.global.load.indirect %ptr___var1 : !util.ptr<tensor<512x32x32x16x16xf32>> -> tensor<512x32x32x16x16xf32>
    %ptr___var2 = util.global.address @__var2 : !util.ptr<tensor<512x512x512xf32>>
    %2 = util.global.load.indirect %ptr___var2 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
    %unpack = tensor.unpack %1 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %0 : tensor<512x32x32x16x16xf32> -> tensor<512x512x512xf32>
    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack, %2 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %4 = arith.addf %in, %in_0 : f32
      %5 = arith.addf %4, %cst : f32
      linalg.yield %5 : f32
    } -> tensor<512x512x512xf32>
    return %3 : tensor<512x512x512xf32>
  }
  func.func @unpack() -> tensor<512x512x512xf32> {
    %0 = tensor.empty() : tensor<512x512x512xf32>
    %ptr___var1 = util.global.address @__var1 : !util.ptr<tensor<512x32x32x16x16xf32>>
    %1 = util.global.load.indirect %ptr___var1 : !util.ptr<tensor<512x32x32x16x16xf32>> -> tensor<512x32x32x16x16xf32>
    %ptr___var2 = util.global.address @__var2 : !util.ptr<tensor<512x512x512xf32>>
    %2 = util.global.load.indirect %ptr___var2 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
    %unpack = tensor.unpack %1 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %0 : tensor<512x32x32x16x16xf32> -> tensor<512x512x512xf32>
    return %unpack : tensor<512x512x512xf32>
  }
  func.func @generic() -> tensor<512x512x512xf32> {
    %cst = arith.constant 9.99999971E-10 : f32
    %0 = tensor.empty() : tensor<512x512x512xf32>
    %ptr___var2 = util.global.address @__var2 : !util.ptr<tensor<512x512x512xf32>>
    %1 = util.global.load.indirect %ptr___var2 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
    %ptr___var3 = util.global.address @__var3 : !util.ptr<tensor<512x512x512xf32>>
    %2 = util.global.load.indirect %ptr___var3 : !util.ptr<tensor<512x512x512xf32>> -> tensor<512x512x512xf32>
    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %2 : tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %4 = arith.addf %in, %in_0 : f32
      %5 = arith.addf %4, %cst : f32
      linalg.yield %5 : f32
    } -> tensor<512x512x512xf32>
    return %3 : tensor<512x512x512xf32>
  }
}

To benchmark, run

iree-compile \
  --output-format=vm-bytecode \
  --iree-hal-target-backends=llvm-cpu \
  --iree-llvmcpu-target-triple=x86_64-pc-linux-gnu \
  --iree-llvmcpu-target-cpu=cascadelake \
  ~/repro.mlir -o /tmp/z.vmfb
iree-benchmark-module \
  --module=/tmp/z.vmfb \
  --device=local-sync \
  --device_allocator=caching \
  --benchmark_repetitions=5 \
  --benchmark_min_warmup_time=3.0

What I got on local workstation (which has scaled CPU frequency comparing to cloud machine):

----------------------------------------------------------------------------------------------------------
Benchmark                                                Time             CPU   Iterations UserCounters...
----------------------------------------------------------------------------------------------------------
BM_unpack_generic/process_time/real_time               212 ms          211 ms            3 items_per_second=4.72781/s
BM_unpack_generic/process_time/real_time               212 ms          212 ms            3 items_per_second=4.72668/s
BM_unpack_generic/process_time/real_time               211 ms          211 ms            3 items_per_second=4.72909/s
BM_unpack_generic/process_time/real_time               211 ms          211 ms            3 items_per_second=4.73046/s
BM_unpack_generic/process_time/real_time               211 ms          211 ms            3 items_per_second=4.7332/s
BM_unpack_generic/process_time/real_time_mean          211 ms          211 ms            5 items_per_second=4.72945/s
BM_unpack_generic/process_time/real_time_median        211 ms          211 ms            5 items_per_second=4.72909/s
BM_unpack_generic/process_time/real_time_stddev      0.113 ms        0.109 ms            5 items_per_second=2.52669m/s
BM_unpack_generic/process_time/real_time_cv           0.05 %          0.05 %             5 items_per_second=0.05%
BM_unpack/process_time/real_time                       150 ms          150 ms            5 items_per_second=6.65321/s
BM_unpack/process_time/real_time                       150 ms          150 ms            5 items_per_second=6.65747/s
BM_unpack/process_time/real_time                       150 ms          150 ms            5 items_per_second=6.64605/s
BM_unpack/process_time/real_time                       150 ms          150 ms            5 items_per_second=6.65586/s
BM_unpack/process_time/real_time                       150 ms          150 ms            5 items_per_second=6.66738/s
BM_unpack/process_time/real_time_mean                  150 ms          150 ms            5 items_per_second=6.65599/s
BM_unpack/process_time/real_time_median                150 ms          150 ms            5 items_per_second=6.65586/s
BM_unpack/process_time/real_time_stddev              0.174 ms        0.171 ms            5 items_per_second=7.72399m/s
BM_unpack/process_time/real_time_cv                   0.12 %          0.11 %             5 items_per_second=0.12%
BM_generic/process_time/real_time                      132 ms          132 ms            5 items_per_second=7.57299/s
BM_generic/process_time/real_time                      132 ms          132 ms            5 items_per_second=7.57392/s
BM_generic/process_time/real_time                      132 ms          132 ms            5 items_per_second=7.58005/s
BM_generic/process_time/real_time                      132 ms          132 ms            5 items_per_second=7.58086/s
BM_generic/process_time/real_time                      132 ms          132 ms            5 items_per_second=7.57286/s
BM_generic/process_time/real_time_mean                 132 ms          132 ms            5 items_per_second=7.57614/s
BM_generic/process_time/real_time_median               132 ms          132 ms            5 items_per_second=7.57392/s
BM_generic/process_time/real_time_stddev             0.069 ms        0.075 ms            5 items_per_second=3.97283m/s
BM_generic/process_time/real_time_cv                  0.05 %          0.06 %             5 items_per_second=0.05%

The fusion case is not as good as generic op, which matches the result that @pzread shared with me. It is because we flatten dimensions of LinAlg ops in createCollapseDimensionsPass pass, which introduces different dispatch input to CodeGen. If we disable the pass, it runs at 260 ms. The unpack overheads is hidden in this case. The fusion performance look fine in this case. The next question is that if we can do better in the fusion case? I.e., can we make it as good as linearized input?

@MaheshRavishankar
Copy link
Contributor

The fusion case is not as good as generic op, which matches the result that @pzread shared with me. It is because we flatten dimensions of LinAlg ops in createCollapseDimensionsPass pass, which introduces different dispatch input to CodeGen. If we disable the pass, it runs at 260 ms. The unpack overheads is hidden in this case. The fusion performance look fine in this case. The next question is that if we can do better in the fusion case? I.e., can we make it as good as linearized input?

Nice find. Incidentally, I am working on improving and generalizing the createCollapseDimensionsPass. If you post the input before and after this pass and some explanation of what you dont want happen I can fix as I go along.

@MaheshRavishankar
Copy link
Contributor

The fusion case is not as good as generic op, which matches the result that @pzread shared with me. It is because we flatten dimensions of LinAlg ops in createCollapseDimensionsPass pass, which introduces different dispatch input to CodeGen. If we disable the pass, it runs at 260 ms. The unpack overheads is hidden in this case. The fusion performance look fine in this case. The next question is that if we can do better in the fusion case? I.e., can we make it as good as linearized input?

Nice find. Incidentally, I am working on improving and generalizing the createCollapseDimensionsPass. If you post the input before and after this pass and some explanation of what you dont want happen I can fix as I go along.

I misread what you were saying. So linearization helps performance... so the question is if you can keep the linearization and still generate good code for unpack + linalg.generic... That seems a bit hard.. but looking at IR might help.

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 23, 2023

I think the distribution messes cache. If we disable distribution (with --iree-codegen-llvm-disable-distribution flag, they all go ~1.6x faster. The unpack overheads are free in this case. Both unpack_generic and generic run at ~130 ms.

BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56303/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56782/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56966/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56825/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.57046/s
BM_unpack_generic/process_time/real_time_mean          132 ms          132 ms            5 items_per_second=7.56784/s
BM_unpack_generic/process_time/real_time_median        132 ms          132 ms            5 items_per_second=7.56825/s
BM_unpack_generic/process_time/real_time_stddev      0.051 ms        0.055 ms            5 items_per_second=2.89421m/s
BM_unpack_generic/process_time/real_time_cv           0.04 %          0.04 %             5 items_per_second=0.04%
BM_unpack/process_time/real_time                      97.3 ms         97.3 ms            7 items_per_second=10.2807/s
BM_unpack/process_time/real_time                      97.5 ms         97.5 ms            7 items_per_second=10.2527/s
BM_unpack/process_time/real_time                      97.4 ms         97.4 ms            7 items_per_second=10.2639/s
BM_unpack/process_time/real_time                      97.5 ms         97.4 ms            7 items_per_second=10.2607/s
BM_unpack/process_time/real_time                      97.5 ms         97.5 ms            7 items_per_second=10.2583/s
BM_unpack/process_time/real_time_mean                 97.4 ms         97.4 ms            5 items_per_second=10.2633/s
BM_unpack/process_time/real_time_median               97.5 ms         97.4 ms            5 items_per_second=10.2607/s
BM_unpack/process_time/real_time_stddev              0.100 ms        0.097 ms            5 items_per_second=0.0105549/s
BM_unpack/process_time/real_time_cv                   0.10 %          0.10 %             5 items_per_second=0.10%
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.77294/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.76412/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.75458/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.75781/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.77162/s
BM_generic/process_time/real_time_mean                 129 ms          129 ms            5 items_per_second=7.76421/s
BM_generic/process_time/real_time_median               129 ms          129 ms            5 items_per_second=7.76412/s
BM_generic/process_time/real_time_stddev             0.135 ms        0.135 ms            5 items_per_second=8.13638m/s
BM_generic/process_time/real_time_cv                  0.10 %          0.11 %             5 items_per_second=0.10%

@pzread maybe we can try

  1. add tiling batch dimension to 1 in DecomposeBatchMmt4DOpsPass pass
  2. add iree-codegen-llvm-disable-distribution flag

This would give us a sense about if it's helping on single-threaded or not.

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 23, 2023

note: the numbers are with createCollapseDimensionsPass. That's why I mentioned that distribution model is quite broken in this case. (This is a known issue, I probably observed it before and have scoped it to distribution improvements.)

IMO, we can just start with single-threaded, and then figure out how to propagate it to multi-threaded.

@MaheshRavishankar
Copy link
Contributor

note: the numbers are with createCollapseDimensionsPass. That's why I mentioned that distribution model is quite broken in this case. (This is a known issue, I probably observed it before and have scoped it to distribution improvements.)

IMO, we can just start with single-threaded, and then figure out how to propagate it to multi-threaded.

I am not sure how we can land it if we dont address the multi-threaded case.

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 24, 2023

I meant that we should make sure it is as good as expected in single-threaded; we may figure a decent configuration for multi-threaded later. I can dig deeper tomorrow and Friday.

@pzread
Copy link
Contributor

pzread commented Aug 24, 2023

I think the distribution messes cache. If we disable distribution (with --iree-codegen-llvm-disable-distribution flag, they all go ~1.6x faster. The unpack overheads are free in this case. Both unpack_generic and generic run at ~130 ms.

BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56303/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56782/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56966/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.56825/s
BM_unpack_generic/process_time/real_time               132 ms          132 ms            5 items_per_second=7.57046/s
BM_unpack_generic/process_time/real_time_mean          132 ms          132 ms            5 items_per_second=7.56784/s
BM_unpack_generic/process_time/real_time_median        132 ms          132 ms            5 items_per_second=7.56825/s
BM_unpack_generic/process_time/real_time_stddev      0.051 ms        0.055 ms            5 items_per_second=2.89421m/s
BM_unpack_generic/process_time/real_time_cv           0.04 %          0.04 %             5 items_per_second=0.04%
BM_unpack/process_time/real_time                      97.3 ms         97.3 ms            7 items_per_second=10.2807/s
BM_unpack/process_time/real_time                      97.5 ms         97.5 ms            7 items_per_second=10.2527/s
BM_unpack/process_time/real_time                      97.4 ms         97.4 ms            7 items_per_second=10.2639/s
BM_unpack/process_time/real_time                      97.5 ms         97.4 ms            7 items_per_second=10.2607/s
BM_unpack/process_time/real_time                      97.5 ms         97.5 ms            7 items_per_second=10.2583/s
BM_unpack/process_time/real_time_mean                 97.4 ms         97.4 ms            5 items_per_second=10.2633/s
BM_unpack/process_time/real_time_median               97.5 ms         97.4 ms            5 items_per_second=10.2607/s
BM_unpack/process_time/real_time_stddev              0.100 ms        0.097 ms            5 items_per_second=0.0105549/s
BM_unpack/process_time/real_time_cv                   0.10 %          0.10 %             5 items_per_second=0.10%
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.77294/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.76412/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.75458/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.75781/s
BM_generic/process_time/real_time                      129 ms          129 ms            5 items_per_second=7.77162/s
BM_generic/process_time/real_time_mean                 129 ms          129 ms            5 items_per_second=7.76421/s
BM_generic/process_time/real_time_median               129 ms          129 ms            5 items_per_second=7.76412/s
BM_generic/process_time/real_time_stddev             0.135 ms        0.135 ms            5 items_per_second=8.13638m/s
BM_generic/process_time/real_time_cv                  0.10 %          0.11 %             5 items_per_second=0.10%

@pzread maybe we can try

  1. add tiling batch dimension to 1 in DecomposeBatchMmt4DOpsPass pass
  2. add iree-codegen-llvm-disable-distribution flag

This would give us a sense about if it's helping on single-threaded or not.

I think the batch dim of batch_mmt4d has been tiled to 1. The issue here is the unpack + generic is using the setElementwiseGenericOpRootConfig to get distribution tiling config, which in this case is [64, 64, 64] by default. I tried [1, 128, 128] before and got ~1.3x improvement. So there are something related to the distribution.

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 24, 2023

I feel that we now rely on createCollapseDimensionsPass pass, so we are able to get good distribution tile sizes for element-wise operations. But we really should fix the logic in distribution, so we still can get decent performance even they are not flattened.

The other possible workaround is to adjust tile sizes for unpack ops in https://github.com/openxla/iree/blob/main/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp#L1993 which can make sure that a batch is not spanned to different workgroup (i.e., we may set the tile sizes to [1, X, Y] in this case).

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 24, 2023

Update on how to repro, we need --iree-opt-const-expr-hoisting=false --iree-opt-const-eval=false in my test case. Otherwise the unpack op will be hoisted and be handled by const-eval. With the flags, I still can reproduce the result. The overheads are hidden in the fusion case.

Here is the IR dump: https://gist.githubusercontent.com/hanhanW/f0dc1d9aa0e479136a29411070bebe4a/raw

BM_unpack_generic/process_time/real_time               137 ms          137 ms            5 items_per_second=7.32115/s
BM_unpack_generic/process_time/real_time               136 ms          136 ms            5 items_per_second=7.37188/s
BM_unpack_generic/process_time/real_time               136 ms          136 ms            5 items_per_second=7.35245/s
BM_unpack_generic/process_time/real_time               136 ms          136 ms            5 items_per_second=7.34665/s
BM_unpack_generic/process_time/real_time               136 ms          136 ms            5 items_per_second=7.33468/s
BM_unpack_generic/process_time/real_time_mean          136 ms          136 ms            5 items_per_second=7.34536/s
BM_unpack_generic/process_time/real_time_median        136 ms          136 ms            5 items_per_second=7.34665/s
BM_unpack_generic/process_time/real_time_stddev      0.353 ms        0.351 ms            5 items_per_second=0.0190735/s
BM_unpack_generic/process_time/real_time_cv           0.26 %          0.26 %             5 items_per_second=0.26%
BM_unpack/process_time/real_time                       100 ms          100 ms            7 items_per_second=9.99352/s
BM_unpack/process_time/real_time                      99.7 ms         99.7 ms            7 items_per_second=10.0305/s
BM_unpack/process_time/real_time                      99.6 ms         99.6 ms            7 items_per_second=10.0381/s
BM_unpack/process_time/real_time                      99.5 ms         99.5 ms            7 items_per_second=10.0488/s
BM_unpack/process_time/real_time                      99.4 ms         99.4 ms            7 items_per_second=10.0571/s
BM_unpack/process_time/real_time_mean                 99.7 ms         99.7 ms            5 items_per_second=10.0336/s
BM_unpack/process_time/real_time_median               99.6 ms         99.6 ms            5 items_per_second=10.0381/s
BM_unpack/process_time/real_time_stddev              0.245 ms        0.248 ms            5 items_per_second=0.0246142/s
BM_unpack/process_time/real_time_cv                   0.25 %          0.25 %             5 items_per_second=0.25%
BM_generic/process_time/real_time                      135 ms          135 ms            5 items_per_second=7.38957/s
BM_generic/process_time/real_time                      135 ms          135 ms            5 items_per_second=7.39334/s
BM_generic/process_time/real_time                      135 ms          135 ms            5 items_per_second=7.39145/s
BM_generic/process_time/real_time                      135 ms          135 ms            5 items_per_second=7.41466/s
BM_generic/process_time/real_time                      135 ms          135 ms            5 items_per_second=7.40216/s
BM_generic/process_time/real_time_mean                 135 ms          135 ms            5 items_per_second=7.39824/s
BM_generic/process_time/real_time_median               135 ms          135 ms            5 items_per_second=7.39334/s

pzread pushed a commit that referenced this issue Aug 25, 2023
Add pattern the materialize `batch_matmul` with data-tiling encoding to
`batch_mmt4d`

Tracking issue: #14431
pzread pushed a commit that referenced this issue Aug 25, 2023
Populate lowering config for `batch_mmt4d` op in CPU backend.

Tracking issue: #14431
@pzread
Copy link
Contributor

pzread commented Aug 30, 2023

I decided to do a more complete analysis on two models with batch_matmul: BertLargeTF and T5LargeTF and here are the results and analysis. Hope this will give us the next steps for performance tuning.

Google doc version with better table rendering can be found here

Scripts and code to reproduce e2e benchmark results can be found in #14542 (see README and scripts under e2e_bench/)

Tl;dr

We benchmarked the batch_matmul data-tiling + ukernel path on BertLargeTF and T5LargeTF models to check performance. The current results show slightly regressions on both models (max at +8.8%). Therefore we did an analysis on results and tried to identify the potential improvements as next steps.

E2E Model Benchmarks

Here we choose BertLargeTF(Batch1, Batch32) and T5LargeTF(Batch1, Batch32) to measure the e2e performance of batch_mamul data-tiling + ukernel (refer to as “data-tiling” below).

The benchmarking was performed on c2-standard-60 (cascadelake) and the numbers are averages of 5 repeated runs. The results are shown in the table below:

Model Baseline (@407949) batch_matmul data-tiling + ukernel

1 thread (seconds)

4 threads

8 threads

1 thread

4 threads

8 threads

BertLargeTFBatch1

3.29

0.963

0.513

3.41

(+3.6%)

0.99

(+3.22%)

0.52

(+2.5%)

BertLargeTFBatch32

80.09

20.55

11.54

83.25

(+4%)

21.78

(+6%)

12.02

(+4.2%)

T5LargeTFBatch1

8.51

2.24

1.71

8.98

(+5.5%)

2.40

(+6.8%)

1.75

(+2.4%)

T5LargeTFBatch32

256.04

6.368

33.68

27.58

(+7.7%)

68.93

(+8.3%)

36.63

(+8.8%)

There are slightly regressions on both models and all thread numbers but not too much (max at +8.8%). The sections below provide detailed analysis based on 1-thread (local-sync) results.

Note that one reason we only have slight regressions in e2e latencies is that batch_matmul doesn’t occupy large portions of the models (only 10.7% run time in T5LargeTFBatch32 and 6% in BertLargeTFBatch32).

Breakdown: T5LargeTFBatch32

There are 4 batch_matmul dispatches in the baseline module and 2 batch_matmul dispatches in the data-tiling module, due to the deduplication (but 4 unpack dispatches). The run time and characteristics of each dispatch are listed in the table below. The symbol on each row links the relations of dispatches between two modules. Same symbol means they are the same workload on both sides (origin from the same parts of source MLIR).

T5LargeTFBatch32 baseline batch_matmul dispatches characteristics

Name

Total time (s)

Total time (%)

Launch count

Shape

Fusion

Full runs

Arith ops

Estimated GFlops

🟩 dispatch_12_batch_matmul

6.93

2.62

196608

512x512x512x64

fused with generic

24

4

118.8

🔶 dispatch_18_batch_matmul

11.63

4.41

147456

512x512x64x512

no fusion

72

2

106.3

🟩 dispatch_39_batch_matmu

4.82

1.82

196608

512x512x512x64

no fusion

24

2

85.5

🟩 dispatch_755_batch_matmul

5.06

1.92

196608

512x512x512x64

fused with generic

24

3

122.0

“Full runs” means the number of complete batch_matmul runs (each run may issue multiple dispatch launches due to workgroup distribution)

“Artith ops” in a dispatch is estimated from high-level MLIR ops (matmul = 2, addf = 1, etc.) and used to estimated GFlops

T5LargeTFBatch32 data-tiling batch_matmul dispatches characteristics

Name

Total time (s)

Total time (%)

Launch count

Encoding and shape

Fusion

Full runs

Arith ops

Estimated GFlops

🟩 dispatch_12_set_encoding_LHS

0.95

0.32

4608

LHS

512x512x64

non-constant, no fusion

72

🟩 dispatch_13_set_encoding_RHS

1.48

0.51

9216

RHS

512x64x512

non-constant, no fusion

72

🟩 dispatch_14_batch_matmul

10.71

3.71

36864

512x512x512x64

72

2

115.4

🟩 dispatch_15_generic

5.17

1.79

12288

RESULT

512x512x512

fused with generic

24

2

79.6

🟩 dispatch_48_unset_encoding

4.27

1.48

12288

RESULT

512x512x512

no fusion

24

🟩 dispatch_908_generic

4.32

1.49

12288

RESULT

512x512x512

fused with generic

24

1

47.6

🔶 dispatch_21_set_encoding_LHS

10.16

3.52

36864

LHS

512x512x512

non-constant, no fusion

72

🔶 dispatch_22_set_encoding_RHS

1.54

0.53

73728

RHS

512x512x64

non-constant, no fusion

72

🔶 dispatch_23_batch_matmul

11.47

3.97

36864

512x512x64x512

72

2

107.8

🔶 dispatch_24_unset_encoding

1.02

0.35

4608

RESULT

512x512x64

no fusion

72

Analysis

There are 2 types of batch_matmul: 512x512x512x64 (BxNxMxK) and 512x512x64x512 and the both run slower with data-tiling:

Shape Baseline Data-tiling Regression
🟩 512x512x512x64 16.82s 26.93s = 2.43 (pack) + 10.71 (ukernel) + 13.77 (unpack + generic) +60%
🔶 512x512x64x512 11.63s 24.20s = 11.70 (pack) + 11.47 (ukernel) + 1.02 (unpack) +108%

Pack and unpack have significant overhead. The 11.07s pack time on 512x512x64x512 looks abnormal and needs a further investigation. We also found that 13.77s unpack time on 512x512x512x64 can be reduced to 8s if --iree-codegen-llvm-disable-distribution is used.

Another way to look at the data is that the estimated GFlops on baseline batch_matmul are close or exceed 100 GFlops. The batch_matmul (batch_mmt4d) ukernel also only reaches similar GFlops, which means that the ukernel doesn’t help too much in this case.

Next steps

In this case, here are 4 improvements can be made:

  • Resolve the abnormal packing time
  • Improve unpack time with better distribution tiling, as we see improvement with the distribution disabled
  • See if ukernel performance can be further improved
  • Consider fusing generic (or even unpack) into ukernel dispatch

Or if 100 GFlops on the non-data-tiling batch_matmul is considered good enough, we might need to control when to apply data-tiling and skip these cases.

Breakdown: BertLargeTFBatch32

Similar to the analysis on T5LargeTFBatch32, we first analyze the batch_matmul dispatches:

BertLargeTFBatch32 baseline batch_matmul dispatches characteristics

Name

Total time (s)

Total time (%)

Launch count

Shape

Fusion

Full runs

Arith ops

Estimated GFlops

🟩 dispatch_13_batch_matmul

2.60

3.24

110592

512x384x384x64

no fusion

24

2

89.0

🔶 dispatch_20_batch_matmul

2.25

2.81

36864

512x384x64x384

no fusion

24

2

102.6

BertLargeTFBatch32 data-tiling batch_matmul dispatches characteristics

Name

Total time (s)

Total time (%)

Launch count

Encoding and shape

Fusion

Full runs

Arith ops

Estimated GFlops

🟩 dispatch_13_set_encoding_LHS

0.24

0.28

1152

LHS

512x384x64

non-constant, no fusion

24

🟩 dispatch_14_set_encoding_RHS

0.39

0.46

3072

RHS

512x64x384

non-constant no fusion

24

🟩 dispatch_15_batch_matmul

2.14

2.50

12288

512x384x384x64

24

2

108.2

🟩 dispatch_16_unset_encoding

1.97

2.30

6912

RESULT

512x384x384

no fusion

24

🔶 dispatch_23_set_encoding_LHS

1.68

1.96

6912

LHS

512x384x384

non-constant, no fusion

24

🔶 dispatch_24_set_encoding_RHS

0.38

0.45

18432

RHS

512x384x64

non-constant, no fusion

24

🔶 dispatch_25_batch_matmul

2.18

2.54

12288

512x384x64x384

24

2

106.2

🔶 dispatch_26_unset_encoding

0.25

0.29

1152

RESULT

512x384x64

no fusion

24

Analysis

There are 2 types of batch_matmul: 512x384x384x64 (BxNxMxK) and 512x384x64x384. The table below shows that on both types data-tiling is slower and majorly due to the pack/unpack overhead. Also in this case there is no generic op fused together to hide the overhead.

Shape Baseline Data-tiling Regression
🟩 512x384x384x64 2.60s 4.75s = 0.63s (pack) + 2.14s (ukernel) + 1.97s (unpack) +82%
🔶 512x384x64x384 2.25s 4.50s = 2.07s (pack) + 2.18s (ukernel) + 0.25s (unpack) +100%

Next steps

This looks like not a data-tiling friendly case as there is no generic op to fuse with pack/unpack. In the MLIR we see that the predecessors and successors of batch_matmul are reshape ops. Before and after the reshape ops are generic ops. A potential improvement is to move pack/unpack across generic ops to pull them in for fusion. But reshape ops between them will need to be handled.

Microbenchmarks

We sampled the heaviest batch_matmul from T5LargeTFBatch32 to further understand the performance on a single batch_matmul.

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  func.func @batch_matmul_generic(%arg0: tensor<512x512x64xf32>, %arg1: tensor<512x64x512xf32>, %arg2: tensor<512x512x512xf32>) -> tensor<512x512x512xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant dense<9.99999971E-10> : tensor<512x512x512xf32>
    %0 = tensor.empty() : tensor<512x512x512xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
    %2 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x512x64xf32>, tensor<512x64x512xf32>) outs(%1 : tensor<512x512x512xf32>) -> tensor<512x512x512xf32>
    %3 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2, %arg2, %cst_0 : tensor<512x512x512xf32>, tensor<512x512x512xf32>, tensor<512x512x512xf32>) outs(%0 : tensor<512x512x512xf32>) {
    ^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32):
      %5 = arith.addf %in, %in_1 : f32
      %6 = arith.addf %5, %in_2 : f32
      linalg.yield %6 : f32
    } -> tensor<512x512x512xf32>
    return %3 : tensor<512x512x512xf32>
  }
}

Dispatch is repeated 50 times with --iree-hal-benchmark-dispatch-repeat-count=50 and get the averages. The benchmarks run with tracy to help analysis, which only adds ~3% overhead in this case.

Baseline (s) batch_matmul data-tiling + ukernel Regression
batch_matmul + generic op 0.30 0.39 = 0.03 (pack) + 0.15 (ukernel) + 0.21 (unpack + generic) +31.1%
batch_matmul 0.22 0.36 = 0.03 (pack) + 0.15 (ukernel) + 0.18 (unpack) +64.12%
generic op 0.14
0.31 (iree-flow-collapse-dimensions disabled)

The results match what we saw in the e2e benchmark breakdown. Note that when only benchmarking generic op, by default iree-flow-collapse-dimensions will collapse the 3 dims generic op (512x512x512) to 1 dim (134217728), tile with 4096 and distribution level, and show big improvements

That inspired us that there might be issues in distribution tiling. So we benchmarked with --iree-codegen-llvm-disable-distribution to test it. As the numbers in the table below, disabling distribution improves unpack + generic a lot and the regression is down to 2%.

Baseline (s) batch_matmul data-tiling + ukernel Regression
batch_matmul + generic op 0.30 0.31 = 0.03 (pack) + 0.15 (ukernel) + 0.13 (unpack + generic) +2%
batch_matmul 0.22 0.29 = 0.03 (pack) + 0.15 (ukernel) + 0.10 (unpack) +28.7%
generic op 0.10

Conclusion

In T5LargeTF and BertLargeTF we identified there might be performance issues with distribution-level tiling and in some cases the pack/unpack ops have no generic ops to fuse to hide overhead. Also in baseline some batch_matmul are already running quite well. So the action items could be:

  • Optimize the distribution-level tiling to improve batch_matmul unpack performance
  • Enable more fusions on pack/unpack with generic ops
  • Survey more models and batch matmul shapes to know when data-tiling approach will bring better performance

@pzread
Copy link
Contributor

pzread commented Aug 30, 2023

Hi @MaheshRavishankar @bjacob I posted an analysis of batch_matmul data-tiling on e2e models: #14431 (comment). I'm happy to know if there are any comments 🙂

@MaheshRavishankar
Copy link
Contributor

Hi @MaheshRavishankar @bjacob I posted an analysis of batch_matmul data-tiling on e2e models: #14431 (comment). I'm happy to know if there are any comments 🙂

Great analysis! Lot to unpack there, which I am not going to be able to do right now, but I will come back to this in a couple of weeks.

@bjacob
Copy link
Contributor

bjacob commented Aug 30, 2023

Thanks for the analysis. I find it striking that unpack overhead is so large.

For example let's look at the final table in the above comment. For batch_matmul, the baseline matmul does 0.22 seconds, while with data-tiling+ukernel you get 0.29 = 0.03 (pack) + 0.15 (ukernel) + 0.10 (unpack).

Already this shows that the data-tiling+ukernel is having the intended effect of enabling a much faster matmul kernel (0.15) vs the baseline (0.22) --- the matmul kernel part is being sped up by ~ 1.5x here.

The question becomes, how do we ensure that pack/unpack overhead, and other ops in the workload such as that generic op being fused on the unpack side, don't ruin this 1.5x speedup.

These numbers, 0.29 = 0.03 (pack) + 0.15 (ukernel) + 0.10 (unpack), shows a reasonably small overhead for pack (0.03) (although that's already a bit on the large side for this shape and we should expect future codegen improvements to further reduce that pack overhead), but the unpack overhead is very large, 0.10, meaning that this is spending almost as much time in unpack as in the matmul kernel. Since the matmul shape here is 512x512x512, this matmul should be doing about K=512 times more work in the matmul kernel than in the unpack. Even allowing for a 10x relative inefficiency of unpack work, writing to memory etc, for such large values as K=512, we should expect unpack overhead to be small.

So my guess is there's something really wrong going on in the codegen for unpack. if you disassemble it, something obvious should show.

Next, the generic op too has a surprisingly large cost. Maybe that too is a candidate for improving codegen.

All in all, these measurements show that data-tiling+ukernel is having the intended effect of speeding up the matmul kernel, but work is needed on the rest, particularly the unpack, to reap the e2e benefits.

Note: we have a ukernel for unpack. We just don't use it outside of VMVX at the moment. We could always give it a try here. We want to avoid it because it will prevent fusions, and unpack is a place where we think that fusions are going to be important. But if the unpack codegen is this bad, that is even more important than fusions. Depends how much you want to unblock this in the short term. Long term, we really should improve codegen so we can keep the fusions and not need a unpack ukernel.

@MaheshRavishankar
Copy link
Contributor

Couple of high-level points after some skim of the analysis from CheYu

  1. I'd like to get an idea of what fusion is missing right now for the unpack. It might be that there is no linalg.generic op to fuse with. I have disabled currently the fusion of pack/unpack with linalg.generic that is not elementwise cause I think that already getting unpack/pack to fuse with elementwise needs to get a bit more robust before going further.
  2. It is fundamentally harder to generate code for unpack fused with elementwise operation. Fusing unpack with the batch-mmt4d and mmt4d ops is a bit convoluted. It would be better to directly write out the result into linear layout and not the packed layout for the result instead of "fusing" unpack with the batch_mmt4d/mmt4d op. That might be a bit more long term.

@hanhanW
Copy link
Contributor Author

hanhanW commented Aug 30, 2023

Really great analysis! I think something is wrong in pack/unpack codegen. The number of elements to be packed are the same in the example, but one takes 0.63 sec and the other takes 2.07 sec.

🟩 512x384x384x64 2.60s 4.75s = 0.63s (pack) + 2.14s (ukernel) + 1.97s (unpack) +82%
🔶 512x384x64x384 2.25s 4.50s = 2.07s (pack) + 2.18s (ukernel) + 0.25s (unpack) +100%

@hanhanW
Copy link
Contributor Author

hanhanW commented Oct 26, 2023

Closing the issue because the functionality is completed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen/llvm LLVM code generation compiler backend codegen Shared code generation infrastructure and dialects
Projects
None yet
Development

No branches or pull requests

4 participants