Skip to content

Commit

Permalink
Add CustomOp support for python operators (#54)
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Co-authored-by: Martin Lücke <martin.luecke@ed.ac.uk>
  • Loading branch information
harsh-nod and martin-luecke authored Jul 29, 2024
1 parent c940f14 commit 44c7634
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 1 deletion.
89 changes: 88 additions & 1 deletion lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import shark_turbine.kernel.wave as tkw
import torch


M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE


Expand All @@ -25,6 +27,9 @@ def launch(func: Callable[[], None]) -> Callable[[], None]:
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}
):
Expand All @@ -37,6 +42,8 @@ def test_read():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
Expand All @@ -49,4 +56,84 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
test(a)


@launch
def test_add():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
res = a + a
tkw.write(res, a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="add: Currently only stub implementation"
):
test(a)


@launch
def test_neg():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
res = -a
tkw.write(res, a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="neg: Currently only stub implementation"
):
test(a)


@launch
def test_sub():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
res = a - a
tkw.write(res, a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="sub: Currently only stub implementation"
):
test(a)


@launch
def test_get_item():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
res = a[0]
tkw.write(res, a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="getitem: Currently only stub implementation"
):
test(a)


# TODO: Add more tests once we have more than a stub implementation.
60 changes: 60 additions & 0 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,66 @@ def test_gemm_reduction_expansion_only():
# CHECK-NEXT: -----


@tkw.wave_trace_only()
def py_arithmetic_different_dims(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f32],
):
a_reg = tkw.read(a, elements_per_thread=4)
a_reg = a_reg + a_reg - a_reg
a_reg = -a_reg
tkw.write(a_reg, c, elements_per_thread=4)


@run
def py_arithmetic_different_dims():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(K, BLOCK_K, 2)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
with tk.gen.TestLaunchContext(
{
BLOCK_M: 32,
BLOCK_N: 32,
BLOCK_K: 32,
}
):
graph = py_arithmetic_different_dims()
IndexingContext.current().finalize()
expand_graph(graph, constraints)
print_trace(graph)
# CHECK: %a
# CHECK-NEXT: %c
# CHECK-NEXT: %read_0_0_0
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %read_1_0_0
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %add_0_0_0
# CHECK-SAME: (%read_0_0_0, %read_0_0_0)
# CHECK-NEXT: %add_1_0_0
# CHECK-SAME: (%read_1_0_0, %read_1_0_0)
# CHECK-NEXT: %sub_0_0_0
# CHECK-SAME: (%add_0_0_0, %read_0_0_0)
# CHECK-NEXT: %sub_1_0_0
# CHECK-SAME: (%add_1_0_0, %read_1_0_0)
# CHECK-NEXT: %neg_0_0_0
# CHECK-SAME: (%sub_0_0_0,)
# CHECK-NEXT: %neg_1_0_0
# CHECK-SAME: (%sub_1_0_0,)
# CHECK-NEXT: %write_0_0_0
# CHECK-SAME: (%neg_0_0_0, %c, 4)
# CHECK-NEXT: %write_1_0_1
# CHECK-SAME: (%neg_1_0_0, %c, 4)
# CHECK-NEXT: %write_1_0_0
# CHECK-SAME: (%neg_1_0_0, %c, 4)
# CHECK-NEXT: %write_0_0_1
# CHECK-SAME: (%neg_0_0_0, %c, 4)

# CHECK: -----


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
34 changes: 34 additions & 0 deletions lit_tests/kernel/wave/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,40 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
# CHECK-NEXT: output


@run
def test_trace_py_arithmetic():
@tkw.wave_trace_only()
def test(A: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
a = tkw.read(A)
res = a + a - a
res = -res
tkw.write(res, A, elements_per_thread=4)

trace = test()
print_trace(trace)
# CHECK: %a
# CHECK-NEXT: %read
# CHECK-SAME: (%a, None)
# CHECK-NEXT: %add
# CHECK-SAME: (%read, %read)
# CHECK-NEXT: %sub
# CHECK-SAME: (%add, %read)
# CHECK-NEXT: %neg
# CHECK-SAME: (%sub,)
# CHECK-NEXT: %write
# CHECK-SAME: (%neg, %a, 4)
# CHECK-NEXT: return None

# Custom format:
# CHECK-NEXT: placeholder
# CHECK-NEXT: read(memory=a
# CHECK-NEXT: add(lhs=read, rhs=read)
# CHECK-NEXT: sub(lhs=add, rhs=read)
# CHECK-NEXT: neg(arg=sub)
# CHECK-NEXT: write(register_=neg, memory=a, elements_per_thread=4)
# CHECK-NEXT: output


@run
def test_trace_read():
@tkw.wave_trace_only()
Expand Down
98 changes: 98 additions & 0 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC
from dataclasses import dataclass, field, fields
import operator
import sys
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -91,6 +92,57 @@ def new_function(*args: Any, **kwargs: dict[str, Any]):
return decorator


def define_py_op(py_op: Callable) -> Callable[[T], T]:
"""
Register python internal operators as custom ops.
This overloads python operator specific functions such as __add__ of
fx.Proxy with a handler in order to control the tracing of the operator and
map it to a dynamically created sublclass of UnaryPyOp or BinaryPyOp.
"""
op_name = py_op.__name__

def decorator(cls: T) -> T:
# define new subclass of cls to represent this op
@dataclass
class NewSubclass(cls):
pass

NewSubclass.tkw_op_name = op_name
NewSubclass.__name__ = f"{op_name.capitalize()}"
NewSubclass.__module__ = cls.__module__
current_module = sys.modules[cls.__module__]
setattr(current_module, NewSubclass.__name__, NewSubclass)

original_handler = None
if hasattr(fx.Proxy, f"__{op_name}__"):
original_handler = getattr(fx.Proxy, f"__{op_name}__")

def new_function(*args: Any, **kwargs: dict[str, Any]):
dispatcher = None
try:
dispatcher = OpDispatcher.current()
except IndexError:
handler = original_handler

if dispatcher:
try:
handler = getattr(dispatcher, f"handle_{op_name}")
except AttributeError:
handler = original_handler

return handler(*args, **kwargs)

if original_handler:
new_function.__name__ = op_name
NewSubclass._tracing_function = new_function
setattr(fx.Proxy, f"__{op_name}__", new_function)

# Return cls unchanged so we can reuse the decorator to register more ops
return cls

return decorator


def get_custom(node: fx.Node) -> "CustomOp":
"""Get the corresponding CustomOp for a given fx.Node."""
if isinstance(node, CustomOp):
Expand Down Expand Up @@ -255,6 +307,52 @@ def indexing_dims(self) -> list[IndexSymbol]:
return []


@define_py_op(operator.getitem)
@define_py_op(operator.add)
@define_py_op(operator.sub)
@dataclass
class BinaryPyOp(CustomOp, ABC):
"""
Represents a binary python operator.
"""

lhs: Any
rhs: Any

@property
def indexing_dims(self) -> list[IndexSymbol]:
combined_dims = []
if isinstance(self.lhs, fx.Node):
combined_dims += get_custom(self.lhs).indexing_dims
if isinstance(self.rhs, fx.Node):
combined_dims += get_custom(self.rhs).indexing_dims

unique_dims = list(dict.fromkeys(combined_dims))
return unique_dims

@property
def py_operator(self) -> str:
return self.tkw_op_name


@define_py_op(operator.neg)
@dataclass
class UnaryPyOp(CustomOp, ABC):
"""
Represents a unary python operator.
"""

arg: fx.Node

@property
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def py_operator(self) -> str:
return self.tkw_op_name


@final
@dataclass
class Unknown(CustomOp):
Expand Down
21 changes: 21 additions & 0 deletions shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
from typing import Any, Callable, ClassVar, Optional
from dataclasses import dataclass
import torch.fx as fx
Expand Down Expand Up @@ -83,6 +84,26 @@ def handle_mma(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("MMA: Currently only stub implementation")


@handle_op(operator.add)
def handle_add(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("add: Currently only stub implementation")


@handle_op(operator.getitem)
def handle_getitem(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("getitem: Currently only stub implementation")


@handle_op(operator.neg)
def handle_neg(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("neg: Currently only stub implementation")


@handle_op(operator.sub)
def handle_sub(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("sub: Currently only stub implementation")


###############################################################################
# Control Flow ops
###############################################################################
Expand Down

0 comments on commit 44c7634

Please sign in to comment.