From 5bdff31f1593e02e93382b968b7447a62e60d45b Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 29 Oct 2024 01:07:05 +0100 Subject: [PATCH] Add one more test with a more complicated structure (and to cover a tiny bit more of corner cases). --- tests/transformations/if_extraction_test.py | 98 ++++++++++++++++++++- 1 file changed, 96 insertions(+), 2 deletions(-) diff --git a/tests/transformations/if_extraction_test.py b/tests/transformations/if_extraction_test.py index bc0b6d7452..cb37932849 100644 --- a/tests/transformations/if_extraction_test.py +++ b/tests/transformations/if_extraction_test.py @@ -46,6 +46,65 @@ def make_branched_sdfg_that_does_not_depend_on_loop_var(): return g +def make_branched_sdfg_that_has_intermediate_branchlike_structure(): + """ + Construct an SDFG that has this structure: + initial_state + / \\ + state_1 state_2 + | | + state_3 state_5 + \\ / + state_5 + / \ + state_6 state_7 + \\ / + terminal_state + """ + # First prepare the map-body. + subg = SDFG('body') + subg.add_array('tmp', (1,), dace.float32) + subg.add_symbol('outval', dace.float32) + ifh = subg.add_state('if_head') + if1 = subg.add_state('state_1') + if3 = subg.add_state('state_2') + if2 = subg.add_state('state_3') + if4 = subg.add_state('state_4') + if5 = subg.add_state('state_5') + if6 = subg.add_state('state_6') + if7 = subg.add_state('state_7') + ift = subg.add_state('if_tail') + subg.add_edge(ifh, if1, InterstateEdge(condition='(flag)', assignments={'outval': 1})) + subg.add_edge(ifh, if2, InterstateEdge(condition='(not flag)', assignments={'outval': 2})) + subg.add_edge(if1, if3, InterstateEdge()) + subg.add_edge(if3, if5, InterstateEdge()) + subg.add_edge(if2, if4, InterstateEdge()) + subg.add_edge(if4, if5, InterstateEdge()) + subg.add_edge(if5, if6, InterstateEdge()) + subg.add_edge(if5, if7, InterstateEdge()) + subg.add_edge(if6, ift, InterstateEdge()) + subg.add_edge(if7, ift, InterstateEdge()) + t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval') + tmp = ift.add_access('tmp') + ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]')) + subg.fill_scope_connectors() + + # Then prepare the parent graph. + g = SDFG('prog') + g.add_array('A', (10,), dace.float32) + g.add_symbol('flag', dace.bool) + st0 = g.add_state('outer', is_start_block=True) + en, ex = st0.add_map('map', {'i': '0:10'}) + body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'}) + A = st0.add_access('A') + st0.add_memlet_path(en, body, memlet=Memlet()) + st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]')) + st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]')) + g.fill_scope_connectors() + + return g + + def make_branched_sdfg_that_depends_on_loop_var(): """ Construct a simple SDFG that depends on symbols defined or updated in the outer state, e.g., loop variables. @@ -119,7 +178,42 @@ def test_simple_application(): assert all(np.equal(wantA_2, gotA_2)) -def test_fails_due_to_dependency_on_loop_var(): +def test_extracts_even_with_intermediate_branchlike_structure(): + origA = np.zeros((10,), np.float32) + + g = make_branched_sdfg_that_has_intermediate_branchlike_structure() + g.save(os.path.join('_dacegraphs', 'intermediate_branch-0.sdfg')) + g.validate() + g.compile() + + # Get the expected values. + wantA_1 = deepcopy(origA) + wantA_2 = deepcopy(origA) + g(A=wantA_1, flag=True) + g(A=wantA_2, flag=False) + + # Before, the outer graph had only one nested SDFG. + assert len(g.nodes()) == 1 + + assert g.apply_transformations_repeated([IfExtraction]) == 1 + g.save(os.path.join('_dacegraphs', 'intermediate_branch-1.sdfg')) + + # Get the values from transformed program. + gotA_1 = deepcopy(origA) + gotA_2 = deepcopy(origA) + g(A=gotA_1, flag=True) + g(A=gotA_2, flag=False) + + # But now, the outer graph have four: two copies of the original nested SDFGs and two for branch management. + assert len(g.nodes()) == 4 + assert g.start_state.is_empty() + + # Verify numerically. + assert all(np.equal(wantA_1, gotA_1)) + assert all(np.equal(wantA_2, gotA_2)) + + +def test_no_extraction_due_to_dependency_on_loop_var(): g = make_branched_sdfg_that_depends_on_loop_var() g.save(os.path.join('_dacegraphs', 'dependent-0.sdfg')) @@ -128,4 +222,4 @@ def test_fails_due_to_dependency_on_loop_var(): if __name__ == '__main__': test_simple_application() - test_fails_due_to_dependency_on_loop_var() + test_no_extraction_due_to_dependency_on_loop_var()