diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 6c80e55b..b63c3145 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -289,12 +289,13 @@ def mma( STORE_ELEMS_PER_THREAD: 4, ADDRESS_SPACE: SHARED_ADDRESS_SPACE, ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - } + }, + canonicalize=True, ): a = torch.randn(64, 32, dtype=torch.float16) b = torch.randn(128, 32, dtype=torch.float16) c = torch.zeros(64, 128, dtype=torch.float32) - print(mma(a, b, c, canonicalize=True).module_op) + print(mma(a, b, c).module_op) # CHECK: func.func @mma(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %[[ARG2:.+]]: !stream.binding) { # CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index diff --git a/shark_turbine/kernel/_support/tracing.py b/shark_turbine/kernel/_support/tracing.py index ec609561..d1a41604 100644 --- a/shark_turbine/kernel/_support/tracing.py +++ b/shark_turbine/kernel/_support/tracing.py @@ -420,8 +420,9 @@ def test_execute(self, args, kwargs): class LaunchContext(ABC): __tk_context_idname__ = "ExecutionContext" - def __init__(self, constant_bindings: Dict[IndexSymbol, int] = {}): + def __init__(self, constant_bindings: Dict[IndexSymbol, int] = {}, **kwargs): self.constant_bindings = constant_bindings + self.kwargs = kwargs @staticmethod def current() -> "LaunchContext": @@ -464,6 +465,8 @@ def launch(self, launchable: Launchable, args, kwargs): class TestLaunchContext(LaunchContext): def launch(self, launchable: Launchable, args, kwargs): + if self.kwargs: + kwargs.update(self.kwargs) return launchable.test_execute(args, kwargs) diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index f30ac369..8a1e8ccd 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -485,22 +485,26 @@ def handle_mma(emitter: WaveEmitter, node: fx.Node): try: lhs, rhs, acc = node.args acc = cast_vector(emitter, acc) - values = [lhs, rhs] - for i in range(len(values)): - values[i] = cast_vector(emitter, values[i]) + values = [cast_vector(emitter, val) for val in [lhs, rhs]] except ValueError as e: raise ValidationError("Malformed arguments") from e vector_type = VectorType(acc.type) + + hardware_constraints = [ + constraint + for constraint in emitter.constraints + if isinstance(constraint, HardwareConstraint) + ] + if not hardware_constraints: + raise CodegenError("No hardware constraints found.") + result = None - for constraint in emitter.constraints: - if isinstance(constraint, HardwareConstraint): - m, n, k = constraint.mma_matrix_shapes - result = emit_mfma(m, n, k, vector_type, acc, values) - break - - if result: - emitter.bind_node_proxy(node, IRProxyValue(result)) + for constraint in hardware_constraints: + m, n, k = constraint.mma_matrix_shapes + result = emit_mfma(m, n, k, vector_type, acc, values) + + emitter.bind_node_proxy(node, IRProxyValue(result)) @handle_op(operator.add) diff --git a/shark_turbine/kernel/wave/constraints.py b/shark_turbine/kernel/wave/constraints.py index 3917da13..989171ac 100644 --- a/shark_turbine/kernel/wave/constraints.py +++ b/shark_turbine/kernel/wave/constraints.py @@ -222,7 +222,7 @@ def get_workgroup_distributed_shape( Given a shape and workgroup constraints, returns the shape of the tensor after it has been distributed along workgroup dimensions. """ - distributed_shape = [s for s in shape] + distributed_shape = list(shape) for i, dim in enumerate(shape): for constraint in constraints: if isinstance(constraint, WorkgroupConstraint): diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index cc30cce7..fe10ee46 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -206,7 +206,7 @@ def _trace_and_get_kernel_signature( emitter.emit(graph.get_root_graph()) emitter.finish() - if "canonicalize" in kwargs and kwargs["canonicalize"]: + if kwargs.get("canonicalize", False): canonicalize_module(mb.module_op) return mb, graph