diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index 2bf3e228..83441af0 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -22,6 +22,12 @@ class MMAType(Enum): F32_32x32x16_F8 = 3 +class MMAOperand(Enum): + M = 0 + N = 1 + K = 2 + + @dataclass class Constraint(ABC): """ @@ -133,7 +139,7 @@ def compute_access_pattern_using_vector_shapes( def apply( self, dim: IndexSymbol, - constraint_index: int, + constraint_index: int | MMAOperand, elements_per_thread: int | IndexSymbol, stride: int, is_mma_dim: bool, @@ -232,11 +238,13 @@ def apply( ] case _: raise ValueError("Unsupported MMA type") - + assert isinstance( + constraint_index, MMAOperand + ), f"Invalid MMA operand {constraint_index}" return IndexSequence( - offset[constraint_index], - size[constraint_index], - stride[constraint_index], + offset[constraint_index.value], + size[constraint_index.value], + stride[constraint_index.value], ) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 7f417665..ea1b32ae 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -125,12 +125,21 @@ def has_strided_access(node: fx.Node) -> bool: custom.graph.erase_node(operator) +def preprocess_nodes( + constraints: Sequence[Constraint], + mma_index: dict[MMA, dict[IndexSymbol, int]], + mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]], + node: fx.Node, +): + set_vector_shapes(constraints, mma_index, mma_slices, node) + set_node_index(constraints, mma_index, mma_slices, node) + + def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): mma_index, mma_slices = get_mma_dimensional_mapping( trace, get_hardware_constraint(constraints) ) - trace.walk(partial(set_vector_shapes, constraints, mma_index, mma_slices)) - trace.walk(partial(set_node_index, constraints, mma_index, mma_slices)) + trace.walk(partial(preprocess_nodes, constraints, mma_index, mma_slices)) def compute_stride( diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 8fb95ece..b3f229e5 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -32,6 +32,7 @@ HardwareConstraint, TilingConstraint, MMAType, + MMAOperand, ) import torch.fx as fx import iree.turbine.kernel.lang as tkl @@ -221,9 +222,9 @@ def is_mma(node): k = ((set(lhs_shape) & set(rhs_shape)) - set(acc_shape)).pop() if custom not in mapping: mapping[custom] = {} - mapping[custom][m] = 0 - mapping[custom][n] = 1 - mapping[custom][k] = 2 + mapping[custom][m] = MMAOperand.M + mapping[custom][n] = MMAOperand.N + mapping[custom][k] = MMAOperand.K custom.vector_shapes = { m: hardware_constraint.mma_matrix_shapes[0], n: hardware_constraint.mma_matrix_shapes[1],