Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TileAndFuse] Add thread groups for convolution ops #695

Merged
merged 4 commits into from
Sep 4, 2024

Conversation

newling
Copy link
Contributor

@newling newling commented Aug 22, 2024

This works for now because the sizes of dimensions DimZ and LinearDim0 are 1 with our tiling strategy, and so DimY and DimX map to the rows and columns of the AIE array. Follow-up work: pass to collapse scf.forall's with more than 2 induction variables to just 2. So instead of

(i,j,k) in (2,3,5)

for example, could be

(i,l) in (2,15) and then j=l/5 k=l%5.

@newling
Copy link
Contributor Author

newling commented Aug 22, 2024

IR for AIR lowering is below. The compilation fails in air-dependency-canonicalize with

error: 'air.launch' op found non-herd op with core id

@erwei-xilinx is this something that can be fixed (assuming the IR below looks valid) ? Let me know if you'd like the full IR trace.

FWIW this allocation of blocks and threads to convolution works ok with the objectFifo pipeline (which FWIW is currently running out of BDs in a late mlir-aie pass).

// -----// IR Dump Before AMDAIEBridgeToAIR (iree-amdaie-bridge-to-air) //----- //
module {
  func.func @conv_2d_nhwc_hwcf_dispatch_0_conv_2d_nhwc_hwcf_2x12x12x64x3x3x32_i32() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
    %c8 = arith.constant 8 : index
    %c1 = arith.constant 1 : index
    %c32 = arith.constant 32 : index
    %c3 = arith.constant 3 : index
    %c0_i32 = arith.constant 0 : i32
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<2x14x14x32xi32>
    memref.assume_alignment %0, 64 : memref<2x14x14x32xi32>
    %1 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<3x3x32x64xi32>
    memref.assume_alignment %1, 64 : memref<3x3x32x64xi32>
    %2 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%c0) : memref<2x12x12x64xi32>
    memref.assume_alignment %2, 64 : memref<2x12x12x64xi32>
    scf.forall (%arg0, %arg1, %arg2) = (0, 0, 0) to (12, 12, 64) step (4, 4, 4) {
      %subview = memref.subview %0[0, %arg0, %arg1, 0] [2, 6, 6, 32] [1, 1, 1, 1] : memref<2x14x14x32xi32> to memref<2x6x6x32xi32, strided<[6272, 448, 32, 1], offset: ?>>
      %subview_0 = memref.subview %1[0, 0, 0, %arg2] [3, 3, 32, 4] [1, 1, 1, 1] : memref<3x3x32x64xi32> to memref<3x3x32x4xi32, strided<[6144, 2048, 64, 1], offset: ?>>
      %subview_1 = memref.subview %2[0, %arg0, %arg1, %arg2] [2, 4, 4, 4] [1, 1, 1, 1] : memref<2x12x12x64xi32> to memref<2x4x4x4xi32, strided<[9216, 768, 64, 1], offset: ?>>
      %alloc = memref.alloc() : memref<2x6x6x32xi32, 1 : i32>
      linalg.copy ins(%subview : memref<2x6x6x32xi32, strided<[6272, 448, 32, 1], offset: ?>>) outs(%alloc : memref<2x6x6x32xi32, 1 : i32>)
      %alloc_2 = memref.alloc() : memref<3x3x32x4xi32, 1 : i32>
      linalg.copy ins(%subview_0 : memref<3x3x32x4xi32, strided<[6144, 2048, 64, 1], offset: ?>>) outs(%alloc_2 : memref<3x3x32x4xi32, 1 : i32>)
      %alloc_3 = memref.alloc() : memref<2x4x4x4xi32, 1 : i32>
      scf.forall (%arg3, %arg4, %arg5, %arg6) = (0, 0, 0, 0) to (2, 4, 4, 4) step (1, 1, 4, 4) {
        %subview_4 = memref.subview %alloc_3[%arg3, %arg4, %arg5, %arg6] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x4x4x4xi32, 1 : i32> to memref<1x1x4x4xi32, strided<[64, 16, 4, 1], offset: ?>, 1 : i32>
        %alloc_5 = memref.alloc() : memref<1x1x4x4xi32, 2 : i32>
        linalg.fill ins(%c0_i32 : i32) outs(%alloc_5 : memref<1x1x4x4xi32, 2 : i32>)
        scf.for %arg7 = %c0 to %c3 step %c1 {
          scf.for %arg8 = %c0 to %c3 step %c1 {
            scf.for %arg9 = %c0 to %c32 step %c8 {
              %3 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg4, %arg7]
              %4 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg5, %arg8]
              %subview_6 = memref.subview %alloc[%arg3, %3, %4, %arg9] [1, 1, 4, 8] [1, 1, 1, 1] : memref<2x6x6x32xi32, 1 : i32> to memref<1x1x4x8xi32, strided<[1152, 192, 32, 1], offset: ?>, 1 : i32>
              %subview_7 = memref.subview %alloc_2[%arg7, %arg8, %arg9, %arg6] [1, 1, 8, 4] [1, 1, 1, 1] : memref<3x3x32x4xi32, 1 : i32> to memref<1x1x8x4xi32, strided<[384, 128, 4, 1], offset: ?>, 1 : i32>
              %alloc_8 = memref.alloc() : memref<1x1x4x8xi32, 2 : i32>
              linalg.copy ins(%subview_6 : memref<1x1x4x8xi32, strided<[1152, 192, 32, 1], offset: ?>, 1 : i32>) outs(%alloc_8 : memref<1x1x4x8xi32, 2 : i32>)
              %alloc_9 = memref.alloc() : memref<1x1x8x4xi32, 2 : i32>
              linalg.copy ins(%subview_7 : memref<1x1x8x4xi32, strided<[384, 128, 4, 1], offset: ?>, 1 : i32>) outs(%alloc_9 : memref<1x1x8x4xi32, 2 : i32>)
              %subview_10 = memref.subview %alloc_8[0, 0, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : memref<1x1x4x8xi32, 2 : i32> to memref<1x4x8xi32, strided<[32, 8, 1]>, 2 : i32>
              %subview_11 = memref.subview %alloc_9[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : memref<1x1x8x4xi32, 2 : i32> to memref<1x8x4xi32, strided<[32, 4, 1]>, 2 : i32>
              %subview_12 = memref.subview %alloc_5[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xi32, 2 : i32> to memref<1x4x4xi32, strided<[16, 4, 1]>, 2 : i32>
              linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%subview_10, %subview_11 : memref<1x4x8xi32, strided<[32, 8, 1]>, 2 : i32>, memref<1x8x4xi32, strided<[32, 4, 1]>, 2 : i32>) outs(%subview_12 : memref<1x4x4xi32, strided<[16, 4, 1]>, 2 : i32>)
              memref.dealloc %alloc_8 : memref<1x1x4x8xi32, 2 : i32>
              memref.dealloc %alloc_9 : memref<1x1x8x4xi32, 2 : i32>
            }
          }
        }
        linalg.copy ins(%alloc_5 : memref<1x1x4x4xi32, 2 : i32>) outs(%subview_4 : memref<1x1x4x4xi32, strided<[64, 16, 4, 1], offset: ?>, 1 : i32>)
        memref.dealloc %alloc_5 : memref<1x1x4x4xi32, 2 : i32>
      } {mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>, #gpu.thread<linear_dim_0>]}
      linalg.copy ins(%alloc_3 : memref<2x4x4x4xi32, 1 : i32>) outs(%subview_1 : memref<2x4x4x4xi32, strided<[9216, 768, 64, 1], offset: ?>>)
      memref.dealloc %alloc : memref<2x6x6x32xi32, 1 : i32>
      memref.dealloc %alloc_2 : memref<3x3x32x4xi32, 1 : i32>
      memref.dealloc %alloc_3 : memref<2x4x4x4xi32, 1 : i32>
    } {mapping = [#gpu.block<y>, #gpu.block<x>, #gpu.block<z>]}
    return
  }
}

@newling
Copy link
Contributor Author

newling commented Aug 22, 2024

IR for AIR lowering is below. The compilation fails in air-dependency-canonicalize with

error: 'air.launch' op found non-herd op with core id

@erwei-xilinx is this something that can be fixed (assuming the IR below looks valid) ? Let me know if you'd like the full IR trace.

FWIW this allocation of blocks and threads to convolution works ok with the objectFifo pipeline (which FWIW is currently running out of BDs in a late mlir-aie pass).

// -----// IR Dump Before AMDAIEBridgeToAIR (iree-amdaie-bridge-to-air) //----- //
module {
  func.func @conv_2d_nhwc_hwcf_dispatch_0_conv_2d_nhwc_hwcf_2x12x12x64x3x3x32_i32() attributes {translation_info = #iree_codegen.translation_info<Custom>} {
    %c8 = arith.constant 8 : index
    %c1 = arith.constant 1 : index
    %c32 = arith.constant 32 : index
    %c3 = arith.constant 3 : index
    %c0_i32 = arith.constant 0 : i32
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<2x14x14x32xi32>
    memref.assume_alignment %0, 64 : memref<2x14x14x32xi32>
    %1 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<3x3x32x64xi32>
    memref.assume_alignment %1, 64 : memref<3x3x32x64xi32>
    %2 = hal.interface.binding.subspan layout(<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%c0) : memref<2x12x12x64xi32>
    memref.assume_alignment %2, 64 : memref<2x12x12x64xi32>
    scf.forall (%arg0, %arg1, %arg2) = (0, 0, 0) to (12, 12, 64) step (4, 4, 4) {
      %subview = memref.subview %0[0, %arg0, %arg1, 0] [2, 6, 6, 32] [1, 1, 1, 1] : memref<2x14x14x32xi32> to memref<2x6x6x32xi32, strided<[6272, 448, 32, 1], offset: ?>>
      %subview_0 = memref.subview %1[0, 0, 0, %arg2] [3, 3, 32, 4] [1, 1, 1, 1] : memref<3x3x32x64xi32> to memref<3x3x32x4xi32, strided<[6144, 2048, 64, 1], offset: ?>>
      %subview_1 = memref.subview %2[0, %arg0, %arg1, %arg2] [2, 4, 4, 4] [1, 1, 1, 1] : memref<2x12x12x64xi32> to memref<2x4x4x4xi32, strided<[9216, 768, 64, 1], offset: ?>>
      %alloc = memref.alloc() : memref<2x6x6x32xi32, 1 : i32>
      linalg.copy ins(%subview : memref<2x6x6x32xi32, strided<[6272, 448, 32, 1], offset: ?>>) outs(%alloc : memref<2x6x6x32xi32, 1 : i32>)
      %alloc_2 = memref.alloc() : memref<3x3x32x4xi32, 1 : i32>
      linalg.copy ins(%subview_0 : memref<3x3x32x4xi32, strided<[6144, 2048, 64, 1], offset: ?>>) outs(%alloc_2 : memref<3x3x32x4xi32, 1 : i32>)
      %alloc_3 = memref.alloc() : memref<2x4x4x4xi32, 1 : i32>
      scf.forall (%arg3, %arg4, %arg5, %arg6) = (0, 0, 0, 0) to (2, 4, 4, 4) step (1, 1, 4, 4) {
        %subview_4 = memref.subview %alloc_3[%arg3, %arg4, %arg5, %arg6] [1, 1, 4, 4] [1, 1, 1, 1] : memref<2x4x4x4xi32, 1 : i32> to memref<1x1x4x4xi32, strided<[64, 16, 4, 1], offset: ?>, 1 : i32>
        %alloc_5 = memref.alloc() : memref<1x1x4x4xi32, 2 : i32>
        linalg.fill ins(%c0_i32 : i32) outs(%alloc_5 : memref<1x1x4x4xi32, 2 : i32>)
        scf.for %arg7 = %c0 to %c3 step %c1 {
          scf.for %arg8 = %c0 to %c3 step %c1 {
            scf.for %arg9 = %c0 to %c32 step %c8 {
              %3 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg4, %arg7]
              %4 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg5, %arg8]
              %subview_6 = memref.subview %alloc[%arg3, %3, %4, %arg9] [1, 1, 4, 8] [1, 1, 1, 1] : memref<2x6x6x32xi32, 1 : i32> to memref<1x1x4x8xi32, strided<[1152, 192, 32, 1], offset: ?>, 1 : i32>
              %subview_7 = memref.subview %alloc_2[%arg7, %arg8, %arg9, %arg6] [1, 1, 8, 4] [1, 1, 1, 1] : memref<3x3x32x4xi32, 1 : i32> to memref<1x1x8x4xi32, strided<[384, 128, 4, 1], offset: ?>, 1 : i32>
              %alloc_8 = memref.alloc() : memref<1x1x4x8xi32, 2 : i32>
              linalg.copy ins(%subview_6 : memref<1x1x4x8xi32, strided<[1152, 192, 32, 1], offset: ?>, 1 : i32>) outs(%alloc_8 : memref<1x1x4x8xi32, 2 : i32>)
              %alloc_9 = memref.alloc() : memref<1x1x8x4xi32, 2 : i32>
              linalg.copy ins(%subview_7 : memref<1x1x8x4xi32, strided<[384, 128, 4, 1], offset: ?>, 1 : i32>) outs(%alloc_9 : memref<1x1x8x4xi32, 2 : i32>)
              %subview_10 = memref.subview %alloc_8[0, 0, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : memref<1x1x4x8xi32, 2 : i32> to memref<1x4x8xi32, strided<[32, 8, 1]>, 2 : i32>
              %subview_11 = memref.subview %alloc_9[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : memref<1x1x8x4xi32, 2 : i32> to memref<1x8x4xi32, strided<[32, 4, 1]>, 2 : i32>
              %subview_12 = memref.subview %alloc_5[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<1x1x4x4xi32, 2 : i32> to memref<1x4x4xi32, strided<[16, 4, 1]>, 2 : i32>
              linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%subview_10, %subview_11 : memref<1x4x8xi32, strided<[32, 8, 1]>, 2 : i32>, memref<1x8x4xi32, strided<[32, 4, 1]>, 2 : i32>) outs(%subview_12 : memref<1x4x4xi32, strided<[16, 4, 1]>, 2 : i32>)
              memref.dealloc %alloc_8 : memref<1x1x4x8xi32, 2 : i32>
              memref.dealloc %alloc_9 : memref<1x1x8x4xi32, 2 : i32>
            }
          }
        }
        linalg.copy ins(%alloc_5 : memref<1x1x4x4xi32, 2 : i32>) outs(%subview_4 : memref<1x1x4x4xi32, strided<[64, 16, 4, 1], offset: ?>, 1 : i32>)
        memref.dealloc %alloc_5 : memref<1x1x4x4xi32, 2 : i32>
      } {mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>, #gpu.thread<linear_dim_0>]}
      linalg.copy ins(%alloc_3 : memref<2x4x4x4xi32, 1 : i32>) outs(%subview_1 : memref<2x4x4x4xi32, strided<[9216, 768, 64, 1], offset: ?>>)
      memref.dealloc %alloc : memref<2x6x6x32xi32, 1 : i32>
      memref.dealloc %alloc_2 : memref<3x3x32x4xi32, 1 : i32>
      memref.dealloc %alloc_3 : memref<2x4x4x4xi32, 1 : i32>
    } {mapping = [#gpu.block<y>, #gpu.block<x>, #gpu.block<z>]}
    return
  }
}

Suggestion from @erwei-xilinx to run canonicalization to reduce 4-d scf.parallel to 2-d scf.parallel works perfectly. We wonder why canonicalization doesn't do the same for scf.forall

Copy link
Contributor

@yzhang93 yzhang93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good starting point. I suggest to put this as a separate utility function.

uint32_t nbIndVars = std::count_if(tileSizesVal.begin(), tileSizesVal.end(),
[](int64_t t) { return t != 0; });

// See mlir::gpu::MappingId enum: there are currently 13 values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 189 to 194
if (i == 0)
mapping.push_back(getMappingAttributeForDimension(1));
else if (i == 1)
mapping.push_back(getMappingAttributeForDimension(0));
else
mapping.push_back(getMappingAttributeForDimension(i));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be problematic if we have the batch dimension. Typically, we should put DimZ for the batch dimension.

@newling newling force-pushed the thread_groups_for_conv branch 3 times, most recently from 57ae4ac to 38c7885 Compare August 29, 2024 15:52
@newling newling requested a review from yzhang93 August 29, 2024 15:53
@@ -32,7 +32,25 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<2x14x14x32xbf16>, %arg1: tensor<3x3x3
// TILE-LEVEL-0-SAME: {
// TILE-LEVEL-0: linalg.fill
// TILE-LEVEL-0: linalg.conv_2d_nhwc_hwcf
// TILE-LEVEL-0: }
// TILE-LEVEL-0: } {mapping = [#gpu.block<y>, #gpu.block<x>, #gpu.block<z>]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, the order of mapping attributes for block and thread should be corresponding to each other. Although we are not using block attributes at the moment, it's good to keep the attributes in the same order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so the logic is actually identical for threads and blocks. What's happening in this example is that the tiling sizes at thread and block levels are different:

For blocks: [0, 4, 4, 4, 0, 0, 0]
For threads: [1, 1, 4, 4, 0, 0, 0]

The logic works implemented in this PR works as follows:

First, assign attributes to dimensions with tile size greater than 1. For threads, that is dimensions 2 and 3.
Second, assign attributes to dimensions with tile size equal to 1. For threads, that is dimensions 0 and 1.

The attributes assigned in the order y then x then z then linear_dim_0, linear_dim_1 etc.

For [1, 1, 4, 4, 0, 0, 0], after step 1 the assigned dimensions are
[none, none, y, x, none, none, none]
and then after step 2 the assigned dimensions are
[z, linear_dim_0, y, x, none, none, none].

And that is why at the thread level we end up with [z, linear_dim_0, y, x].

Copy link
Contributor

@yzhang93 yzhang93 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Yeah, I can follow the steps to set up mapping attributes. It's just not typical that the attributes are different for the same dim in block and thread mapping. I don't have a good solution to solve this other than hardcoding the dimensions. On the other hand, since the block attributes are not used anyway, maybe we could remove these block attributes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried removing block dimension tiling, but for small matmul examples where the block tile sizes are all 1, there's a problem: scf.forall without tiling attributes get canonicalized away when the iterations space is size 1. i.e. the block-level scf.forall gets removed completely. Ideally we'd be able to work without this outer scf.forall, but currently the passes aren't set up to handle this, I guess. So removing block dimensions for matmuls isn't something we can immediately do.

I could remove them for convolution, but we might have the same issue when we have small convolutions.

So yeah, not sure. Maybe, for now, it's ok to keep the block dimensions as they are for convolution?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see... We should find a way to solve the scf.forall canonicalization problem. Could you add a TODO comment for the block mapping attributes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll add a comment. Thanks for accepting :)

// TILE-LEVEL-1-SAME: {
// TILE-LEVEL-1: linalg.fill
// TILE-LEVEL-1: linalg.conv_2d_nhwc_hwcf
// TILE-LEVEL-1: } {mapping = [#gpu.thread<z>, #gpu.thread<linear_dim_0>, #gpu.thread<y>, #gpu.thread<x>]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

@newling newling enabled auto-merge (squash) September 4, 2024 21:41
@newling newling merged commit 994f6e3 into nod-ai:main Sep 4, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants