Skip to content

Commit

Permalink
Add indexing to nodes
Browse files Browse the repository at this point in the history
This primary purpose of this PR is to annotate nodes
with their access patterns based on the workgroup, tiling
and MMA constraints. This is accomplished prior to
expansion and propagates through expansion to the expanded nodes.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jul 31, 2024
1 parent 9390476 commit ff5a81d
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 28 deletions.
199 changes: 194 additions & 5 deletions lit_tests/kernel/wave/expansion.py

Large diffs are not rendered by default.

50 changes: 48 additions & 2 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .._support.dtype import DataType
from .._support.regions import RegionGraph
from .base import OpDispatcher
import shark_turbine.kernel.lang as tkl

T = TypeVar("T", bound=Type[Any])
AccT = TypeVar("AccT")
Expand Down Expand Up @@ -171,13 +172,15 @@ class CustomOp(ABC):
fx_node: Optional[fx.Node] = field(default=None, init=False)
tkw_op_name: str = field(default="unknown", init=False)
_tracing_function: Optional[Callable[..., Any]] = field(default=None, init=False)
index: Optional[IndexExpr] = field(default=None, init=False)
index: Optional[dict[IndexSymbol, IndexSequence]] = field(default=None, init=False)

@classmethod
def from_fx_node(cls: Type[CustomOpT], node: fx.Node) -> CustomOpT:
instance = cls(*node.args)
instance.fx_node = node
instance.graph = node.graph
if hasattr(node, "index"):
instance.index = node.index
return instance

def __post_init__(self):
Expand All @@ -198,7 +201,14 @@ def __str__(self) -> str:

def custom_string(self, value_map: dict[str, str]) -> str:
# print all variables of the node apart from graph and fx_node
vars_list = [f"{key}={value}" for key, value in vars(self).items()][:-2]
ignore_list = ["fx_node", "graph"]
if self.index is None:
ignore_list += ["index"]
vars_list = [
f"{key}={value}"
for key, value in vars(self).items()
if key not in ignore_list
]
vars_str = ", ".join(vars_list)
return f"{self.tkw_op_name}({vars_str})"

Expand All @@ -213,6 +223,7 @@ def add_to_graph(self, region_graph: RegionGraph) -> fx.Node:
)
self.fx_node.tkw_op = self.__class__
self.fx_node.tkw_op_name = self.tkw_op_name
self.fx_node.index = None
return self.fx_node

def _add_proxy_to_graph(self, region_graph: RegionGraph):
Expand Down Expand Up @@ -259,6 +270,7 @@ def copy(
graph.inserting_after(self.fx_node)
new_node = graph.node_copy(self.fx_node)
new_node.tkw_op = self
new_node.index = self.fx_node.index
if new_name:
new_node.name = new_name
return get_custom(new_node)
Expand Down Expand Up @@ -512,6 +524,40 @@ def rhs_type(self) -> Memory:
def acc_type(self) -> Memory:
return get_custom(self.acc).type

def operand_index(
self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr]
) -> list[IndexSequence]:
indices: list[IndexSequence] = []
for dim in shape:
indices.append(self.index[dim].subs(operand_map))
return indices

@property
def lhs_index(self) -> list[IndexSequence]:
operand_map = {tkl.sym.MMA_LHS: 1, tkl.sym.MMA_RHS: 0, tkl.sym.MMA_ACC: 0}
return self.operand_index(operand_map, self.lhs_type.symbolic_shape)

@property
def rhs_index(self) -> list[IndexSequence]:
operand_map = {tkl.sym.MMA_LHS: 0, tkl.sym.MMA_RHS: 1, tkl.sym.MMA_ACC: 0}
return self.operand_index(operand_map, self.rhs_type.symbolic_shape)

@property
def acc_index(self) -> list[IndexSequence]:
operand_map = {tkl.sym.MMA_LHS: 0, tkl.sym.MMA_RHS: 0, tkl.sym.MMA_ACC: 1}
if self.acc.type is None:
return None
return self.operand_index(operand_map, self.acc_type.symbolic_shape)

def custom_string(self, value_map: dict[str, str]) -> str:
if self.index is None:
return super().custom_string(value_map)
custom_str = f"{self.tkw_op_name}("
custom_str += f"lhs={self.lhs} (index = {self.lhs_index}), "
custom_str += f"rhs={self.rhs} (index = {self.rhs_index}), "
custom_str += f"acc={self.acc} (index = {self.acc_index}))"
return custom_str


@define_op("read")
@dataclass
Expand Down
72 changes: 59 additions & 13 deletions shark_turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from enum import Enum
from typing import Optional
import shark_turbine.kernel.lang as tkl
from sympy import ceiling
from sympy import ceiling, Piecewise, floor

from .._support.indexing import IndexExpr, IndexSymbol
from .indexing import IndexSequence
from .distribution_symbols import *


class MMAType(Enum):
Expand All @@ -25,8 +27,8 @@ class Constraint(ABC):
"""

@abstractmethod
def apply(self) -> IndexExpr:
"""Apply the constraint and get the resulting index expression."""
def apply(self) -> IndexSequence:
"""Apply the constraint and get the resulting index sequence."""
...


Expand All @@ -52,8 +54,13 @@ class HardwareConstraint(Constraint):
mma_type: Optional[MMAType] = MMAType.F32_16x16x16_F16
vector_shapes: Optional[dict[IndexSymbol, int]] = None

def __post_init__(self):
self.LHS = tkl.sym.MMA_LHS
self.RHS = tkl.sym.MMA_RHS
self.ACC = tkl.sym.MMA_ACC

@property
def mma_matrix_shapes(self):
def mma_matrix_shapes(self) -> tuple[int]:
# TODO: Eventually the shapes and indices should be provided by a tool
match self.mma_type:
case MMAType.F32_16x16x16_F16:
Expand All @@ -63,8 +70,47 @@ def mma_matrix_shapes(self):
case _:
return ()

def apply(self) -> IndexExpr:
raise NotImplementedError("Not yet implemented")
@property
def threads_per_block(self) -> tuple[int]:
return (
self.waves_per_block[0] * self.threads_per_wave,
) + self.waves_per_block[1:]

@property
def linearized_thread_id(self) -> IndexExpr:
thread_ids = [THREAD_0, THREAD_1, THREAD_2]
threads_per_block = (
[1]
+ [self.threads_per_block[0]]
+ [self.threads_per_block[0] * self.threads_per_block[1]]
)
return sum([x * y for x, y in zip(thread_ids, threads_per_block)])

def apply(self, mma_index: int) -> IndexSequence:
lane = self.linearized_thread_id
match self.mma_type:
# (M x K, N x K) -> M x N
case MMAType.F32_16x16x16_F16:
offset = {
0: Piecewise(
(lane % 16, ~self.ACC), (4 * floor(lane / 16), self.ACC)
), # M
1: lane % 16, # N
2: 4 * floor(lane / 16), # K
}
size = {
0: Piecewise((0, ~self.ACC), (4, self.ACC)), # M
1: 0, # N
2: 4, # K
}
stride = {
0: Piecewise((1, ~self.ACC), (16, self.ACC)), # M
1: 1, # N
2: 1, # K
}
return IndexSequence(
offset[mma_index], size[mma_index], stride[mma_index]
)


@dataclass
Expand All @@ -81,17 +127,17 @@ class WorkgroupConstraint(Constraint):
tile_size: IndexExpr
workgroup_dim: int

def apply(self) -> IndexExpr:
def apply(self) -> IndexSequence:
match self.workgroup_dim:
case 0:
wg_dim = tkl.sym.WG0
wg_dim = WORKGROUP_0
case 1:
wg_dim = tkl.sym.WG1
wg_dim = WORKGROUP_1
case 2:
wg_dim = tkl.sym.WG2
wg_dim = WORKGROUP_2
case _:
raise ValueError("Invalid workgroup dimension. Expected 0, 1 or 2.")
return wg_dim * self.tile_size
return IndexSequence(wg_dim * self.tile_size, 1)


def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
Expand Down Expand Up @@ -130,9 +176,9 @@ def iterations(self) -> IndexExpr:
"""
return ceiling(self.dim / self.tile_size)

def apply(self) -> IndexExpr:
def apply(self) -> IndexSequence:
if self.induction_var is None:
raise ValueError(
"Index is being computed without setting induction variable"
)
return self.induction_var * self.tile_size
return IndexSequence(self.induction_var * self.tile_size, 1)
9 changes: 9 additions & 0 deletions shark_turbine/kernel/wave/distribution_symbols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import shark_turbine.kernel.lang as tkl

WORKGROUP_0 = tkl.sym.WG0
WORKGROUP_1 = tkl.sym.WG1
WORKGROUP_2 = tkl.sym.WG2

THREAD_0 = tkl.sym.T0
THREAD_1 = tkl.sym.T1
THREAD_2 = tkl.sym.T2
53 changes: 50 additions & 3 deletions shark_turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .._support.indexing import IndexingContext
from ...support.logging import get_logger
from .._support.tracing import CapturedTrace
from .indexing import IndexSequence

logger = get_logger("turbine.wave.expansion")
# This represents a mapping of a node + indexing into the dimensions to the
Expand Down Expand Up @@ -87,7 +88,7 @@ def is_expandable(arg: Any) -> bool:
return isinstance(arg, CustomOp)


def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[fx.Node, int]:
def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]:
"""
Given a trace, determine the MMA dimensional mapping for all the
MMA operations in the graph. For example, if we have
Expand All @@ -103,7 +104,7 @@ def is_mma(node: fx.Node) -> bool:
return True

mma_nodes = trace.walk(is_mma)
mapping: dict[fx.Node, int] = {}
mapping: dict[IndexSymbol, int] = {}
for node in mma_nodes:
custom: MMA = get_custom(node)
m, n = custom.acc_type.symbolic_shape[-2:]
Expand All @@ -120,6 +121,50 @@ def is_mma(node: fx.Node) -> bool:
return mapping


def set_node_indices(
trace: CapturedTrace,
constraints: Sequence[Constraint],
mma_index: dict[IndexSymbol, int],
):
"""
Set the indices of the nodes based on the user constraints. In certain
operators (like read, write), there is only a single index associated
with the node (the index to read from, the index to write to). But for
other operators like mma, each operand reads from a different index.
Rather than maintain operand specific indices for operators, we maintain
dimension specific indices for each operator. So for an mma operator that
has a signature of (MxK, NxK) -> MxN, we maintain only 3 mappings for
dimensions M, N and K, but allow each mapping to be piecewise conditioned
on the operand.
"""

def compute_index(node: fx.Node) -> bool:
custom = get_custom(node)
custom.index = {}
for dim in custom.indexing_dims:
for constraint in constraints:
if (
not isinstance(constraint, HardwareConstraint)
and dim != constraint.dim
):
continue
if not custom.index:
custom.index = {
dim: IndexSequence(0, 1) for dim in custom.indexing_dims
}
if isinstance(constraint, HardwareConstraint):
if dim in mma_index and isinstance(custom, MMA):
custom.index[dim] += constraint.apply(mma_index[dim])
elif dim == constraint.dim:
custom.index[dim] += constraint.apply()
if custom.index:
setattr(custom.fx_node, "index", custom.index)
return False

trace.walk(compute_index)


def expand_graph(
trace: CapturedTrace,
constraints_or_scaling: Sequence[Constraint] | dict[IndexSymbol, int],
Expand All @@ -132,6 +177,7 @@ def expand_graph(
dim_scaling = constraints_or_scaling
else:
mma_index = get_mma_dimensional_mapping(trace)
set_node_indices(trace, constraints_or_scaling, mma_index)
dim_scaling = get_dim_scaling(constraints_or_scaling, mma_index)

# Start from the back and expand in the corresponding indexing dimensions of a node
Expand All @@ -142,6 +188,7 @@ def expand_graph(
for node in (
get_custom(fx_node) for fx_node in reversed(list(trace.get_root_graph().nodes))
):

# Expansion begins at the leaf nodes
if node.__class__ not in leaf_nodes:
continue
Expand Down Expand Up @@ -276,7 +323,7 @@ def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str:


def get_dim_scaling(
constraints: Sequence[Constraint], mma_indices: dict[IndexExpr, int]
constraints: Sequence[Constraint], mma_indices: dict[IndexSymbol, int]
) -> dict[IndexSymbol, int]:
"""Get the number of expansions for the dimensions based on the constraints."""
dim_scaling: dict[IndexSymbol, int] = {}
Expand Down
43 changes: 43 additions & 0 deletions shark_turbine/kernel/wave/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import Optional, Any
from ...support.logging import get_logger
from .._support.indexing import IndexExpr, IndexSymbol
import sympy

logger = get_logger("turbine.wave.indexing")


@dataclass
class IndexSequence:
start: IndexExpr | int
size: IndexExpr | int
stride: Optional[IndexExpr | int] = 1

def __add__(self, other: Any) -> Any:
if isinstance(other, IndexSequence):
return IndexSequence(
self.start + other.start,
self.size * other.size,
self.stride * other.stride,
)
else:
raise NotImplementedError("IndexSequence addition not implemented!")

def subs(self, map: dict[IndexSymbol, int]):
start = self.start
if isinstance(self.start, IndexExpr):
start = self.start.subs(map)
size = self.size
if isinstance(self.size, IndexExpr):
size = self.size.subs(map)
stride = self.stride
if isinstance(self.stride, IndexExpr):
stride = self.stride.subs(map)
return IndexSequence(start, size, stride)

def __repr__(self):
if isinstance(self.size, sympy.Integer):
self.size = int(self.size)
if isinstance(self.size, int) and self.size <= 1:
return f"{self.start}"
return f"{self.start} : {self.size} : {self.stride}"
Loading

0 comments on commit ff5a81d

Please sign in to comment.