From 7a8a23cebaceb8d975f1453310840c90201f95ce Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Thu, 24 Oct 2024 17:36:00 -0700 Subject: [PATCH] Address comments Signed-off-by: Harsh Menon --- iree/turbine/kernel/wave/expansion.py | 98 +++++++++++++++------------ 1 file changed, 55 insertions(+), 43 deletions(-) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index da7a1c0c..b003016c 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -190,19 +190,23 @@ def _expand_node( return context[(node, get_indexed_dims(dim_query, node), res_idx)] elif isinstance(node, MMA): # Handle expansion of MMA nodes whose reduction dim is not the same as the reduction - # dim of the parent reduction op. - if hasattr(node.graph, "parent_op") and node.reduction_dim not in dim_query: + # dim of the parent reduction op or when there is no parent reduction op. + has_parent_op = hasattr(node.graph, "parent_op") + reduction_axes_different = False + if has_parent_op: reduction: Reduction = get_custom(node.graph.parent_op) - if reduction.axis != node.reduction_dim: - return _expand_mma_reduction( - node, - trace, - dim_query, - get_node_dim_scaling(node), - context, - get_node_dim_scaling, - res_idx, - ) + reduction_axes_different = reduction.axis != node.reduction_dim + parallel_dim_query = node.reduction_dim not in dim_query + if (not has_parent_op or reduction_axes_different) and parallel_dim_query: + return _expand_mma_reduction( + node, + trace, + dim_query, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) elif isinstance(node, Reduction): return _expand_reduction( node, trace, dim_query, dim_scaling, context, get_node_dim_scaling, res_idx @@ -248,35 +252,30 @@ def _expand_node( # Proceed with expansion of the arguments for i, arg in node.node_args.items(): - new_arg = None - if not isinstance(arg, Sequence): - if is_expandable(arg): - new_arg = _expand_node( - arg, - trace, - restricted_dims, - get_node_dim_scaling(arg), - context, - get_node_dim_scaling, - res_idx, - ) - new_node.update_arg(i, new_arg) + arg_list = arg + unpack = lambda x: x + if isinstance(arg, list): + if not all(is_expandable(a) for a in arg): + continue else: - new_arg = [] - for subarg in arg: - if is_expandable(subarg): - new_subarg = _expand_node( - subarg, - trace, - restricted_dims, - get_node_dim_scaling(subarg), - context, - get_node_dim_scaling, - res_idx, - ) - new_arg.append(new_subarg.fx_node) - assert len(new_arg) == len(arg), "All subargs must be expanded" - new_node.update_arg(i, new_arg) + arg_list = [arg] + unpack = lambda x: x[0] + if not is_expandable(arg): + continue + + new_args = [] + for subarg in arg_list: + new_subarg = _expand_node( + subarg, + trace, + restricted_dims, + get_node_dim_scaling(subarg), + context, + get_node_dim_scaling, + res_idx, + ) + new_args.append(new_subarg.fx_node) + new_node.update_arg(i, unpack(new_args)) context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node return new_node @@ -400,21 +399,34 @@ def _expand_mma_reduction( """ logger.debug(f"Expanding MMA reduction: {mma} in dims: {dim_query}") + expand_dims = set(mma.indexing_dims) - set([mma.reduction_dim]) + idxc = IndexingContext.current() for dim in mma.indexing_dims: if dim not in dim_scaling and mma.vector_shapes[dim] > 0: tile_size = idxc.get_static_value(dim) dim_scaling[dim] = tile_size // mma.vector_shapes[dim] - expand_dims = set(mma.indexing_dims) - set([mma.reduction_dim]) - # Store the original mma node and accumulator value for expansion. + # When we begin expansion, we have a single mma node with the correct accumulator. + # This node corresponds to the dim query with all 0s and for this we reuse the + # original mma node. For all other queries, we create a new node. + # So say we have parallel dimensions {M, K2} and reduction dimension {K1}. + # For M = 0, K2 = 0, K1 = 0, we use the original mma node. + # For M = 0, K2 = 0, K1 = 1, we create a new node. + # Now, when it is time to expand along new parallel dimensions, we use the original node + # For M = 0, K2 = 1, K1 = 0, we use the original mma node so that the last cloned node's + # accumulator value is not modified. + dim_query_dims = tuple(dim_query.keys()) if not hasattr(_expand_mma_reduction, "acc"): _expand_mma_reduction.acc = {} if not hasattr(_expand_mma_reduction, "mma"): _expand_mma_reduction.mma = {} - if dim_query_dims not in _expand_mma_reduction.mma: + if ( + dim_query_dims not in _expand_mma_reduction.mma + or _expand_mma_reduction.mma[dim_query_dims].graph != mma.graph + ): _expand_mma_reduction.mma[dim_query_dims] = mma _expand_mma_reduction.acc[dim_query_dims] = mma.acc