Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lowerings for mma, register and allocate #86

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.lang.global_symbols import *
import torch

M = tkl.sym.M
Expand All @@ -13,7 +14,10 @@
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEM_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEM_PER_THREAD
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


def launch(func: Callable[[], None]) -> Callable[[], None]:
Expand Down Expand Up @@ -247,6 +251,100 @@ def test(
# CHECK: vector.scatter %[[OUT]][%[[IDX_Y]], %[[IDX_X]]] [%[[OFF]]], %[[MASK]], %[[RES]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16>


@run
def test_mma():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

@tkw.wave(constraints)
def mma(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
acc = tkw.mma(a_reg, b_reg, c_reg)
tkw.write(acc, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
M: 64,
N: 128,
K: 16,
BLOCK_M: 32,
BLOCK_N: 32,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
}
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
print(mma(a, b, c, canonicalize=True).module_op)
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved

# CHECK: func.func @mma(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %[[ARG2:.+]]: !stream.binding) {
# CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
# CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
# CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[ACC:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
# CHECK: %[[WG0:.+]] = stream.dispatch.workgroup.id[0] : index
# CHECK: %[[WG1:.+]] = stream.dispatch.workgroup.id[1] : index
# CHECK: %[[TX:.+]] = gpu.thread_id x
# CHECK: %[[TY:.+]] = gpu.thread_id y
# CHECK: %[[R0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x16xf16, strided<[16, 1], offset: ?>>
# CHECK: %[[R1:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R2:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R3:.+]] = arith.addi %[[R2]], %[[R1]] : index
# CHECK: %[[R4:.+]] = vector.load %0[%[[R3]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
# CHECK: %[[R5:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R6:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R7:.+]] = arith.addi %[[R6]], %[[R5]] : index
# CHECK: vector.store %4, %[[ALLOC]][%[[R7]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R8:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R9:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R10:.+]] = arith.addi %[[R9]], %[[R8]] : index
# CHECK: %[[R11:.+]] = vector.load %[[ALLOC]][%[[R10]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, strided<[16, 1], offset: ?>>
# CHECK: %[[R13:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R14:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R15:.+]] = arith.addi %[[R14]], %[[R13]] : index
# CHECK: %[[R16:.+]] = vector.load %[[R12]][%[[R15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
# CHECK: %[[R17:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R18:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R19:.+]] = arith.addi %[[R18]], %[[R17]] : index
# CHECK: vector.store %16, %[[ALLOC_0]][%[[R19]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R20:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R21:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R22:.+]] = arith.addi %[[R21]], %[[R20]] : index
# CHECK: %[[R23:.+]] = vector.load %[[ALLOC_0]][%[[R22]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R24:.+]] = amdgpu.mfma %[[R11]] * %[[R23]] + %[[ACC]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
# CHECK: %[[R25:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, strided<[128, 1], offset: ?>>
# CHECK: %[[R26:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R27:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R28:.+]] = arith.addi %[[R27]], %[[R26]] : index
# CHECK: %[[R29:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R30:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R31:.+]] = arith.addi %[[R30]], %[[R29]] : index
# CHECK: vector.store %[[R24]], %[[R25]][%[[R28]], %[[R31]]] : memref<64x128xf32, strided<[128, 1], offset: ?>>, vector<4xf32>


@run
def test_add_float():
constraints: list[tkw.Constraint] = [
Expand Down
20 changes: 10 additions & 10 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,25 +350,25 @@ def test_gemm():
# CHECK-SAME: acc=acc_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_0 (index = None))
# CHECK-SAME: acc=mma_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_1_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_1_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_1_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_1_1_0 (index = None))
# CHECK-SAME: acc=mma_1_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_1_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_0_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_1_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_1_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_1_0_0 (index = None))
# CHECK-SAME: acc=mma_1_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) + 16 : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_0 (index = None))
# CHECK-SAME: acc=mma_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1],))

# CHECK-NEXT: -----
Expand Down Expand Up @@ -502,25 +502,25 @@ def test_gemm_reduction_expansion_only():
# CHECK-SAME: acc=acc_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_0 (index = None))
# CHECK-SAME: acc=mma_0_0_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_2 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_2 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_1 (index = None))
# CHECK-SAME: acc=mma_0_0_1 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_3 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: rhs=read_0_0_3 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: acc=mma_0_0_2 (index = None))
# CHECK-SAME: acc=mma_0_0_2 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)]))
# CHECK-NEXT: mma(lhs=read_0_0_0 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: rhs=read_0_1_0 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) : 4 : 1])
# CHECK-SAME: acc=acc_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_1 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_1 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 16 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_0 (index = None))
# CHECK-SAME: acc=mma_0_1_0 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_2 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_2 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 32 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_1 (index = None))
# CHECK-SAME: acc=mma_0_1_1 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: mma(lhs=read_0_0_3 (index = [$T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: rhs=read_0_1_3 (index = [$T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, 16*$T1 + 16*$T2 + ARGK*BLOCK_K + 4*floor($T0/16) + 48 : 4 : 1])
# CHECK-SAME: acc=mma_0_1_2 (index = None))
# CHECK-SAME: acc=mma_0_1_2 (index = [$T0*BLOCK_M/128 + 16*$T1 + 16*$T2 + $WG0*BLOCK_M + 4*floor($T0/16) : 4 : 16, $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16]))
# CHECK-NEXT: output(return_vals=([mma_0_0_3, mma_0_1_3],))

# CHECK-NEXT: -----
Expand Down
14 changes: 7 additions & 7 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def test_read_write_equal_sizes():
graph: fx.Graph = trace.get_root_graph()
read_node = get_read_nodes(graph)[0]
IndexingContext.current().finalize()
promote_node(read_node, SHARED_ADDRESS_SPACE)
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
print_trace(trace)
# CHECK: %a
# CHECK-NEXT: %c
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4, None)
# CHECK-NEXT: %read_1
Expand Down Expand Up @@ -136,14 +136,14 @@ def test_read_write_equal_sizes_different_address_spaces():
):
trace: CapturedTrace = read_write_same_size_different_address_spaces()
IndexingContext.current().finalize()
promote_placeholders(trace)
promote_placeholders(trace, constraints)
print_trace(trace)
# CHECK: %a
# CHECK-NEXT: %c
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4, None)
# CHECK-NEXT: %read_1
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_gemm():
graph: fx.Graph = trace.get_subgraph("region_0")
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE)
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
hoist_allocs(trace)
IndexingContext.current().finalize()
print_trace(trace)
Expand All @@ -201,9 +201,9 @@ def test_gemm():
# CHECK-NEXT: %c
# CHECK-NEXT: %register
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-NEXT: %write
# CHECK-SAME: (%reduction, %c, 4, None)
Expand Down
4 changes: 3 additions & 1 deletion shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@

from iree.compiler.dialects import (
arith as arith_d,
amdgpu as amdgpu_d,
builtin as builtin_d,
flow as flow_d,
func as func_d,
gpu as gpu_d,
math as math_d,
memref as memref_d,
stream as stream_d,
vector as vector_d,
scf as scf_d,
transform as transform_d,
vector as vector_d,
)
11 changes: 10 additions & 1 deletion shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,11 @@ def cast_py_value(emitter: ThreadEmitter, value) -> IRProxyValue:
try:
node_values = emitter.lookup_node_values(value)
assert len(node_values) == 1, f"Expected exactly one value for node {value}"
return node_values[0]
return (
node_values[0]
if isinstance(node_values[0], IRProxyValue)
else IRProxyValue(node_values[0])
)
except KeyError:
raise CodegenError(f"Producer node `{value}` has no IR Value")
elif isinstance(value, IndexExpr):
Expand Down Expand Up @@ -828,6 +832,11 @@ def cast_kernel_buffer(
value, node = cast_py_lvalue(emitter, kb)
ir_type = value.type
py_type = node.type
if py_type is None:
try:
py_type = ops.wave_ops.get_custom(node).type
except:
raise CodegenError(f"Could not find type for node {node}")

if not MemRefType.isinstance(ir_type):
raise CodegenError(
Expand Down
3 changes: 2 additions & 1 deletion shark_turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)

from .kernel_buffer import AddressSpace, KernelBufferMeta, KernelBufferUsage
from ..ops.wave_ops import register
from .._support.dtype import DataType
from .._support.indexing import IndexExpr, IndexSymbol, index_symbol

Expand Down Expand Up @@ -101,6 +100,8 @@ class Register(metaclass=KernelBufferMeta):
value: float

def __new__(cls, value: float) -> "Register":
from ..ops.wave_ops import register

return register(cls.symbolic_shape, cls.dtype, value)

def __class_getitem__(
Expand Down
Loading
Loading