Skip to content

Commit

Permalink
Simplify leaf expansion
Browse files Browse the repository at this point in the history
- fix bug in linking up the reduction results

Signed-off-by: Martin Lücke <martin.luecke@ed.ac.uk>
  • Loading branch information
martin-luecke committed Jul 10, 2024
1 parent 233e6fa commit 1e73b0d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
7 changes: 3 additions & 4 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ def test_gemm():
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %reduction
# CHECK-SAME: %register_0_0_0, %register_0_1_0, %register_1_0_0, %register_1_1_0
# CHECK-NEXT: %getresult_1_1_0
# CHECK-NEXT: %getresult_1_0_0
# CHECK-NEXT: %getresult_0_1_0
# CHECK-NEXT: %getresult_0_0_0
# CHECK-NEXT: %write_0_0_0
# TODO: This link-up is not yet correct!
# CHECK-SAME: (%reduction, %c, 4)
# CHECK-SAME: (%get_result_0_0_0, %c, 4)
# CHECK-NEXT: %write_1_1_0
# CHECK-SAME: (%get_result_1_1_0, %c, 4)
# CHECK-NEXT: %write_1_0_0
Expand Down Expand Up @@ -283,8 +283,7 @@ def test_gemm_reduction_expansion_only():
# CHECK-NEXT: %getresult_0_1_0
# CHECK-NEXT: %getresult_0_0_0
# CHECK-NEXT: %write_0_0_0
# TODO: This link-up is not yet correct!
# CHECK-SAME: (%reduction, %c, 4)
# CHECK-SAME: (%get_result_0_0_0, %c, 4)
# CHECK-NEXT: %write_0_1_0
# CHECK-SAME: (%get_result_0_1_0, %c, 4)

Expand Down
11 changes: 10 additions & 1 deletion shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,18 @@ def type(self) -> "Memory":
@define_op("get_result")
@dataclass
class GetResult(CustomOp):
value: fx.Proxy
value: fx.Node
res_idx: int

@property
def type(self) -> "Memory":
return self.value.type

@property
def indexing_dims(self) -> list[IndexSymbol]:
expand_dims: list[IndexSymbol] = []
for user in self.users:
for indexing_dim in user.indexing_dims:
if indexing_dim not in expand_dims:
expand_dims.append(indexing_dim)
return expand_dims
20 changes: 6 additions & 14 deletions shark_turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def expand_graph(
Create a graph that represents the expanded version of the wave function.
The expansion is done in the dimensions specified by the constraints.
"""

if isinstance(constraints_or_scaling, dict):
dim_scaling = constraints_or_scaling
else:
Expand All @@ -114,19 +113,7 @@ def expand_graph(
dim: val for dim, val in zip(dim_scaling.keys(), dim_combination)
}
logger.debug(f"Starting expansion at leaf:{node} in dims:{expand_dims}")
if not expansion_needed(expand_dims, node.indexing_dims):
new_node = node
else:
node.graph.inserting_after(node.fx_node)
new_node = node.copy()
for arg_idx, arg in enumerate(node.node_args):
if is_expandable(arg):
new_arg = _expand_node(
arg, trace, expand_dims, dim_scaling, expansion_context
)
new_node.update_arg(arg_idx, new_arg)
new_node.fx_node.name = get_expanded_name(node, expand_dims)
expansion_context[(node, get_indexed_dims(expand_dims, node))] = new_node
_expand_node(node, trace, expand_dims, dim_scaling, expansion_context)


def _expand_node(
Expand All @@ -143,6 +130,11 @@ def _expand_node(
return context[(node, get_indexed_dims(dim_query, node))]
elif isinstance(node, Reduction):
return _expand_reduction(node, trace, dim_query, dim_scaling, context)
elif isinstance(node, GetResult):
# The presence of a GetResult node indicates that the reduction has already
# been expanded. Simply return the corresponding node.
reduction = get_custom(node.value)
return context[(reduction, get_indexed_dims(dim_query, reduction))]

# Filter out the dimensions that are not indexed by the node
restricted_dims = filter_and_zero_unselected_dims(dim_query, node.indexing_dims)
Expand Down

0 comments on commit 1e73b0d

Please sign in to comment.