Skip to content

Commit

Permalink
Move index setting outside of expansion
Browse files Browse the repository at this point in the history
This PR separates the setting of indices pre-
and post-expansion from expansion itself. This
separation of concerns should make the overall
design more modular.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 18, 2024
1 parent 594d580 commit a7f7789
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 185 deletions.
18 changes: 17 additions & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,18 @@ def scheduling_parameters(self, value: Any):
raise ValueError("Scheduling parameters must be a dict")
self.fx_node.scheduling_parameters = value

@property
def expanded_dims(self) -> dict[IndexSymbol, int]:
if hasattr(self.fx_node, "expanded_dims"):
return self.fx_node.expanded_dims
return None

@expanded_dims.setter
def expanded_dims(self, value: dict[IndexSymbol, int]):
if not isinstance(value, dict):
raise ValueError("Expanded dims must be a dict")
self.fx_node.expanded_dims = value

def post_expansion(self, constraints: list["Constraint"]) -> None:
"""
Hook for post-expansion operations. This is called after the arguments
Expand Down Expand Up @@ -1060,7 +1072,11 @@ def index(self) -> dict[IndexSymbol, IndexSequence]:
custom = get_custom(self.value)
if custom.index is None:
return None
assert isinstance(custom.index, list) and self.res_idx < len(custom.index)
if not isinstance(custom, Reduction):
return custom.index
assert isinstance(custom.index, list) and self.res_idx < len(
custom.indexing_dims
)
return custom.index[self.res_idx]

@index.setter
Expand Down
193 changes: 15 additions & 178 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,137 +109,6 @@ def is_expandable(arg: Any) -> bool:
return isinstance(arg, CustomOp)


def is_contiguous_dim(
dim: IndexSymbol, symbolic_shape: list[IndexSymbol], vector_shapes: list[int]
) -> bool:
"""
Checks if the given dimension is stored contiguously in memory. This happens if
the dimension is the last one in the symbolic shape or all dimensions after it
are unit dimensions.
"""
is_innermost_dim = dim == symbolic_shape[-1]
dim_index = symbolic_shape.index(dim)
static_shape = [vector_shapes[dim] for dim in symbolic_shape]
all_unit_dims = all(dim == 1 for dim in static_shape[dim_index + 1 :])
return is_innermost_dim or all_unit_dims


def compute_stride(
symbolic_shape: tuple[IndexSymbol, ...],
vector_shapes: dict[IndexSymbol, int],
target_dim: IndexSymbol,
) -> int:
"""
Compute the stride for a given dimension based on the vector shapes.
The stride is the product of the vector shapes of all dimensions that are
not the given dimension.
"""
stride = 1
for dim in reversed(symbolic_shape):
if dim == target_dim:
break
assert dim in vector_shapes, f"Dimension {dim} not found in vector shapes"
stride *= vector_shapes[dim]

try:
stride = int(stride)
except Exception as e:
logger.error(e)
return stride


def set_node_index(
constraints: Sequence[Constraint],
mma_index: dict[IndexSymbol, int],
mma_slices: dict[IndexSymbol, list[fx.Node]],
dim_tile_size: dict[IndexSymbol, int],
custom: CustomOp,
dim_scaling: dict[IndexSymbol, int],
):
"""
Set the index of the node based on the user constraints. In certain
operators (like read, write), there is only a single index associated
with the node (the index to read from, the index to write to). But for
other operators like mma, each operand reads from a different index.
Rather than maintain operand specific indices for operators, we maintain
dimension specific indices for each operator. So for an mma operator that
has a signature of (MxK, NxK) -> MxN, we maintain only 3 mappings for
dimensions M, N and K, but allow each mapping to be piecewise conditioned
on the operand.
"""
hardware_constraint = [get_hardware_constraint(constraints)]
workgroup_constraints = {
c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint)
}
other_constraints = [
c for c in constraints if not isinstance(c, HardwareConstraint)
]
# Apply hardware constraint first since it dictates the stride and size.
sorted_constraints = hardware_constraint + other_constraints

index = {}
# The semantics of elements_per_thread are that it represents the number of
# elements that are loaded contiguously from memory.
elements_per_thread = getattr(custom, "elements_per_thread", None)

for dim in custom.indexing_dims:
index_seq = None
for constraint in sorted_constraints:
if isinstance(constraint, HardwareConstraint):
inputs = None
if dim in mma_index:
inputs = (mma_index[dim], elements_per_thread, None)
else:
# Assumes vector shapes are associated with workgroup dims.
assert (
dim in workgroup_constraints
), f"Dimension {dim} not found in workgroup constraints"
assert (
dim in constraint.vector_shapes
), f"Dimension {dim} not found in vector shapes"
if constraint.vector_shapes[dim] == 0:
continue
inputs = (
workgroup_constraints[dim].workgroup_dim,
(
1
if not is_contiguous_dim(
dim,
custom.indexing_dims,
constraint.vector_shapes,
)
else elements_per_thread
),
compute_stride(
custom.indexing_dims, constraint.vector_shapes, dim
),
)
if elements_per_thread is None:
# Here we end up with a situation where there will be no thread level
# dependence in the dimensional index.
# TODO: Evaluate if this is a valid case.
continue
index_seq = constraint.apply(dim, *inputs, dim in mma_index)
if dim in mma_index:
index_seq = specialize_index_sequence(index_seq, mma_slices, custom)

elif constraint.dim == dim:
if index_seq is None:
index_seq = constraint.apply()
else:
index_seq.start += constraint.apply().start

if index_seq is not None:
if dim in dim_scaling and dim in dim_tile_size:
index_seq.start += dim_scaling[dim] * dim_tile_size[dim]
index.update({dim: index_seq})
else:
index.update({dim: IndexSequence(0, 1, 1)})

setattr(custom.fx_node, "index", index)


def expand_graph(
trace: CapturedTrace,
constraints_or_scaling: Sequence[Constraint] | dict[IndexSymbol, int],
Expand All @@ -250,15 +119,8 @@ def expand_graph(
"""
if isinstance(constraints_or_scaling, dict):
dim_scaling = constraints_or_scaling
node_index_setter = lambda *args: None
else:
mma_index, mma_slices = get_mma_dimensional_mapping(
trace, get_hardware_constraint(constraints_or_scaling)
)
dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index)
node_index_setter = partial(
set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size
)
dim_scaling = get_dim_scaling(constraints_or_scaling)

# Start from the back and expand in the corresponding indexing dimensions of a node
# Then proceed to the operands
Expand Down Expand Up @@ -294,7 +156,6 @@ def expand_graph(
trace,
expand_dims,
dim_scaling,
node_index_setter,
expansion_context,
)

Expand All @@ -304,7 +165,6 @@ def _expand_node(
trace: CapturedTrace,
dim_query: dict[IndexSymbol, int],
dim_scaling: dict[IndexSymbol, int],
node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None],
context: ExpandedNodeMap,
res_idx: int = 0,
) -> CustomOp:
Expand All @@ -318,7 +178,6 @@ def _expand_node(
trace,
dim_query,
dim_scaling,
node_index_setter,
context,
res_idx,
).fx_node
Expand All @@ -329,9 +188,7 @@ def _expand_node(
logger.debug(f"Already expanded node: {node} in {dim_query}")
return context[(node, get_indexed_dims(dim_query, node), res_idx)]
elif isinstance(node, Reduction):
return _expand_reduction(
node, trace, dim_query, dim_scaling, node_index_setter, context, res_idx
)
return _expand_reduction(node, trace, dim_query, dim_scaling, context, res_idx)
elif isinstance(node, Getitem):
res_idx = node.res_idx
elif isinstance(node, GetResult) and not isinstance(node, Getitem):
Expand Down Expand Up @@ -368,11 +225,8 @@ def _expand_node(
if isinstance(node, IterArg):
_expand_node.last_expanded_iter_arg = new_node.fx_node

new_node.fx_node.expanded_dims = restricted_dims
new_node.expanded_dims = restricted_dims
new_node.fx_node.name = get_expanded_name(node, restricted_dims)
node_index_setter(new_node, restricted_dims)

constraints = node_index_setter.args[0]

# Proceed with expansion of the arguments
for i, arg in node.node_args.items():
Expand All @@ -382,14 +236,11 @@ def _expand_node(
trace,
restricted_dims,
dim_scaling,
node_index_setter,
context,
res_idx,
)
new_node.update_arg(i, new_arg)

new_node.post_expansion(constraints)

context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node
return new_node

Expand All @@ -399,7 +250,6 @@ def _expand_reduction(
trace: CapturedTrace,
dim_query: dict[IndexSymbol, int],
dim_scaling: dict[IndexSymbol, int],
node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None],
context: ExpandedNodeMap,
res_idx: int = 0,
) -> CustomOp:
Expand Down Expand Up @@ -442,7 +292,12 @@ def _expand_reduction(
# Proceed with expansion inside the reduction
new_output_args.append(
_expand_node(
arg, trace, dims, dim_scaling, node_index_setter, context, res_idx
arg,
trace,
dims,
dim_scaling,
context,
res_idx,
)
)

Expand All @@ -454,7 +309,6 @@ def _expand_reduction(
trace,
dims,
dim_scaling,
node_index_setter,
context,
res_idx,
)
Expand All @@ -470,7 +324,6 @@ def _expand_reduction(
output,
trace,
dim_scaling,
node_index_setter,
context,
res_idx,
)
Expand Down Expand Up @@ -502,12 +355,9 @@ def _contains(elem, container):
return elem in container


def get_dim_scaling(
constraints: Sequence[Constraint], mma_indices: dict[IndexSymbol, int]
) -> tuple[dict[IndexSymbol, int]]:
def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int]:
"""Get the number of expansions for the dimensions based on the constraints."""
dim_scaling: dict[IndexSymbol, int] = {}
dim_tile_size: dict[IndexSymbol, int] = {}
hardware_constraints: list[HardwareConstraint] = [
constraint
for constraint in constraints
Expand All @@ -523,30 +373,20 @@ def get_dim_scaling(
):
hw_cons = hardware_constraints[0]
tile_size = idxc.get_static_value(constraint.tile_size)
if not (
_contains(constraint.dim, mma_indices)
or _contains(constraint.dim, hw_cons.vector_shapes)
):
if not _contains(constraint.dim, hw_cons.vector_shapes):
raise ValueError(
f"Attempting to determine vector shape for unmapped dimension {constraint.dim}"
)

if mma_indices and constraint.dim in mma_indices:
vector_size = hardware_constraints[0].mma_matrix_shapes[
mma_indices[constraint.dim]
]
else:
vector_size = hardware_constraints[0].vector_shapes[constraint.dim]
vector_size = hw_cons.vector_shapes[constraint.dim]

# No dim scaling for dims with 0 vector size.
if vector_size == 0:
continue

wave_count = 1
if isinstance(constraint, WorkgroupConstraint):
wave_count = hardware_constraints[0].waves_per_block[
constraint.workgroup_dim
]
wave_count = hw_cons.waves_per_block[constraint.workgroup_dim]
if tile_size is None or wave_count is None or vector_size is None:
raise ValueError(
"Tile size, wave count and vector size must be statically known"
Expand All @@ -559,16 +399,14 @@ def get_dim_scaling(
"Tile size must be divisible by wave count and vector size"
)
dim_scaling[constraint.dim] = tile_size // wave_count // vector_size
dim_tile_size[constraint.dim] = vector_size
return (dim_scaling, dim_tile_size)
return dim_scaling


def _handle_reduction_dim(
reduction: Reduction,
output: Output,
trace: CapturedTrace,
dim_scaling: dict[IndexSymbol, int],
node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None],
context: ExpandedNodeMap,
res_idx: int,
):
Expand All @@ -589,7 +427,7 @@ def _handle_reduction_dim(
if isinstance(user, Output):
continue

dims = user.fx_node.expanded_dims
dims = dict(user.fx_node.expanded_dims)
dims[reduction.axis] = scale_idx
# Temporarily replace the loop carried arg here to avoid
# duplicated expansion. Otherwise we have the following situation:
Expand All @@ -611,7 +449,6 @@ def _handle_reduction_dim(
trace,
dims,
dim_scaling,
node_index_setter,
context,
res_idx,
)
Expand Down
Loading

0 comments on commit a7f7789

Please sign in to comment.