From 36825337f657ccd91a50a9741671439c17fa3043 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Sat, 26 Oct 2024 08:27:08 -0700 Subject: [PATCH] [TKW] Expand reduction dim on ReduceOp + non IterArg Signed-off-by: Stanley Winata --- iree/turbine/kernel/ops/wave_ops.py | 12 +++- iree/turbine/kernel/wave/expansion.py | 100 +++++++++++++++----------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 1d225010..b2224c3c 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -409,11 +409,15 @@ def copy( new_node.reduction_dim = self.fx_node.reduction_dim return get_custom(new_node) - def replace_all_uses_with(self, new_node: CustomOp | fx.Node): + def replace_all_uses_with( + self, + new_node: CustomOp | fx.Node, + filter_fn: Callable[[fx.node], bool] = lambda x: True, + ): """Replace all uses of the current node with the new node.""" if isinstance(new_node, CustomOp): new_node = new_node.fx_node - self.fx_node.replace_all_uses_with(new_node) + self.fx_node.replace_all_uses_with(new_node, filter_fn) def erase(self): """Erase the current node from the graph where it exists.""" @@ -1322,6 +1326,10 @@ def num_reduction_dims(self) -> int: else: return 1 + @property + def reduction_dim(self) -> IndexSymbol: + return self.dim + # TODO: Add support for more shuffle types. @define_op("shuffle") diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index b003016c..726b0c96 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -566,53 +566,69 @@ def _handle_reduction_dim( # TODO: Register iter args with the reduction initially so accessing them is easier iter_args: list[CustomOp] = [] reduction_subgraph = trace.get_subgraph(reduction.subgraph_name) + + # TODO: Add support for case where we process MMA before returning to IterArg. + def get_output_index(custom: CustomOp): + output_users = [ + get_custom(user) + for user in custom.fx_node.users + if isinstance(get_custom(user), Output) + ] + if len(output_users) != 1: + raise NotImplementedError( + "NYI: Currently only handle direct and 1:1 MMA -> Output case." + ) + return output_users[0].return_vals[0].index(custom.fx_node) + + # Collect MMA and ReduceOp who's reduction axis matches parent ReductionOp. + reduction_root_ops = [] for node in (get_custom(fx_node) for fx_node in reduction_subgraph.nodes): - if isinstance(node, IterArg): - iter_args.append(node) + if isinstance(node, MMA) and reduction.axis == node.reduction_dim: + reduction_root_ops.append(node) new_outputs = list(reduction.outputs(trace.get_subgraph(reduction.subgraph_name))) # Users of the loop carried nodes will be duplicated - for idx, carried_node in enumerate(iter_args): + for root_op in reduction_root_ops: # The initial nodes are expanded in the first dimension, so we start from 1 - dim_scaling = get_node_dim_scaling(carried_node) + dim_scaling = get_node_dim_scaling(root_op) + dims = dict(root_op.fx_node.expanded_dims) + latest_reduced_op = root_op + op_output_index = get_output_index(root_op) + if dim_scaling[reduction.axis] <= 1: + continue for scale_idx in range(1, dim_scaling[reduction.axis]): - for user in carried_node.users: - if isinstance(user, Output): - continue - - dims = dict(user.fx_node.expanded_dims) - dims[reduction.axis] = scale_idx - # Temporarily replace the loop carried arg here to avoid - # duplicated expansion. Otherwise we have the following situation: - # Suppose we have: - # mma_0_0_0(..., acc_0_0_0) - # mma_0_0_1(..., mma_0_0_0) - # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg - # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. - # To avoid this we temporarily replace the use of it with a dummy - # placeholder which will not trigger further expansion. - index = user.get_node_arg_index(carried_node) - dummy = Placeholder("dummy").add_to_graph(user.graph) - dummy.type = None - - saved_arg = user.node_args[index] - user.update_arg(index, dummy) - new_node = _expand_node( - user, - trace, - dims, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - - # This expansion always happens, user should never be reused - assert new_node != user - user.update_arg(index, saved_arg) - new_node.update_arg(index, user) - user.graph.erase_node(dummy) - carried_node = user - new_outputs[idx] = new_node.fx_node + dims[reduction.axis] = scale_idx + # Temporarily replace the loop carried arg here to avoid + # duplicated expansion. Otherwise we have the following situation: + # Suppose we have: + # mma_0_0_0(..., acc_0_0_0) + # mma_0_0_1(..., mma_0_0_0) + # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg + # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. + # To avoid this we temporarily replace the use of it with a dummy + # placeholder which will not trigger further expansion. + dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph) + dummy.type = None + + saved_acc = latest_reduced_op.acc + latest_reduced_op.update_arg("acc", dummy) + new_node = _expand_node( + latest_reduced_op, + trace, + dims, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + # This expansion always happens, user should never be reused + assert new_node != latest_reduced_op + # Update MMA_{t} to accumulate on MMA_{t-1}, and then save + # current MMA_{t} to outputs for use in next loop. + latest_reduced_op.update_arg("acc", saved_acc) + new_node.update_arg("acc", latest_reduced_op) + latest_reduced_op.graph.erase_node(dummy) + latest_reduced_op = new_node + new_outputs[op_output_index] = latest_reduced_op.fx_node output.update_arg("return_vals", new_outputs)