Skip to content

Commit

Permalink
Fix three issues related to deepcopying elements (#1446)
Browse files Browse the repository at this point in the history
This PR fixes #1439 and #1443 by adapting fields and the deepcopy
operation for states:
1. Skips derived field `parent` being set if a state is deepcopied on
its own
2. Does not add a new field to AST nodes during preprocessing. That
parent-pointing field outlives preprocessing and ends up copying the
entire original AST for short codeblocks.
3. Does not add a new field to states during state propagation.
  • Loading branch information
tbennun authored Nov 27, 2023
1 parent d157346 commit 4b5e2c2
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 9 deletions.
11 changes: 6 additions & 5 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,14 +1522,15 @@ def __init__(self, globals: Dict[str, Any]):
from mpi4py import MPI
self.globals = globals
self.MPI = MPI
self.parents = {}
self.parent = None

def visit(self, node):
node.parent = self.parent
self.parents[node] = self.parent
self.parent = node
node = super().visit(node)
if isinstance(node, ast.AST):
self.parent = node.parent
self.parent = self.parents[node]
return node

def visit_Name(self, node: ast.Name) -> Union[ast.Name, ast.Attribute]:
Expand All @@ -1540,7 +1541,7 @@ def visit_Name(self, node: ast.Name) -> Union[ast.Name, ast.Attribute]:
lattr = ast.Attribute(ast.Name(id='mpi4py', ctx=ast.Load), attr='MPI')
if obj is self.MPI.COMM_NULL:
newnode = ast.copy_location(ast.Attribute(value=lattr, attr='COMM_NULL'), node)
newnode.parent = node.parent
self.parents[newnode] = self.parents[node]
return newnode
return node

Expand All @@ -1549,10 +1550,10 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
if isinstance(node.attr, str) and node.attr == 'Request':
try:
val = astutils.evalnode(node, self.globals)
if val is self.MPI.Request and not isinstance(node.parent, ast.Attribute):
if val is self.MPI.Request and not isinstance(self.parents[node], ast.Attribute):
newnode = ast.copy_location(
ast.Attribute(value=ast.Name(id='dace', ctx=ast.Load), attr='MPI_Request'), node)
newnode.parent = node.parent
self.parents[newnode] = self.parents[node]
return newnode
except SyntaxError:
pass
Expand Down
11 changes: 7 additions & 4 deletions dace/sdfg/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,12 +565,15 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states):
:param sdfg: The SDFG in which to look.
:param unannotated_cycle_states: List of lists. Each sub-list contains the states of one unannotated cycle.
:return: A dictionary mapping guard states to their condition edges, if applicable
"""

# We import here to avoid cyclic imports.
from dace.transformation.interstate.loop_detection import find_for_loop
from dace.sdfg import utils as sdutils

condition_edges = {}

for cycle in sdfg.find_cycles():
# In each cycle, try to identify a valid loop guard state.
guard = None
Expand Down Expand Up @@ -667,14 +670,15 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states):
for v in loop_states:
v.ranges[itervar] = subsets.Range([rng])
guard.ranges[itervar] = subsets.Range([rng])
guard.condition_edge = sdfg.edges_between(guard, begin)[0]
condition_edges[guard] = sdfg.edges_between(guard, begin)[0]
guard.is_loop_guard = True
guard.itvar = itervar
else:
# There's no guard state, so this cycle marks all states in it as
# dynamically unbounded.
unannotated_cycle_states.append(cycle)

return condition_edges

def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None:
"""
Expand Down Expand Up @@ -760,7 +764,7 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None:
# Find any valid for loop constructs and annotate the loop ranges. Any other
# cycle should be marked as unannotated.
unannotated_cycle_states = []
_annotate_loop_ranges(sdfg, unannotated_cycle_states)
condition_edges = _annotate_loop_ranges(sdfg, unannotated_cycle_states)
if not concretize_dynamic_unbounded:
# Flatten the list. This keeps the old behavior of propagate_states.
unannotated_cycle_states = [state for cycle in unannotated_cycle_states for state in cycle]
Expand Down Expand Up @@ -869,7 +873,7 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None:
(outer_itvar, 0, ceiling((outer_stop - outer_start) / outer_stride)))
loop_executions = loop_executions.doit()

loop_state = state.condition_edge.dst
loop_state = condition_edges[state].dst
end_state = (out_edges[0].dst if out_edges[1].dst == loop_state else out_edges[1].dst)

traversal_q.append((end_state, state.executions, proposed_dynamic, itvar_stack))
Expand Down Expand Up @@ -1142,7 +1146,6 @@ def reset_state_annotations(sdfg):
state.executions = 0
state.dynamic_executions = True
state.ranges = {}
state.condition_edge = None
state.is_loop_guard = False
state.itervar = None

Expand Down
9 changes: 9 additions & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,16 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == '_parent': # Skip derivative attributes
continue
setattr(result, k, copy.deepcopy(v, memo))

for k in ('_parent',):
if id(getattr(self, k)) in memo:
setattr(result, k, memo[id(getattr(self, k))])
else:
setattr(result, k, None)

for node in result.nodes():
if isinstance(node, nd.NestedSDFG):
try:
Expand Down
1 change: 1 addition & 0 deletions dace/transformation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _copy_state(sdfg: SDFG,

state_copy = copy.deepcopy(state)
state_copy._label += '_copy'
state_copy.parent = sdfg
sdfg.add_node(state_copy)

in_conditions = []
Expand Down
16 changes: 16 additions & 0 deletions tests/sdfg/state_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace.transformation.helpers import find_sdfg_control_flow


def test_read_write_set():
Expand Down Expand Up @@ -42,7 +43,22 @@ def test_read_write_set_y_formation():

assert 'B' not in state.read_and_write_sets()[0]

def test_deepcopy_state():
N = dace.symbol('N')

@dace.program
def double_loop(arr: dace.float32[N]):
for i in range(N):
arr[i] *= 2
for i in range(N):
arr[i] *= 2

sdfg = double_loop.to_sdfg()
find_sdfg_control_flow(sdfg)
sdfg.validate()


if __name__ == '__main__':
test_read_write_set()
test_read_write_set_y_formation()
test_deepcopy_state()
1 change: 1 addition & 0 deletions tests/transformations/loop_to_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,3 +757,4 @@ def internal_write(inp0: dace.int32[10], inp1: dace.int32[10], out: dace.int32[1
test_thread_local_transient_multi_state()
test_nested_loops()
test_internal_write()
test_specialize()

0 comments on commit 4b5e2c2

Please sign in to comment.