Skip to content

Commit

Permalink
Schedule tree: fix tests, print empty scopes in a nicer way
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Jul 26, 2023
1 parent 4b43606 commit b4a7984
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 2 additions & 0 deletions dace/sdfg/analysis/schedule_tree/treenodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self,
child.parent = self

def as_string(self, indent: int = 0):
if not self.children:
return (indent + 1) * INDENTATION + 'pass'
return '\n'.join([child.as_string(indent + 1) for child in self.children])

def preorder_traversal(self) -> Iterator['ScheduleTreeNode']:
Expand Down
14 changes: 9 additions & 5 deletions tests/schedule_tree/schedule_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main(a: dace.float64[20, 10]):
nest1(a[10:15])
nest1(a[15:])

sdfg = main.to_sdfg()
sdfg = main.to_sdfg(simplify=True)
stree = as_schedule_tree(sdfg)

# Despite two levels of nesting, immediate children are the 4 for loops
Expand Down Expand Up @@ -150,8 +150,10 @@ def test_irreducible_sub_sdfg():
# Add a loop following general block
sdfg.add_loop(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1')

# TODO: Missing exit in stateif s2->e
# print(as_schedule_tree(sdfg).as_string())
stree = as_schedule_tree(sdfg)
node_types = [type(n) for n in stree.preorder_traversal()]
assert node_types.count(tn.GBlock) == 1 # Only one gblock
assert node_types[-1] == tn.ForScope # Check that loop was detected


def test_irreducible_in_loops():
Expand All @@ -176,8 +178,10 @@ def test_irreducible_in_loops():
# Avoiding undefined behavior
sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5'

# TODO: gblock must cover the greatest common scope its labels are in.
# print(as_schedule_tree(sdfg).as_string())
stree = as_schedule_tree(sdfg)
node_types = [type(n) for n in stree.preorder_traversal()]
assert node_types.count(tn.GBlock) == 1
assert node_types.count(tn.ForScope) == 2


def test_reference():
Expand Down

0 comments on commit b4a7984

Please sign in to comment.