Skip to content

Commit

Permalink
Minor cleanups and refactoring (#242)
Browse files Browse the repository at this point in the history
This PR address comments from a previous PR, namely
- now only one pass through the graph is required to set the vector
shapes and index
- the MMA index is now specified using an enum, the MMAOperand, which
allows for better verification

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod authored Oct 28, 2024
1 parent 62d11cc commit bf7b686
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
18 changes: 13 additions & 5 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class MMAType(Enum):
F32_32x32x16_F8 = 3


class MMAOperand(Enum):
M = 0
N = 1
K = 2


@dataclass
class Constraint(ABC):
"""
Expand Down Expand Up @@ -133,7 +139,7 @@ def compute_access_pattern_using_vector_shapes(
def apply(
self,
dim: IndexSymbol,
constraint_index: int,
constraint_index: int | MMAOperand,
elements_per_thread: int | IndexSymbol,
stride: int,
is_mma_dim: bool,
Expand Down Expand Up @@ -232,11 +238,13 @@ def apply(
]
case _:
raise ValueError("Unsupported MMA type")

assert isinstance(
constraint_index, MMAOperand
), f"Invalid MMA operand {constraint_index}"
return IndexSequence(
offset[constraint_index],
size[constraint_index],
stride[constraint_index],
offset[constraint_index.value],
size[constraint_index.value],
stride[constraint_index.value],
)


Expand Down
13 changes: 11 additions & 2 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,21 @@ def has_strided_access(node: fx.Node) -> bool:
custom.graph.erase_node(operator)


def preprocess_nodes(
constraints: Sequence[Constraint],
mma_index: dict[MMA, dict[IndexSymbol, int]],
mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]],
node: fx.Node,
):
set_vector_shapes(constraints, mma_index, mma_slices, node)
set_node_index(constraints, mma_index, mma_slices, node)


def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]):
mma_index, mma_slices = get_mma_dimensional_mapping(
trace, get_hardware_constraint(constraints)
)
trace.walk(partial(set_vector_shapes, constraints, mma_index, mma_slices))
trace.walk(partial(set_node_index, constraints, mma_index, mma_slices))
trace.walk(partial(preprocess_nodes, constraints, mma_index, mma_slices))


def compute_stride(
Expand Down
7 changes: 4 additions & 3 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
HardwareConstraint,
TilingConstraint,
MMAType,
MMAOperand,
)
import torch.fx as fx
import iree.turbine.kernel.lang as tkl
Expand Down Expand Up @@ -221,9 +222,9 @@ def is_mma(node):
k = ((set(lhs_shape) & set(rhs_shape)) - set(acc_shape)).pop()
if custom not in mapping:
mapping[custom] = {}
mapping[custom][m] = 0
mapping[custom][n] = 1
mapping[custom][k] = 2
mapping[custom][m] = MMAOperand.M
mapping[custom][n] = MMAOperand.N
mapping[custom][k] = MMAOperand.K
custom.vector_shapes = {
m: hardware_constraint.mma_matrix_shapes[0],
n: hardware_constraint.mma_matrix_shapes[1],
Expand Down

0 comments on commit bf7b686

Please sign in to comment.