Skip to content

Commit

Permalink
Complete coverage for reference-to-view pass (#1488)
Browse files Browse the repository at this point in the history
Adds a scoped test that completes coverage for the reference-to-view
pass, leading to fixes of issues in the uncovered code.
  • Loading branch information
tbennun authored Dec 22, 2023
1 parent 509ee0f commit b93a4c9
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
7 changes: 4 additions & 3 deletions dace/transformation/passes/reference_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def find_candidates(
# Check if any of the symbols is a scope symbol
entry = state.entry_node(node)
while entry is not None:
if fsyms & entry.new_symbols(sdfg, state, {}):
if fsyms & entry.new_symbols(sdfg, state, {}).keys():
result.remove(cand)
break
entry = state.entry_node(entry)
Expand Down Expand Up @@ -183,11 +183,12 @@ def remove_refsets(

# Modify the state graph as necessary
for e in edges_to_remove:
state.remove_edge_and_connectors(e)
state.remove_memlet_path(e)
for n in nodes_to_remove:
state.remove_node(n)
for e in edges_to_add:
state.add_edge(*e)
if len(state.edges_between(e[0], e[2])) == 0:
state.add_edge(*e)
for n in affected_nodes: # Orphaned nodes
if n in nodes_to_remove:
continue
Expand Down
59 changes: 59 additions & 0 deletions tests/sdfg/reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,61 @@ def test_reference_loop_nonfree_internal_use():
assert np.allclose(ref, A)


@pytest.mark.parametrize(('array_outside_scope', 'depends_on_iterate'), ((False, True), (False, True)))
def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate):
sdfg = dace.SDFG('reftest')
sdfg.add_array('A', [20], dace.float64)
sdfg.add_array('B', [20], dace.float64)
sdfg.add_reference('ref', [1], dace.float64)

memlet_string = 'A[i]' if depends_on_iterate else 'A[3]'

state = sdfg.add_state()
me, mx = state.add_map('somemap', dict(i='0:20'))
arr = state.add_access('A')
ref = state.add_access('ref')
write = state.add_write('B')

if array_outside_scope:
state.add_edge_pair(me, ref, arr, dace.Memlet(memlet_string), internal_connector='set')
else:
state.add_nedge(me, arr, dace.Memlet())
state.add_edge(arr, None, ref, 'set', dace.Memlet(memlet_string))

t = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1')
state.add_edge(ref, None, t, 'inp', dace.Memlet('ref'))
state.add_edge_pair(mx, t, write, dace.Memlet('B[i]'), internal_connector='out')

# Test sources
sources = FindReferenceSources().apply_pass(sdfg, {})
assert len(sources) == 1 # There is only one SDFG
sources = sources[0]
assert len(sources) == 1
assert sources['ref'] == {dace.Memlet(memlet_string)}

# Test correctness before pass
A = np.random.rand(20)
B = np.random.rand(20)
ref = (A + 1) if depends_on_iterate else (A[3] + 1)
sdfg(A=A, B=B)
assert np.allclose(B, ref)

# Test reference-to-view - should fail to apply
result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {})
if depends_on_iterate:
assert 'ReferenceToView' not in result or not result['ReferenceToView']
else:
assert result['ReferenceToView'] == {'ref'}

# Test correctness after pass
if not depends_on_iterate:
A = np.random.rand(20)
B = np.random.rand(20)
ref = (A + 1) if depends_on_iterate else (A[3] + 1)
sdfg(A=A, B=B)
assert np.allclose(B, ref)


if __name__ == '__main__':
test_unset_reference()
test_reference_branch()
Expand All @@ -603,3 +658,7 @@ def test_reference_loop_nonfree_internal_use():
test_reference_loop_internal_use(False)
test_reference_loop_internal_use(True)
test_reference_loop_nonfree_internal_use()
test_ref2view_refset_in_scope(False, False)
test_ref2view_refset_in_scope(False, True)
test_ref2view_refset_in_scope(True, False)
test_ref2view_refset_in_scope(True, True)

0 comments on commit b93a4c9

Please sign in to comment.