diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 005f7faf2d..3b3940f804 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -59,11 +59,6 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]: sdfg.remove_symbol(sym) result.add(sym) - for e in sdfg.edges(): - for aname in list(e.data.assignments): - if aname in result: - del e.data.assignments[aname] - if self.recursive: # Prune nested SDFGs recursively sid = sdfg.cfg_id @@ -93,10 +88,10 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: result |= symbolic.symbols_in_code(code.as_string) for desc in sdfg.arrays.values(): - result |= set(map(str, desc.used_symbols(False))) + result |= set(map(str, desc.free_symbols)) for state in sdfg.nodes(): - result |= state.used_symbols(False) + result |= state.free_symbols # In addition to the standard free symbols, we are conservative with other tasklet languages by # tokenizing their code. Since this is intersected with `sdfg.symbols`, keywords such as "if" are # ok to include @@ -116,6 +111,6 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: node.ignored_symbols) for e in sdfg.edges(): - result |= e.data.used_symbols(False) + result |= e.data.free_symbols return result