diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 5d586aa5..57ce9b25 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -190,13 +190,13 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %reduction + # CHECK-SAME: %register_0_0_0, %register_0_1_0, %register_1_0_0, %register_1_1_0 # CHECK-NEXT: %getresult_1_1_0 # CHECK-NEXT: %getresult_1_0_0 # CHECK-NEXT: %getresult_0_1_0 # CHECK-NEXT: %getresult_0_0_0 # CHECK-NEXT: %write_0_0_0 - # TODO: This link-up is not yet correct! - # CHECK-SAME: (%reduction, %c, 4) + # CHECK-SAME: (%get_result_0_0_0, %c, 4) # CHECK-NEXT: %write_1_1_0 # CHECK-SAME: (%get_result_1_1_0, %c, 4) # CHECK-NEXT: %write_1_0_0 @@ -283,8 +283,7 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: %getresult_0_1_0 # CHECK-NEXT: %getresult_0_0_0 # CHECK-NEXT: %write_0_0_0 - # TODO: This link-up is not yet correct! - # CHECK-SAME: (%reduction, %c, 4) + # CHECK-SAME: (%get_result_0_0_0, %c, 4) # CHECK-NEXT: %write_0_1_0 # CHECK-SAME: (%get_result_0_1_0, %c, 4) diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 9d857779..f900480b 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -446,9 +446,18 @@ def type(self) -> "Memory": @define_op("get_result") @dataclass class GetResult(CustomOp): - value: fx.Proxy + value: fx.Node res_idx: int @property def type(self) -> "Memory": return self.value.type + + @property + def indexing_dims(self) -> list[IndexSymbol]: + expand_dims: list[IndexSymbol] = [] + for user in self.users: + for indexing_dim in user.indexing_dims: + if indexing_dim not in expand_dims: + expand_dims.append(indexing_dim) + return expand_dims diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index c55e48a2..e2789181 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -91,7 +91,6 @@ def expand_graph( Create a graph that represents the expanded version of the wave function. The expansion is done in the dimensions specified by the constraints. """ - if isinstance(constraints_or_scaling, dict): dim_scaling = constraints_or_scaling else: @@ -114,19 +113,7 @@ def expand_graph( dim: val for dim, val in zip(dim_scaling.keys(), dim_combination) } logger.debug(f"Starting expansion at leaf:{node} in dims:{expand_dims}") - if not expansion_needed(expand_dims, node.indexing_dims): - new_node = node - else: - node.graph.inserting_after(node.fx_node) - new_node = node.copy() - for arg_idx, arg in enumerate(node.node_args): - if is_expandable(arg): - new_arg = _expand_node( - arg, trace, expand_dims, dim_scaling, expansion_context - ) - new_node.update_arg(arg_idx, new_arg) - new_node.fx_node.name = get_expanded_name(node, expand_dims) - expansion_context[(node, get_indexed_dims(expand_dims, node))] = new_node + _expand_node(node, trace, expand_dims, dim_scaling, expansion_context) def _expand_node( @@ -143,6 +130,11 @@ def _expand_node( return context[(node, get_indexed_dims(dim_query, node))] elif isinstance(node, Reduction): return _expand_reduction(node, trace, dim_query, dim_scaling, context) + elif isinstance(node, GetResult): + # The presence of a GetResult node indicates that the reduction has already + # been expanded. Simply return the corresponding node. + reduction = get_custom(node.value) + return context[(reduction, get_indexed_dims(dim_query, reduction))] # Filter out the dimensions that are not indexed by the node restricted_dims = filter_and_zero_unselected_dims(dim_query, node.indexing_dims)