Skip to content

Commit

Permalink
Address Ivan's comments
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Aug 17, 2024
1 parent 91eea4a commit b3612ab
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 16 deletions.
5 changes: 3 additions & 2 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,13 @@ def mma(
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
}
},
canonicalize=True,
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
print(mma(a, b, c, canonicalize=True).module_op)
print(mma(a, b, c).module_op)

# CHECK: func.func @mma(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %[[ARG2:.+]]: !stream.binding) {
# CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
Expand Down
5 changes: 4 additions & 1 deletion shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,9 @@ def test_execute(self, args, kwargs):
class LaunchContext(ABC):
__tk_context_idname__ = "ExecutionContext"

def __init__(self, constant_bindings: Dict[IndexSymbol, int] = {}):
def __init__(self, constant_bindings: Dict[IndexSymbol, int] = {}, **kwargs):
self.constant_bindings = constant_bindings
self.kwargs = kwargs

@staticmethod
def current() -> "LaunchContext":
Expand Down Expand Up @@ -464,6 +465,8 @@ def launch(self, launchable: Launchable, args, kwargs):

class TestLaunchContext(LaunchContext):
def launch(self, launchable: Launchable, args, kwargs):
if self.kwargs:
kwargs.update(self.kwargs)
return launchable.test_execute(args, kwargs)


Expand Down
26 changes: 15 additions & 11 deletions shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,22 +485,26 @@ def handle_mma(emitter: WaveEmitter, node: fx.Node):
try:
lhs, rhs, acc = node.args
acc = cast_vector(emitter, acc)
values = [lhs, rhs]
for i in range(len(values)):
values[i] = cast_vector(emitter, values[i])
values = [cast_vector(emitter, val) for val in [lhs, rhs]]
except ValueError as e:
raise ValidationError("Malformed arguments") from e

vector_type = VectorType(acc.type)

hardware_constraints = [
constraint
for constraint in emitter.constraints
if isinstance(constraint, HardwareConstraint)
]
if not hardware_constraints:
raise CodegenError("No hardware constraints found.")

result = None
for constraint in emitter.constraints:
if isinstance(constraint, HardwareConstraint):
m, n, k = constraint.mma_matrix_shapes
result = emit_mfma(m, n, k, vector_type, acc, values)
break

if result:
emitter.bind_node_proxy(node, IRProxyValue(result))
for constraint in hardware_constraints:
m, n, k = constraint.mma_matrix_shapes
result = emit_mfma(m, n, k, vector_type, acc, values)

emitter.bind_node_proxy(node, IRProxyValue(result))


@handle_op(operator.add)
Expand Down
2 changes: 1 addition & 1 deletion shark_turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_workgroup_distributed_shape(
Given a shape and workgroup constraints, returns the shape
of the tensor after it has been distributed along workgroup dimensions.
"""
distributed_shape = [s for s in shape]
distributed_shape = list(shape)
for i, dim in enumerate(shape):
for constraint in constraints:
if isinstance(constraint, WorkgroupConstraint):
Expand Down
2 changes: 1 addition & 1 deletion shark_turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _trace_and_get_kernel_signature(
emitter.emit(graph.get_root_graph())
emitter.finish()

if "canonicalize" in kwargs and kwargs["canonicalize"]:
if kwargs.get("canonicalize", False):
canonicalize_module(mb.module_op)

return mb, graph
Expand Down

0 comments on commit b3612ab

Please sign in to comment.