Skip to content

Commit

Permalink
Fix infinite loops in memlet path when a scope cycle is added (#1559)
Browse files Browse the repository at this point in the history
Fixes #1558
  • Loading branch information
tbennun authored Apr 26, 2024
1 parent 5d4dfe9 commit f01b937
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
15 changes: 15 additions & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto

# Prepend incoming edges until reaching the source node
curedge = edge
visited = set()
while not isinstance(curedge.src, (nd.CodeNode, nd.AccessNode)):
visited.add(curedge)
# Trace through scopes using OUT_# -> IN_#
if isinstance(curedge.src, (nd.EntryNode, nd.ExitNode)):
if curedge.src_conn is None:
Expand All @@ -398,10 +400,14 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto
next_edge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == "IN_" + curedge.src_conn[4:])
result.insert(0, next_edge)
curedge = next_edge
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')

# Append outgoing edges until reaching the sink node
curedge = edge
visited.clear()
while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)):
visited.add(curedge)
# Trace through scope entry using IN_# -> OUT_#
if isinstance(curedge.dst, (nd.EntryNode, nd.ExitNode)):
if curedge.dst_conn is None:
Expand All @@ -411,6 +417,8 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto
next_edge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == "OUT_" + curedge.dst_conn[3:])
result.append(next_edge)
curedge = next_edge
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')

return result

Expand All @@ -434,16 +442,23 @@ def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree:

# Find tree root
curedge = edge
visited = set()
if propagate_forward:
while (isinstance(curedge.src, nd.EntryNode) and curedge.src_conn is not None):
visited.add(curedge)
assert curedge.src_conn.startswith('OUT_')
cname = curedge.src_conn[4:]
curedge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == 'IN_%s' % cname)
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')
elif propagate_backward:
while (isinstance(curedge.dst, nd.ExitNode) and curedge.dst_conn is not None):
visited.add(curedge)
assert curedge.dst_conn.startswith('IN_')
cname = curedge.dst_conn[3:]
curedge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == 'OUT_%s' % cname)
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')
tree_root = mm.MemletTree(curedge, downwards=propagate_forward)

# Collect children (recursively)
Expand Down
19 changes: 19 additions & 0 deletions tests/sdfg/cycles_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import pytest

import dace
Expand All @@ -13,3 +14,21 @@ def test_cycles():

state.add_edge(access, None, access, None, dace.Memlet.simple("A", "0"))
sdfg.validate()


def test_cycles_memlet_path():
with pytest.raises(ValueError, match="Found cycles.*"):
sdfg = dace.SDFG("foo")
state = sdfg.add_state()
sdfg.add_array("bla", shape=(10, ), dtype=dace.float32)
mentry_3, _ = state.add_map("map_3", dict(i="0:9"))
mentry_3.add_in_connector("IN_0")
mentry_3.add_out_connector("OUT_0")
state.add_edge(mentry_3, "OUT_0", mentry_3, "IN_0", dace.Memlet(data="bla", subset='0:9'))

sdfg.validate()


if __name__ == '__main__':
test_cycles()
test_cycles_memlet_path()

0 comments on commit f01b937

Please sign in to comment.