Skip to content

Commit

Permalink
[TKW] Expand reduction dim on ReduceOp + non IterArg
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu committed Oct 26, 2024
1 parent 62d11cc commit 3682533
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 44 deletions.
12 changes: 10 additions & 2 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,15 @@ def copy(
new_node.reduction_dim = self.fx_node.reduction_dim
return get_custom(new_node)

def replace_all_uses_with(self, new_node: CustomOp | fx.Node):
def replace_all_uses_with(
self,
new_node: CustomOp | fx.Node,
filter_fn: Callable[[fx.node], bool] = lambda x: True,
):
"""Replace all uses of the current node with the new node."""
if isinstance(new_node, CustomOp):
new_node = new_node.fx_node
self.fx_node.replace_all_uses_with(new_node)
self.fx_node.replace_all_uses_with(new_node, filter_fn)

def erase(self):
"""Erase the current node from the graph where it exists."""
Expand Down Expand Up @@ -1322,6 +1326,10 @@ def num_reduction_dims(self) -> int:
else:
return 1

@property
def reduction_dim(self) -> IndexSymbol:
return self.dim


# TODO: Add support for more shuffle types.
@define_op("shuffle")
Expand Down
100 changes: 58 additions & 42 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,53 +566,69 @@ def _handle_reduction_dim(
# TODO: Register iter args with the reduction initially so accessing them is easier
iter_args: list[CustomOp] = []
reduction_subgraph = trace.get_subgraph(reduction.subgraph_name)

# TODO: Add support for case where we process MMA before returning to IterArg.
def get_output_index(custom: CustomOp):
output_users = [
get_custom(user)
for user in custom.fx_node.users
if isinstance(get_custom(user), Output)
]
if len(output_users) != 1:
raise NotImplementedError(
"NYI: Currently only handle direct and 1:1 MMA -> Output case."
)
return output_users[0].return_vals[0].index(custom.fx_node)

# Collect MMA and ReduceOp who's reduction axis matches parent ReductionOp.
reduction_root_ops = []
for node in (get_custom(fx_node) for fx_node in reduction_subgraph.nodes):
if isinstance(node, IterArg):
iter_args.append(node)
if isinstance(node, MMA) and reduction.axis == node.reduction_dim:
reduction_root_ops.append(node)

new_outputs = list(reduction.outputs(trace.get_subgraph(reduction.subgraph_name)))
# Users of the loop carried nodes will be duplicated
for idx, carried_node in enumerate(iter_args):
for root_op in reduction_root_ops:
# The initial nodes are expanded in the first dimension, so we start from 1
dim_scaling = get_node_dim_scaling(carried_node)
dim_scaling = get_node_dim_scaling(root_op)
dims = dict(root_op.fx_node.expanded_dims)
latest_reduced_op = root_op
op_output_index = get_output_index(root_op)
if dim_scaling[reduction.axis] <= 1:
continue
for scale_idx in range(1, dim_scaling[reduction.axis]):
for user in carried_node.users:
if isinstance(user, Output):
continue

dims = dict(user.fx_node.expanded_dims)
dims[reduction.axis] = scale_idx
# Temporarily replace the loop carried arg here to avoid
# duplicated expansion. Otherwise we have the following situation:
# Suppose we have:
# mma_0_0_0(..., acc_0_0_0)
# mma_0_0_1(..., mma_0_0_0)
# Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg
# mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node.
# To avoid this we temporarily replace the use of it with a dummy
# placeholder which will not trigger further expansion.
index = user.get_node_arg_index(carried_node)
dummy = Placeholder("dummy").add_to_graph(user.graph)
dummy.type = None

saved_arg = user.node_args[index]
user.update_arg(index, dummy)
new_node = _expand_node(
user,
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)

# This expansion always happens, user should never be reused
assert new_node != user
user.update_arg(index, saved_arg)
new_node.update_arg(index, user)
user.graph.erase_node(dummy)
carried_node = user
new_outputs[idx] = new_node.fx_node
dims[reduction.axis] = scale_idx
# Temporarily replace the loop carried arg here to avoid
# duplicated expansion. Otherwise we have the following situation:
# Suppose we have:
# mma_0_0_0(..., acc_0_0_0)
# mma_0_0_1(..., mma_0_0_0)
# Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg
# mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node.
# To avoid this we temporarily replace the use of it with a dummy
# placeholder which will not trigger further expansion.
dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph)
dummy.type = None

saved_acc = latest_reduced_op.acc
latest_reduced_op.update_arg("acc", dummy)
new_node = _expand_node(
latest_reduced_op,
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)

# This expansion always happens, user should never be reused
assert new_node != latest_reduced_op
# Update MMA_{t} to accumulate on MMA_{t-1}, and then save
# current MMA_{t} to outputs for use in next loop.
latest_reduced_op.update_arg("acc", saved_acc)
new_node.update_arg("acc", latest_reduced_op)
latest_reduced_op.graph.erase_node(dummy)
latest_reduced_op = new_node
new_outputs[op_output_index] = latest_reduced_op.fx_node
output.update_arg("return_vals", new_outputs)

0 comments on commit 3682533

Please sign in to comment.