From 91eea4acda595a15b7c8813ea5eace5d94718436 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Fri, 16 Aug 2024 13:37:15 -0700 Subject: [PATCH] Add lowerings for mma, register and allocate This PR adds a mma unit test which lowers to vector.loads/stores and amdgpu.mfmas. Also supports shared memory promotion. Signed-off-by: Harsh Menon --- lit_tests/kernel/wave/codegen.py | 98 +++++++++++++++++ lit_tests/kernel/wave/expansion.py | 20 ++-- lit_tests/kernel/wave/promotion.py | 14 +-- shark_turbine/kernel/compiler/ir.py | 4 +- .../kernel/compiler/vector_codegen.py | 11 +- shark_turbine/kernel/lang/wave_types.py | 3 +- shark_turbine/kernel/ops/wave_ops.py | 28 +++-- shark_turbine/kernel/wave/codegen.py | 100 ++++++++++++++++-- shark_turbine/kernel/wave/constraints.py | 16 +++ shark_turbine/kernel/wave/expansion.py | 8 +- shark_turbine/kernel/wave/promotion.py | 42 +++++--- .../kernel/wave/register_analysis.py | 32 ++++++ shark_turbine/kernel/wave/utils.py | 35 ++++++ shark_turbine/kernel/wave/wave.py | 12 ++- tests/kernel/wave/wave_gemm_test.py | 3 +- 15 files changed, 370 insertions(+), 56 deletions(-) create mode 100644 shark_turbine/kernel/wave/register_analysis.py create mode 100644 shark_turbine/kernel/wave/utils.py diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 548eba96..6c80e55b 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -5,6 +5,7 @@ import shark_turbine.kernel as tk import shark_turbine.kernel.lang as tkl import shark_turbine.kernel.wave as tkw +from shark_turbine.kernel.lang.global_symbols import * import torch M = tkl.sym.M @@ -13,7 +14,10 @@ BLOCK_M = tkl.sym.BLOCK_M BLOCK_N = tkl.sym.BLOCK_N BLOCK_K = tkl.sym.BLOCK_K +LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEM_PER_THREAD +STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEM_PER_THREAD ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 def launch(func: Callable[[], None]) -> Callable[[], None]: @@ -247,6 +251,100 @@ def test( # CHECK: vector.scatter %[[OUT]][%[[IDX_Y]], %[[IDX_X]]] [%[[OFF]]], %[[MASK]], %[[RES]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16> +@run +def test_mma(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + @tkw.wave(constraints) + def mma( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, c_reg) + tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 64, + N: 128, + K: 16, + BLOCK_M: 32, + BLOCK_N: 32, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + } + ): + a = torch.randn(64, 32, dtype=torch.float16) + b = torch.randn(128, 32, dtype=torch.float16) + c = torch.zeros(64, 128, dtype=torch.float32) + print(mma(a, b, c, canonicalize=True).module_op) + + # CHECK: func.func @mma(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %[[ARG2:.+]]: !stream.binding) { + # CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[ACC:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> + # CHECK: %[[WG0:.+]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WG1:.+]] = stream.dispatch.workgroup.id[1] : index + # CHECK: %[[TX:.+]] = gpu.thread_id x + # CHECK: %[[TY:.+]] = gpu.thread_id y + # CHECK: %[[R0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x16xf16, strided<[16, 1], offset: ?>> + # CHECK: %[[R1:.+]] = arith.muli %[[WG0]], %[[C32]] : index + # CHECK: %[[R2:.+]] = arith.divsi %[[TX]], %[[C4]] : index + # CHECK: %[[R3:.+]] = arith.addi %[[R2]], %[[R1]] : index + # CHECK: %[[R4:.+]] = vector.load %0[%[[R3]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16> + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space> + # CHECK: %[[R5:.+]] = arith.muli %[[WG0]], %[[C32]] : index + # CHECK: %[[R6:.+]] = arith.divsi %[[TX]], %[[C4]] : index + # CHECK: %[[R7:.+]] = arith.addi %[[R6]], %[[R5]] : index + # CHECK: vector.store %4, %[[ALLOC]][%[[R7]], %[[C0]]] : memref<32x16xf16, #gpu.address_space>, vector<4xf16> + # CHECK: %[[R8:.+]] = arith.muli %[[WG0]], %[[C32]] : index + # CHECK: %[[R9:.+]] = arith.divsi %[[TX]], %[[C4]] : index + # CHECK: %[[R10:.+]] = arith.addi %[[R9]], %[[R8]] : index + # CHECK: %[[R11:.+]] = vector.load %[[ALLOC]][%[[R10]], %[[C0]]] : memref<32x16xf16, #gpu.address_space>, vector<4xf16> + # CHECK: %[[R12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, strided<[16, 1], offset: ?>> + # CHECK: %[[R13:.+]] = arith.muli %[[TY]], %[[C16]] : index + # CHECK: %[[R14:.+]] = arith.muli %[[WG1]], %[[C32]] : index + # CHECK: %[[R15:.+]] = arith.addi %[[R14]], %[[R13]] : index + # CHECK: %[[R16:.+]] = vector.load %[[R12]][%[[R15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space> + # CHECK: %[[R17:.+]] = arith.muli %[[TY]], %[[C16]] : index + # CHECK: %[[R18:.+]] = arith.muli %[[WG1]], %[[C32]] : index + # CHECK: %[[R19:.+]] = arith.addi %[[R18]], %[[R17]] : index + # CHECK: vector.store %16, %[[ALLOC_0]][%[[R19]], %[[C0]]] : memref<32x16xf16, #gpu.address_space>, vector<4xf16> + # CHECK: %[[R20:.+]] = arith.muli %[[TY]], %[[C16]] : index + # CHECK: %[[R21:.+]] = arith.muli %[[WG1]], %[[C32]] : index + # CHECK: %[[R22:.+]] = arith.addi %[[R21]], %[[R20]] : index + # CHECK: %[[R23:.+]] = vector.load %[[ALLOC_0]][%[[R22]], %[[C0]]] : memref<32x16xf16, #gpu.address_space>, vector<4xf16> + # CHECK: %[[R24:.+]] = amdgpu.mfma %[[R11]] * %[[R23]] + %[[ACC]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[R25:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, strided<[128, 1], offset: ?>> + # CHECK: %[[R26:.+]] = arith.muli %[[WG0]], %[[C32]] : index + # CHECK: %[[R27:.+]] = arith.divsi %[[TX]], %[[C4]] : index + # CHECK: %[[R28:.+]] = arith.addi %[[R27]], %[[R26]] : index + # CHECK: %[[R29:.+]] = arith.muli %[[TY]], %[[C16]] : index + # CHECK: %[[R30:.+]] = arith.muli %[[WG1]], %[[C32]] : index + # CHECK: %[[R31:.+]] = arith.addi %[[R30]], %[[R29]] : index + # CHECK: vector.store %[[R24]], %[[R25]][%[[R28]], %[[R31]]] : memref<64x128xf32, strided<[128, 1], offset: ?>>, vector<4xf32> + + @run def test_add_float(): constraints: list[tkw.Constraint] = [ diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index a2a5429c..c8651f7c 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -350,25 +350,25 @@ def test_gemm(): # CHECK-SAME: acc=acc_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) # CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) - # CHECK-SAME: acc=mma_0_0_0 (index = None)) + # CHECK-SAME: acc=mma_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_1_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: acc=acc_1_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: mma(lhs=read_1_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) # CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) - # CHECK-SAME: acc=mma_1_1_0 (index = None)) + # CHECK-SAME: acc=mma_1_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: mma(lhs=read_1_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: rhs=read_0_0_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: acc=acc_1_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_1_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) # CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) - # CHECK-SAME: acc=mma_1_0_0 (index = None)) + # CHECK-SAME: acc=mma_1_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_0_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: acc=acc_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) # CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) - # CHECK-SAME: acc=mma_0_1_0 (index = None)) + # CHECK-SAME: acc=mma_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1],)) # CHECK-NEXT: ----- @@ -502,25 +502,25 @@ def test_gemm_reduction_expansion_only(): # CHECK-SAME: acc=acc_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) # CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) - # CHECK-SAME: acc=mma_0_0_0 (index = None)) + # CHECK-SAME: acc=mma_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_0_0_2 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1]) # CHECK-SAME: rhs=read_0_0_2 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1]) - # CHECK-SAME: acc=mma_0_0_1 (index = None)) + # CHECK-SAME: acc=mma_0_0_1 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_0_0_3 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1]) # CHECK-SAME: rhs=read_0_0_3 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1]) - # CHECK-SAME: acc=mma_0_0_2 (index = None)) + # CHECK-SAME: acc=mma_0_0_2 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)])) # CHECK-NEXT: mma(lhs=read_0_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1]) # CHECK-SAME: acc=acc_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) # CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1]) - # CHECK-SAME: acc=mma_0_1_0 (index = None)) + # CHECK-SAME: acc=mma_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: mma(lhs=read_0_0_2 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1]) # CHECK-SAME: rhs=read_0_1_2 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1]) - # CHECK-SAME: acc=mma_0_1_1 (index = None)) + # CHECK-SAME: acc=mma_0_1_1 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: mma(lhs=read_0_0_3 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1]) # CHECK-SAME: rhs=read_0_1_3 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1]) - # CHECK-SAME: acc=mma_0_1_2 (index = None)) + # CHECK-SAME: acc=mma_0_1_2 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16])) # CHECK-NEXT: output(return_vals=([mma_0_0_3, mma_0_1_3],)) # CHECK-NEXT: ----- diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index b46f912c..34b96687 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -87,14 +87,14 @@ def test_read_write_equal_sizes(): graph: fx.Graph = trace.get_root_graph() read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() - promote_node(read_node, SHARED_ADDRESS_SPACE) + promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) print_trace(trace) # CHECK: %a # CHECK-NEXT: %c # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -136,14 +136,14 @@ def test_read_write_equal_sizes_different_address_spaces(): ): trace: CapturedTrace = read_write_same_size_different_address_spaces() IndexingContext.current().finalize() - promote_placeholders(trace) + promote_placeholders(trace, constraints) print_trace(trace) # CHECK: %a # CHECK-NEXT: %c # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -191,7 +191,7 @@ def test_gemm(): graph: fx.Graph = trace.get_subgraph("region_0") read_nodes = get_read_nodes(graph) for read_node in read_nodes: - promote_node(read_node, SHARED_ADDRESS_SPACE) + promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) hoist_allocs(trace) IndexingContext.current().finalize() print_trace(trace) @@ -201,9 +201,9 @@ def test_gemm(): # CHECK-NEXT: %c # CHECK-NEXT: %register # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-NEXT: %write # CHECK-SAME: (%reduction, %c, 4, None) diff --git a/shark_turbine/kernel/compiler/ir.py b/shark_turbine/kernel/compiler/ir.py index 48d87197..5ede83fa 100644 --- a/shark_turbine/kernel/compiler/ir.py +++ b/shark_turbine/kernel/compiler/ir.py @@ -35,6 +35,7 @@ from iree.compiler.dialects import ( arith as arith_d, + amdgpu as amdgpu_d, builtin as builtin_d, flow as flow_d, func as func_d, @@ -42,6 +43,7 @@ math as math_d, memref as memref_d, stream as stream_d, - vector as vector_d, scf as scf_d, + transform as transform_d, + vector as vector_d, ) diff --git a/shark_turbine/kernel/compiler/vector_codegen.py b/shark_turbine/kernel/compiler/vector_codegen.py index 775828a1..98eaf1fd 100644 --- a/shark_turbine/kernel/compiler/vector_codegen.py +++ b/shark_turbine/kernel/compiler/vector_codegen.py @@ -790,7 +790,11 @@ def cast_py_value(emitter: ThreadEmitter, value) -> IRProxyValue: try: node_values = emitter.lookup_node_values(value) assert len(node_values) == 1, f"Expected exactly one value for node {value}" - return node_values[0] + return ( + node_values[0] + if isinstance(node_values[0], IRProxyValue) + else IRProxyValue(node_values[0]) + ) except KeyError: raise CodegenError(f"Producer node `{value}` has no IR Value") elif isinstance(value, IndexExpr): @@ -828,6 +832,11 @@ def cast_kernel_buffer( value, node = cast_py_lvalue(emitter, kb) ir_type = value.type py_type = node.type + if py_type is None: + try: + py_type = ops.wave_ops.get_custom(node).type + except: + raise CodegenError(f"Could not find type for node {node}") if not MemRefType.isinstance(ir_type): raise CodegenError( diff --git a/shark_turbine/kernel/lang/wave_types.py b/shark_turbine/kernel/lang/wave_types.py index d4ffc4bb..c84bdbdd 100644 --- a/shark_turbine/kernel/lang/wave_types.py +++ b/shark_turbine/kernel/lang/wave_types.py @@ -11,7 +11,6 @@ ) from .kernel_buffer import AddressSpace, KernelBufferMeta, KernelBufferUsage -from ..ops.wave_ops import register from .._support.dtype import DataType from .._support.indexing import IndexExpr, IndexSymbol, index_symbol @@ -101,6 +100,8 @@ class Register(metaclass=KernelBufferMeta): value: float def __new__(cls, value: float) -> "Register": + from ..ops.wave_ops import register + return register(cls.symbolic_shape, cls.dtype, value) def __class_getitem__( diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 603f367d..be75e2e4 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -17,8 +17,7 @@ ) import torch.fx as fx -if TYPE_CHECKING: - from ..lang.wave_types import Memory, Register +from ..lang.wave_types import Memory, Register, IndexMapping from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence from .._support.dtype import DataType from .._support.regions import RegionGraph @@ -339,7 +338,7 @@ def index(self, value: Any): assert isinstance( key, IndexSequence ), f"Expected IndexSequence, got {key}" - if not hasattr(self.fx_node, "index"): + if not hasattr(self.fx_node, "index") or self.fx_node.index is None: self.fx_node.index = {} self.fx_node.index[dim] = key else: @@ -502,6 +501,7 @@ class Allocate(CustomOp): """ shape: tuple[IndexExpr] + distributed_shape: tuple[IndexExpr] dtype: DataType address_space: AddressSpace @@ -509,6 +509,10 @@ class Allocate(CustomOp): def indexing_dims(self) -> list[IndexSymbol]: return list(self.shape) + @property + def type(self) -> "Memory": + return Memory[*self.shape, self.address_space, self.dtype] + @define_op("register") @dataclass @@ -521,6 +525,10 @@ class NewRegister(CustomOp): def indexing_dims(self) -> list[IndexSymbol]: return list(self.shape) + @property + def type(self) -> "Register": + return Register[*self.shape, self.dtype] + @define_op("mma") @dataclass @@ -551,6 +559,10 @@ def rhs_type(self) -> Memory: def acc_type(self) -> Memory: return get_custom(self.acc).type + @property + def type(self) -> Memory: + return self.acc_type + def operand_index( self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr] ) -> list[IndexSequence]: @@ -572,7 +584,7 @@ def rhs_index(self) -> list[IndexSequence]: @property def acc_index(self) -> list[IndexSequence]: operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 1} - if self.acc.type is None: + if self.acc_type is None: return None return self.operand_index(operand_map, self.acc_type.symbolic_shape) @@ -598,11 +610,11 @@ def indexing_dims(self) -> list[IndexSymbol]: if self.mapping is not None: return list(self.mapping.output_shape) # TODO: This could contain ints. - return list(self.memory.type.symbolic_shape) + return list(self.type.symbolic_shape) @property def type(self) -> "Memory": - return self.memory.type + return get_custom(self.memory).type @define_op("reduction") @@ -663,11 +675,11 @@ def indexing_dims(self) -> list[IndexSymbol]: if self.mapping is not None: return list(self.mapping.input_shape) # TODO: This could contain ints. - return list(self.memory.type.symbolic_shape) + return list(self.type.symbolic_shape) @property def type(self) -> "Memory": - return self.memory.type + return get_custom(self.memory).type @define_op("get_result") diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index 84a46a82..f30ac369 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -5,7 +5,9 @@ import torch.fx as fx from ..compiler.ir import ( + Attribute, DenseElementsAttr, + FloatAttr, IndexType, InsertionPoint, IntegerAttr, @@ -17,9 +19,11 @@ ShapedType, Value, VectorType, + amdgpu_d, arith_d, func_d, gpu_d, + memref_d, stream_d, vector_d, ) @@ -27,7 +31,7 @@ # TK infrastructure imports. from shark_turbine.kernel.lang.global_symbols import * -from ..ops.wave_ops import write, register, mma, read, reduction, get_custom +from ..ops.wave_ops import write, register, mma, read, reduction, get_custom, allocate from ..lang.wave_types import IndexMapping from ..compiler.base import CodegenError, ValidationError, NDEBUG from ..compiler.kernel_codegen import BoundKernelSignature @@ -40,6 +44,7 @@ cast_py_value, cast_vector, ) +from .constraints import Constraint, HardwareConstraint, MMAType # Indexing imports. from .._support.indexing import IndexingContext, IndexExpr, IndexSequence @@ -51,6 +56,7 @@ class WaveEmitter: root_sig: BoundKernelSignature trace: CapturedTrace + constraints: list[Constraint] ip: InsertionPoint = None OP_HANDLERS: ClassVar[dict[str, Callable[["WaveEmitter", fx.Node], None]]] = {} _node_values: ClassVar[dict[fx.Node, List[IRProxyValue]]] = {} @@ -102,6 +108,7 @@ def lookup_node_values(self, node: fx.Node) -> List[Value]: if values is None: values = [self.root_sig.resolve_by_reference(("node", node))] self._node_values[node] = values + values = [v.ir_value if isinstance(v, IRProxyValue) else v for v in values] return values def bind_node_proxy(self, node: fx.Node, proxy: IRProxyValue): @@ -203,6 +210,14 @@ def gen_sympy_index(emitter: WaveEmitter, expr: sympy.Expr) -> OpResult: return stack[0] +def get_constant_attr(value: Any, element_type: IrType) -> Attribute: + if _is_integer_like_type(element_type): + return IntegerAttr.get(element_type, int(value)) + if _is_float_type(element_type): + return FloatAttr.get(element_type, float(value)) + raise CodegenError(f"Cannot create a constant attribute for type `{element_type}`") + + def handle_op(op: Callable[..., Any]): def decorator( f: Callable[[WaveEmitter, fx.Node], None] @@ -220,7 +235,36 @@ def decorator( @handle_op(register) def handle_register(emitter: WaveEmitter, node: fx.Node): - raise NotImplementedError("Register: Currently only stub implementation") + try: + shape, dtype, value = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + if hasattr(node, "thread_shape"): + shape = [node.thread_shape] + vector_shape = cast_py_literal(emitter, shape) + element_type = IrType.parse(dtype.ir_type_asm()) + vector_type = VectorType.get(vector_shape, element_type) + register = arith_d.ConstantOp( + vector_type, + DenseElementsAttr.get_splat( + vector_type, get_constant_attr(value, element_type) + ), + ).result + emitter.bind_node_proxy(node, IRProxyValue(register)) + + +@handle_op(allocate) +def handle_allocate(emitter: WaveEmitter, node: fx.Node): + try: + shape, distributed_shape, dtype, address_space = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + memref_shape = cast_py_literal(emitter, distributed_shape) + element_type = IrType.parse(dtype.ir_type_asm()) + address_space = Attribute.parse("#gpu.address_space") + memref_type = MemRefType.get(memref_shape, element_type, None, address_space) + alloc = memref_d.alloc(memref_type, [], []) + emitter.bind_node_proxy(node, IRProxyValue(alloc)) def _get_start_indices( @@ -352,7 +396,8 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_read=True, ) - zero = arith_d.ConstantOp(vector_type.element_type, 0) + zero = int(0) if _is_integer_like_type(element_type) else float(0) + zero = arith_d.ConstantOp(vector_type.element_type, zero) passthru = vector_d.splat(vector_type, zero) result = vector_d.gather( @@ -373,11 +418,13 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): kb_dest, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, memory) insert_vector = cast_vector(emitter, register, element_type=kb_ir_type.element_type) insert_type = VectorType(insert_vector.type) + vector_shape = cast_py_literal(emitter, (elements_per_thread,)) # TODO: Support elements_per_thread size mismatch and broadcasting - assert tuple(insert_type.shape) == ( - elements_per_thread, - ), f"Shape doesn't match: {tuple(insert_type.shape)} and {(elements_per_thread,)}" + + assert ( + tuple(insert_type.shape) == vector_shape + ), f"Shape doesn't match: {tuple(insert_type.shape)} and {(vector_shape)}" if not hasattr(node, "index"): raise ValidationError("codegen expected read to have index attr.") @@ -412,9 +459,48 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): ############################################################################### +def emit_mfma( + m: int, n: int, k: int, vector_type: VectorType, acc: Value, values: list[Value] +): + m = get_constant_attr(m, IntegerType.get_signless(32)) + n = get_constant_attr(n, IntegerType.get_signless(32)) + k = get_constant_attr(k, IntegerType.get_signless(32)) + blocks = get_constant_attr(1, IntegerType.get_signless(32)) + + result = amdgpu_d.mfma( + dest_d=vector_type, + m=m, + n=n, + k=k, + blocks=blocks, + source_a=values[0], + source_b=values[1], + dest_c=acc, + ) + return result + + @handle_op(mma) def handle_mma(emitter: WaveEmitter, node: fx.Node): - raise NotImplementedError("MMA: Currently only stub implementation") + try: + lhs, rhs, acc = node.args + acc = cast_vector(emitter, acc) + values = [lhs, rhs] + for i in range(len(values)): + values[i] = cast_vector(emitter, values[i]) + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + vector_type = VectorType(acc.type) + result = None + for constraint in emitter.constraints: + if isinstance(constraint, HardwareConstraint): + m, n, k = constraint.mma_matrix_shapes + result = emit_mfma(m, n, k, vector_type, acc, values) + break + + if result: + emitter.bind_node_proxy(node, IRProxyValue(result)) @handle_op(operator.add) diff --git a/shark_turbine/kernel/wave/constraints.py b/shark_turbine/kernel/wave/constraints.py index d725a719..3917da13 100644 --- a/shark_turbine/kernel/wave/constraints.py +++ b/shark_turbine/kernel/wave/constraints.py @@ -213,3 +213,19 @@ def apply(self) -> IndexSequence: if self.wave_id is None: raise ValueError("Index is being computed without setting wave id") return IndexSequence(self.tile_size * self.wave_id, 1) + + +def get_workgroup_distributed_shape( + shape: list[IndexExpr], constraints: list[WorkgroupConstraint] +) -> tuple[IndexExpr]: + """ + Given a shape and workgroup constraints, returns the shape + of the tensor after it has been distributed along workgroup dimensions. + """ + distributed_shape = [s for s in shape] + for i, dim in enumerate(shape): + for constraint in constraints: + if isinstance(constraint, WorkgroupConstraint): + if dim == constraint.dim: + distributed_shape[i] = constraint.tile_size + return tuple(distributed_shape) diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index c49ca119..62f8b9c5 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -74,7 +74,7 @@ def get_indexed_dims( """ if isinstance(nodeOrDims, CustomOp): nodeOrDims = nodeOrDims.indexing_dims - return tuple((key, all_dims[key]) for key in nodeOrDims) + return tuple((key, all_dims[key]) for key in nodeOrDims if key in all_dims) def get_last(node_list: fx.graph._node_list) -> fx.Node: # type: ignore @@ -173,8 +173,11 @@ def set_node_index( index_seq.start += constraint.apply().start if index_seq is not None: - index_seq.start += dim_scaling[dim] * dim_tile_size[dim] + if dim in dim_scaling and dim in dim_tile_size: + index_seq.start += dim_scaling[dim] * dim_tile_size[dim] custom.index = {dim: index_seq} + else: + custom.index = {dim: IndexSequence(0, 1, 1)} setattr(custom.fx_node, "index", custom.index) @@ -464,6 +467,7 @@ def _handle_reduction_dim( # placeholder which will not trigger further expansion. index = user.node_args.index(carried_node) dummy = Placeholder("dummy").add_to_graph(user.graph) + dummy.type = None saved_arg = user.node_args[index] user.update_arg(index, dummy) diff --git a/shark_turbine/kernel/wave/promotion.py b/shark_turbine/kernel/wave/promotion.py index 2b36fa62..1e203971 100644 --- a/shark_turbine/kernel/wave/promotion.py +++ b/shark_turbine/kernel/wave/promotion.py @@ -3,15 +3,16 @@ from .._support.indexing import IndexingContext from ..ops.wave_ops import * from ..lang.global_symbols import * +from .constraints import Constraint, get_workgroup_distributed_shape logger = get_logger("turbine.wave.promotion") def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate): match custom_node: - case Read( - memory, elements_per_thread - ) if memory.type.address_space != allocate_node.address_space: + case Read(memory, elements_per_thread) if get_custom( + memory + ).type.address_space != allocate_node.address_space: promoted_read = Read( allocate_node.fx_node, elements_per_thread ).add_to_graph(custom_node.graph) @@ -20,11 +21,11 @@ def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate): Write( custom_node.fx_node, allocate_node.fx_node, elements_per_thread ).add_to_graph(custom_node.graph) - case _: - logger.error(f"Attempted to promoted unsupported operator {custom_node}") -def promote_node(node: Read | Write, address_space: IndexSymbol): +def promote_node( + node: Read | Write, address_space: IndexSymbol, constraints: list[Constraint] +): """Promotes the given operand in the provided graph to the specified address space. @@ -35,20 +36,29 @@ def promote_node(node: Read | Write, address_space: IndexSymbol): assert isinstance(node, Read) or isinstance(node, Write) with node.graph.inserting_before(node.fx_node.next): + workgroup_distributed_shape = get_workgroup_distributed_shape( + node.type.symbolic_shape, constraints + ) allocate_node = Allocate( - node.type.symbolic_shape, node.type.dtype, address_space + node.type.symbolic_shape, + workgroup_distributed_shape, + node.type.dtype, + address_space, ) allocate_node.add_to_graph(node.graph) apply_promotion_pattern(node, allocate_node) -def promote_placeholders(graph: CapturedTrace): - for node in graph.get_root_graph().nodes: +def promote_placeholders(graph: CapturedTrace, constraints: list[Constraint]): + read_or_write_nodes = graph.walk( + lambda node: isinstance(get_custom(node), Read) + or isinstance(get_custom(node), Write) + ) + for node in read_or_write_nodes: custom = get_custom(node) - if isinstance(custom, Read) or isinstance(custom, Write): - if not custom.type: - continue - idxc = IndexingContext.current() - address_space = custom.type.address_space.subs(idxc.subs) - if address_space == SHARED_ADDRESS_SPACE: - promote_node(custom, address_space) + if not custom.type: + continue + idxc = IndexingContext.current() + address_space = custom.type.address_space.subs(idxc.subs) + if address_space == SHARED_ADDRESS_SPACE: + promote_node(custom, address_space, constraints) diff --git a/shark_turbine/kernel/wave/register_analysis.py b/shark_turbine/kernel/wave/register_analysis.py new file mode 100644 index 00000000..f7a4d79d --- /dev/null +++ b/shark_turbine/kernel/wave/register_analysis.py @@ -0,0 +1,32 @@ +from .._support.tracing import CapturedTrace +from ...support.logging import get_logger +from ..ops.wave_ops import * + +logger = get_logger("turbine.wave.register_analysis") + + +def determine_register_shape(trace: CapturedTrace) -> None: + """ + Each register op is annotated with the wave shape of the register. This + function determines the thread shape of the register based on the uses + of the register in the graph. + """ + register_nodes = trace.walk(lambda node: isinstance(get_custom(node), NewRegister)) + if not register_nodes: + return + for node in register_nodes: + custom_node = get_custom(node) + for user in node.users.keys(): + custom_user = get_custom(user) + if isinstance(custom_user, MMA): + arg_index = user.args.index(node) + if arg_index == 0: + custom_node.fx_node.thread_shape = custom_user.lhs_index[0].size + if arg_index == 1: + custom_node.fx_node.thread_shape = custom_user.rhs_index[0].size + if arg_index == 2: + custom_node.fx_node.thread_shape = custom_user.acc_index[0].size + else: + raise NotImplementedError( + f"Register shape propagation not implemented for {user}" + ) diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py new file mode 100644 index 00000000..d7088e63 --- /dev/null +++ b/shark_turbine/kernel/wave/utils.py @@ -0,0 +1,35 @@ +from ..compiler.ir import ( + builtin_d, + InsertionPoint, + Location, + Operation, + transform_d, + UnitAttr, +) + +from iree.compiler.dialects.transform import ( + interpreter as transform_interpreter, + any_op_t, +) + + +def canonicalize_module(module: Operation): + with module.context, Location.unknown(): + transform_module = builtin_d.Module.create() + transform_module_op = module.operation + transform_module_op.attributes["transform.with_named_sequence"] = UnitAttr.get() + with InsertionPoint(transform_module.body): + named_sequence = transform_d.NamedSequenceOp( + "__transform_main", [any_op_t()], [] + ) + with InsertionPoint(named_sequence.body): + target = named_sequence.body.arguments[0] + apply_patterns = transform_d.ApplyPatternsOp(target) + with InsertionPoint(apply_patterns.regions[0].blocks[0]): + transform_d.apply_patterns_canonicalization() + transform_d.YieldOp([target]) + transform_interpreter.apply_named_sequence( + module, + transform_module.body.operations[0], + transform_module, + ) diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index 26f5b007..cc30cce7 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -17,10 +17,12 @@ from .expansion import expand_graph from .promotion import promote_placeholders from .hoisting import hoist_allocs +from .utils import canonicalize_module from ..lang import Grid, IndexMapping from ..lang.global_symbols import * from ..ops import wave_ops from ..ops.wave_ops import Reduction, CustomOp, get_custom +from .register_analysis import determine_register_shape from .._support.indexing import IndexingContext, IndexExpr import shark_turbine.kernel.lang as tkl from .._support.tracing import ( @@ -173,12 +175,15 @@ def _trace_and_get_kernel_signature( idxc.finalize() # Promote the placeholders to the appropriate address space. - promote_placeholders(graph) + promote_placeholders(graph, self.constraints) hoist_allocs(graph) # Expansion expand_graph(graph, self.constraints) + # Register analysis to determine register shapes. + determine_register_shape(graph) + self.grid_type.dims = [1, 1, 1] for constraint in self.workgroup_constraints: self.grid_type.dims[constraint.workgroup_dim] = ( @@ -197,10 +202,13 @@ def _trace_and_get_kernel_signature( exe = dispatch_codegen.StreamExecutable(mb, name=entrypoint_name) dispatch_entrypoint = exe.define_entrypoint(entrypoint_name, kernel_sig, grid) - emitter = WaveEmitter(dispatch_entrypoint, graph) + emitter = WaveEmitter(dispatch_entrypoint, graph, self.constraints) emitter.emit(graph.get_root_graph()) emitter.finish() + if "canonicalize" in kwargs and kwargs["canonicalize"]: + canonicalize_module(mb.module_op) + return mb, graph def test_execute(self, args, kwargs): diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 21b12c2f..8542bc64 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -77,7 +77,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: K: 256, } with pytest.raises( - NotImplementedError, match="Currently only stub implementation" + NotImplementedError, + match="Register shape propagation not implemented for reduction", ): with tk.gen.TestLaunchContext(hyperparams): a = torch.randn(64, 256, dtype=torch.float16)