Skip to content

Commit

Permalink
Replace add_memlet_path() with add_edge().
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Oct 29, 2024
1 parent e7800f7 commit eb6b838
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions tests/transformations/if_extraction_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import os
from copy import deepcopy
from typing import Dict, Collection

import numpy as np

import dace
from dace import SDFG, InterstateEdge, Memlet
from dace import SDFG, InterstateEdge, Memlet, SDFGState
from dace.transformation.interstate import IfExtraction


def _add_map_with_connectors(st: SDFGState, name: str, ndrange: Dict[str, str],
en_conn_bases: Collection[str] = None, ex_conn_bases: Collection[str] = None):
en, ex = st.add_map(name, ndrange)
if en_conn_bases:
for c in en_conn_bases:
en.add_in_connector(f"IN_{c}")
en.add_out_connector(f"OUT_{c}")
if ex_conn_bases:
for c in ex_conn_bases:
ex.add_in_connector(f"IN_{c}")
ex.add_out_connector(f"OUT_{c}")
return en, ex


def make_branched_sdfg_that_does_not_depend_on_loop_var():
"""
Construct a simple SDFG that does not depend on symbols defined or updated in the outer state, e.g., loop variables.
Expand All @@ -27,20 +42,20 @@ def make_branched_sdfg_that_does_not_depend_on_loop_var():
subg.add_edge(if2, ift, InterstateEdge())
t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval')
tmp = ift.add_access('tmp')
ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]'))
ift.add_edge(t0, '__out', tmp, None, Memlet(expr='tmp[0]'))
subg.fill_scope_connectors()

# Then prepare the parent graph.
g = SDFG('prog')
g.add_array('A', (10,), dace.float32)
g.add_symbol('flag', dace.bool)
st0 = g.add_state('outer', is_start_block=True)
en, ex = st0.add_map('map', {'i': '0:10'})
en, ex = _add_map_with_connectors(st0, 'map', {'i': '0:10'}, [], ['A'])
body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'})
A = st0.add_access('A')
st0.add_memlet_path(en, body, memlet=Memlet())
st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]'))
st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]'))
st0.add_nedge(en, body, Memlet())
st0.add_edge(body, 'tmp', ex, 'IN_A', Memlet(expr='A[i]'))
st0.add_edge(ex, 'OUT_A', A, None, Memlet(expr='A[0:10]'))
g.fill_scope_connectors()

return g
Expand Down Expand Up @@ -86,20 +101,20 @@ def make_branched_sdfg_that_has_intermediate_branchlike_structure():
subg.add_edge(if7, ift, InterstateEdge())
t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval')
tmp = ift.add_access('tmp')
ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]'))
ift.add_edge(t0, '__out', tmp, None, Memlet(expr='tmp[0]'))
subg.fill_scope_connectors()

# Then prepare the parent graph.
g = SDFG('prog')
g.add_array('A', (10,), dace.float32)
g.add_symbol('flag', dace.bool)
st0 = g.add_state('outer', is_start_block=True)
en, ex = st0.add_map('map', {'i': '0:10'})
en, ex = _add_map_with_connectors(st0, 'map', {'i': '0:10'}, [], ['A'])
body = st0.add_nested_sdfg(subg, None, {}, {'tmp'}, symbol_mapping={'flag': 'flag'})
A = st0.add_access('A')
st0.add_memlet_path(en, body, memlet=Memlet())
st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]'))
st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]'))
st0.add_nedge(en, body, Memlet())
st0.add_edge(body, 'tmp', ex, 'IN_A', Memlet(expr='A[i]'))
st0.add_edge(ex, 'OUT_A', A, None, Memlet(expr='A[0:10]'))
g.fill_scope_connectors()

return g
Expand All @@ -123,19 +138,19 @@ def make_branched_sdfg_that_depends_on_loop_var():
subg.add_edge(if2, ift, InterstateEdge())
t0 = ift.add_tasklet('copy', inputs={}, outputs={'__out'}, code='__out = outval')
tmp = ift.add_access('tmp')
ift.add_memlet_path(t0, tmp, src_conn='__out', memlet=Memlet(expr='tmp[0]'))
ift.add_edge(t0, '__out', tmp, None, Memlet(expr='tmp[0]'))
subg.fill_scope_connectors()

# Then prepare the parent graph.
g = SDFG('prog')
g.add_array('A', (10,), dace.float32)
st0 = g.add_state('outer', is_start_block=True)
en, ex = st0.add_map('map', {'i': '0:10'})
en, ex = _add_map_with_connectors(st0, 'map', {'i': '0:10'}, [], ['A'])
body = st0.add_nested_sdfg(subg, None, {}, {'tmp'})
A = st0.add_access('A')
st0.add_memlet_path(en, body, memlet=Memlet())
st0.add_memlet_path(body, ex, src_conn='tmp', dst_conn='IN_A', memlet=Memlet(expr='A[i]'))
st0.add_memlet_path(ex, A, src_conn='OUT_A', memlet=Memlet(expr='A[0:10]'))
st0.add_nedge(en, body, Memlet())
st0.add_edge(body, 'tmp', ex, 'IN_A', Memlet(expr='A[i]'))
st0.add_edge(ex, 'OUT_A', A, None, Memlet(expr='A[0:10]'))
g.fill_scope_connectors()

return g
Expand Down

0 comments on commit eb6b838

Please sign in to comment.