Skip to content

Commit

Permalink
Add Codegen e2e flow
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu committed Aug 1, 2024
1 parent b094185 commit 07265b6
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 28 deletions.
40 changes: 24 additions & 16 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,19 @@ def launch(func: Callable[[], None]) -> Callable[[], None]:
func()
return func

def codegen_test_context():
return tk.gen.TestLaunchContext(
{
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}
)

@launch
def test_read():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
Expand All @@ -49,16 +60,13 @@ def test_read():

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
tkw.read(a)
tkw.read(a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="Read: Currently only stub implementation"
):
test(a)
with codegen_test_context():
a = torch.randn(16, 16, dtype=torch.float16)
print(test(a).module_op)


@launch
def test_add():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
Expand All @@ -70,17 +78,15 @@ def test_add():

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
res = a + a
tkw.write(res, a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="add: Currently only stub implementation"
):
test(a)
a_reg = tkw.read(a, elements_per_thread=4)
res = a_reg + a_reg

with codegen_test_context():
a = torch.randn(16, 16, dtype=torch.float16)
print(test(a).module_op)

@launch
@pytest.mark.skip(reason="neg: Currently only stub implementation")
def test_neg():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
Expand All @@ -103,6 +109,7 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):


@launch
@pytest.mark.skip(reason="sub: Currently only stub implementation")
def test_sub():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
Expand All @@ -125,6 +132,7 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):


@launch
@pytest.mark.skip(reason="getitem: Currently only stub implementation")
def test_get_item():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
Expand Down
1 change: 1 addition & 0 deletions shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
builtin as builtin_d,
flow as flow_d,
func as func_d,
gpu as gpu_d,
math as math_d,
memref as memref_d,
stream as stream_d,
Expand Down
44 changes: 44 additions & 0 deletions shark_turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
KernelBufferUsage,
is_kernel_buffer_meta_derived,
)
from ..lang.wave_types import Memory
from ..lang.grid import Grid

from .base import (
Expand Down Expand Up @@ -189,6 +190,49 @@ def add_grid(self, grid: Type[Grid]):
)
)

def determine_input_output_buffers(self, graph: fx.Graph):
placeholder_nodes: list[fx.Node] = []
for node in graph.nodes:
if node.op != "placeholder":
continue
placeholder_nodes.append(node)

def only_read_dependencies(node):
return all(["read" in x.name for x in node.users.keys()])

def only_write_dependencies(node):
if len(node.users) == 0:
return False
return all(["write" in x.name for x in node.users.keys()])

for node in placeholder_nodes:
index = None
for i, binding in enumerate(self.bindings):
if binding.reference[1] == node:
index = i
break
if index == None:
continue
# TODO: remove this hack, this is just to make things pass
# I did not investigate yet why it does not correctly determine the
# buffer to only have read dependencies, even though that is the case
usage = KernelBufferUsage.INPUT
if only_read_dependencies(node):
usage = KernelBufferUsage.INPUT

if only_write_dependencies(node):
usage = KernelBufferUsage.OUTPUT

# Create new Memory type with the correct usage
memory_type = self.bindings[index].kernel_buffer_type
self.bindings[index].kernel_buffer_type = Memory[
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
]
return

def __repr__(self):
parts = []
for b in self.bindings:
Expand Down
2 changes: 1 addition & 1 deletion shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def cast_kernel_buffer(
f"Expected a KernelBuffer (aka. `memref`) but got `{ir_type}`"
)

if not issubclass(py_type, KernelBuffer):
if not (issubclass(py_type, KernelBuffer) or issubclass(py_type, tkl.Memory)):
raise CodegenError(
f"Expected an lvalue of type KernelBuffer but got '{py_type}' for node {node}"
)
Expand Down
Loading

0 comments on commit 07265b6

Please sign in to comment.