From b4a7984ec94ff1b5eb00cb474bc7d1fb766fc425 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 25 Jul 2023 23:07:03 -0700 Subject: [PATCH] Schedule tree: fix tests, print empty scopes in a nicer way --- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 ++ tests/schedule_tree/schedule_test.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index da56dc16aa..b96be06832 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -55,6 +55,8 @@ def __init__(self, child.parent = self def as_string(self, indent: int = 0): + if not self.children: + return (indent + 1) * INDENTATION + 'pass' return '\n'.join([child.as_string(indent + 1) for child in self.children]) def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 4eabf50497..6d41420856 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -73,7 +73,7 @@ def main(a: dace.float64[20, 10]): nest1(a[10:15]) nest1(a[15:]) - sdfg = main.to_sdfg() + sdfg = main.to_sdfg(simplify=True) stree = as_schedule_tree(sdfg) # Despite two levels of nesting, immediate children are the 4 for loops @@ -150,8 +150,10 @@ def test_irreducible_sub_sdfg(): # Add a loop following general block sdfg.add_loop(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1') - # TODO: Missing exit in stateif s2->e - # print(as_schedule_tree(sdfg).as_string()) + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 # Only one gblock + assert node_types[-1] == tn.ForScope # Check that loop was detected def test_irreducible_in_loops(): @@ -176,8 +178,10 @@ def test_irreducible_in_loops(): # Avoiding undefined behavior sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' - # TODO: gblock must cover the greatest common scope its labels are in. - # print(as_schedule_tree(sdfg).as_string()) + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 + assert node_types.count(tn.ForScope) == 2 def test_reference():