Skip to content

Commit

Permalink
Scal2sym: Properly deal with connector name clashes
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Nov 2, 2024
1 parent 3aec5cc commit 9f7ff91
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,33 +319,37 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle
"""
self.in_edges = in_edges
self.out_edges = out_edges
self.arrays = {k: sdfg.arrays[v.data] for k, v in in_edges.items() if k is not None}
self.arrays.update({k: sdfg.arrays[v.data] for k, v in out_edges.items() if k is not None})
self.sdfg = sdfg
self.defined = defined_syms
self.connector_names = set(in_edges.keys()) | set(out_edges.keys())
self.in_mapping: Dict[str, Tuple[str, subsets.Range]] = {}
self.out_mapping: Dict[str, Tuple[str, subsets.Range]] = {}
self.do_not_remove: Set[str] = set()
self.latest: DefaultDict[str, int] = collections.defaultdict(int)

def visit_Subscript(self, node: ast.Subscript) -> Any:
# Convert subscript to symbol name
node = self.generic_visit(node)
node_name = astutils.rname(node)
if node_name in self.in_edges:
self.latest[node_name] += 1
new_name = f'{node_name}_{self.latest[node_name]}'
new_name = dt.find_new_name(node_name, self.connector_names)
self.connector_names.add(new_name)

orig_subset = self.in_edges[node_name].subset
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.sdfg.arrays)[1]))
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1]))
# Check if range can be collapsed
if _range_is_promotable(subset, self.defined):
self.in_mapping[new_name] = (node_name, subset)
return ast.copy_location(ast.Name(id=new_name, ctx=ast.Load()), node)
else:
self.do_not_remove.add(node_name)
elif node_name in self.out_edges:
self.latest[node_name] += 1
new_name = f'{node_name}_{self.latest[node_name]}'
new_name = dt.find_new_name(node_name, self.connector_names)
self.connector_names.add(new_name)

orig_subset = self.out_edges[node_name].subset
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.sdfg.arrays)[1]))
subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1]))
# Check if range can be collapsed
if _range_is_promotable(subset, self.defined):
self.out_mapping[new_name] = (node_name, subset)
Expand Down Expand Up @@ -654,6 +658,8 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]:
tasklet_inputs = [e.src for e in state.in_edges(input)]
# Step 2.1
new_state = xfh.state_fission(gr.SubgraphView(state, set([input, node] + tasklet_inputs)))
if state.edges_between(input, node): # Edge still there after fission, remove manually
state.remove_edge_and_connectors(state.edges_between(input, node)[0])
new_isedge: sd.InterstateEdge = new_state.parent_graph.out_edges(new_state)[0]
# Step 2.2
node: nodes.AccessNode = new_state.sink_nodes()[0]
Expand Down

0 comments on commit 9f7ff91

Please sign in to comment.