Skip to content

Commit

Permalink
Add lowerings for mma, register and allocate
Browse files Browse the repository at this point in the history
This PR adds a mma unit test which lowers to
vector.loads/stores and amdgpu.mfmas. Also supports
shared memory promotion.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Aug 16, 2024
1 parent 7f9b93e commit 91eea4a
Show file tree
Hide file tree
Showing 15 changed files with 370 additions and 56 deletions.
98 changes: 98 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.lang.global_symbols import *
import torch

M = tkl.sym.M
Expand All @@ -13,7 +14,10 @@
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEM_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEM_PER_THREAD
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


def launch(func: Callable[[], None]) -> Callable[[], None]:
Expand Down Expand Up @@ -247,6 +251,100 @@ def test(
# CHECK: vector.scatter %[[OUT]][%[[IDX_Y]], %[[IDX_X]]] [%[[OFF]]], %[[MASK]], %[[RES]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16>


@run
def test_mma():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

@tkw.wave(constraints)
def mma(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
acc = tkw.mma(a_reg, b_reg, c_reg)
tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
M: 64,
N: 128,
K: 16,
BLOCK_M: 32,
BLOCK_N: 32,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
}
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
print(mma(a, b, c, canonicalize=True).module_op)

# CHECK: func.func @mma(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %[[ARG2:.+]]: !stream.binding) {
# CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
# CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
# CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[ACC:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
# CHECK: %[[WG0:.+]] = stream.dispatch.workgroup.id[0] : index
# CHECK: %[[WG1:.+]] = stream.dispatch.workgroup.id[1] : index
# CHECK: %[[TX:.+]] = gpu.thread_id x
# CHECK: %[[TY:.+]] = gpu.thread_id y
# CHECK: %[[R0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x16xf16, strided<[16, 1], offset: ?>>
# CHECK: %[[R1:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R2:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R3:.+]] = arith.addi %[[R2]], %[[R1]] : index
# CHECK: %[[R4:.+]] = vector.load %0[%[[R3]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
# CHECK: %[[R5:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R6:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R7:.+]] = arith.addi %[[R6]], %[[R5]] : index
# CHECK: vector.store %4, %[[ALLOC]][%[[R7]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R8:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R9:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R10:.+]] = arith.addi %[[R9]], %[[R8]] : index
# CHECK: %[[R11:.+]] = vector.load %[[ALLOC]][%[[R10]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, strided<[16, 1], offset: ?>>
# CHECK: %[[R13:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R14:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R15:.+]] = arith.addi %[[R14]], %[[R13]] : index
# CHECK: %[[R16:.+]] = vector.load %[[R12]][%[[R15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
# CHECK: %[[R17:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R18:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R19:.+]] = arith.addi %[[R18]], %[[R17]] : index
# CHECK: vector.store %16, %[[ALLOC_0]][%[[R19]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R20:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R21:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R22:.+]] = arith.addi %[[R21]], %[[R20]] : index
# CHECK: %[[R23:.+]] = vector.load %[[ALLOC_0]][%[[R22]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R24:.+]] = amdgpu.mfma %[[R11]] * %[[R23]] + %[[ACC]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
# CHECK: %[[R25:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, strided<[128, 1], offset: ?>>
# CHECK: %[[R26:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R27:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R28:.+]] = arith.addi %[[R27]], %[[R26]] : index
# CHECK: %[[R29:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R30:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R31:.+]] = arith.addi %[[R30]], %[[R29]] : index
# CHECK: vector.store %[[R24]], %[[R25]][%[[R28]], %[[R31]]] : memref<64x128xf32, strided<[128, 1], offset: ?>>, vector<4xf32>


@run
def test_add_float():
constraints: list[tkw.Constraint] = [
Expand Down
20 changes: 10 additions & 10 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,25 +350,25 @@ def test_gemm():
# CHECK-SAME: acc=acc_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_0 (index = None))
# CHECK-SAME: acc=mma_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_1_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_1_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_1_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_1_1_0 (index = None))
# CHECK-SAME: acc=mma_1_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_1_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_0_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_1_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_1_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_1_0_0 (index = None))
# CHECK-SAME: acc=mma_1_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_0 (index = None))
# CHECK-SAME: acc=mma_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1],))

# CHECK-NEXT: -----
Expand Down Expand Up @@ -502,25 +502,25 @@ def test_gemm_reduction_expansion_only():
# CHECK-SAME: acc=acc_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_0 (index = None))
# CHECK-SAME: acc=mma_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_2 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_2 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_1 (index = None))
# CHECK-SAME: acc=mma_0_0_1 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_3 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_3 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_2 (index = None))
# CHECK-SAME: acc=mma_0_0_2 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_0 (index = None))
# CHECK-SAME: acc=mma_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_2 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_2 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_1 (index = None))
# CHECK-SAME: acc=mma_0_1_1 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_3 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_3 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_2 (index = None))
# CHECK-SAME: acc=mma_0_1_2 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: output(return_vals=([mma_0_0_3, mma_0_1_3],))

# CHECK-NEXT: -----
Expand Down
14 changes: 7 additions & 7 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def test_read_write_equal_sizes():
graph: fx.Graph = trace.get_root_graph()
read_node = get_read_nodes(graph)[0]
IndexingContext.current().finalize()
promote_node(read_node, SHARED_ADDRESS_SPACE)
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
print_trace(trace)
# CHECK: %a
# CHECK-NEXT: %c
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4, None)
# CHECK-NEXT: %read_1
Expand Down Expand Up @@ -136,14 +136,14 @@ def test_read_write_equal_sizes_different_address_spaces():
):
trace: CapturedTrace = read_write_same_size_different_address_spaces()
IndexingContext.current().finalize()
promote_placeholders(trace)
promote_placeholders(trace, constraints)
print_trace(trace)
# CHECK: %a
# CHECK-NEXT: %c
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4, None)
# CHECK-NEXT: %read_1
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_gemm():
graph: fx.Graph = trace.get_subgraph("region_0")
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE)
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
hoist_allocs(trace)
IndexingContext.current().finalize()
print_trace(trace)
Expand All @@ -201,9 +201,9 @@ def test_gemm():
# CHECK-NEXT: %c
# CHECK-NEXT: %register
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-NEXT: %write
# CHECK-SAME: (%reduction, %c, 4, None)
Expand Down
4 changes: 3 additions & 1 deletion shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@

from iree.compiler.dialects import (
arith as arith_d,
amdgpu as amdgpu_d,
builtin as builtin_d,
flow as flow_d,
func as func_d,
gpu as gpu_d,
math as math_d,
memref as memref_d,
stream as stream_d,
vector as vector_d,
scf as scf_d,
transform as transform_d,
vector as vector_d,
)
11 changes: 10 additions & 1 deletion shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,11 @@ def cast_py_value(emitter: ThreadEmitter, value) -> IRProxyValue:
try:
node_values = emitter.lookup_node_values(value)
assert len(node_values) == 1, f"Expected exactly one value for node {value}"
return node_values[0]
return (
node_values[0]
if isinstance(node_values[0], IRProxyValue)
else IRProxyValue(node_values[0])
)
except KeyError:
raise CodegenError(f"Producer node `{value}` has no IR Value")
elif isinstance(value, IndexExpr):
Expand Down Expand Up @@ -828,6 +832,11 @@ def cast_kernel_buffer(
value, node = cast_py_lvalue(emitter, kb)
ir_type = value.type
py_type = node.type
if py_type is None:
try:
py_type = ops.wave_ops.get_custom(node).type
except:
raise CodegenError(f"Could not find type for node {node}")

if not MemRefType.isinstance(ir_type):
raise CodegenError(
Expand Down
3 changes: 2 additions & 1 deletion shark_turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)

from .kernel_buffer import AddressSpace, KernelBufferMeta, KernelBufferUsage
from ..ops.wave_ops import register
from .._support.dtype import DataType
from .._support.indexing import IndexExpr, IndexSymbol, index_symbol

Expand Down Expand Up @@ -101,6 +100,8 @@ class Register(metaclass=KernelBufferMeta):
value: float

def __new__(cls, value: float) -> "Register":
from ..ops.wave_ops import register

return register(cls.symbolic_shape, cls.dtype, value)

def __class_getitem__(
Expand Down
Loading

0 comments on commit 91eea4a

Please sign in to comment.