Skip to content

Commit

Permalink
Remove unused global data descriptor shapes from arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Jul 30, 2023
1 parent 60b4045 commit 0040aad
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 11 deletions.
12 changes: 7 additions & 5 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,10 @@ def as_arg(self, with_types=True, for_call=False, name=None):
def free_symbols(self) -> Set[symbolic.SymbolicType]:
""" Returns a set of undefined symbols in this data descriptor. """
result = set()
for s in self.shape:
if isinstance(s, sp.Basic):
result |= set(s.free_symbols)
if self.transient:
for s in self.shape:
if isinstance(s, sp.Basic):
result |= set(s.free_symbols)
return result

def __repr__(self):
Expand Down Expand Up @@ -695,11 +696,12 @@ def free_symbols(self):
for s in self.strides:
if isinstance(s, sp.Expr):
result |= set(s.free_symbols)
if isinstance(self.total_size, sp.Expr):
result |= set(self.total_size.free_symbols)
for o in self.offset:
if isinstance(o, sp.Expr):
result |= set(o.free_symbols)
if self.transient:
if isinstance(self.total_size, sp.Expr):
result |= set(self.total_size.free_symbols)

return result

Expand Down
21 changes: 15 additions & 6 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def new_symbols(self, sdfg, symbols) -> Dict[str, dtypes.typeclass]:
alltypes = symbols

inferred_lhs_symbols = {k: infer_expr_type(v, alltypes) for k, v in self.assignments.items()}

# Symbols in assignment keys are candidate newly defined symbols
lhs_symbols = set()
# Symbols already defined
Expand All @@ -303,7 +303,7 @@ def new_symbols(self, sdfg, symbols) -> Dict[str, dtypes.typeclass]:
# Only add LHS to the set of candidate newly defined symbols if it has not been defined yet
if lhs not in rhs_symbols:
lhs_symbols.add(lhs)

return {k: v for k, v in inferred_lhs_symbols.items() if k in lhs_symbols}

def get_read_memlets(self, arrays: Dict[str, dt.Data]) -> List[mm.Memlet]:
Expand Down Expand Up @@ -593,6 +593,7 @@ def hash_sdfg(self, jsondict: Optional[Dict[str, Any]] = None) -> str:
:param jsondict: If not None, uses given JSON dictionary as input.
:return: The hash (in SHA-256 format).
"""

def keyword_remover(json_obj: Any, last_keyword=""):
# Makes non-unique in SDFG hierarchy v2
# Recursively remove attributes from the SDFG which are not used in
Expand Down Expand Up @@ -1290,14 +1291,22 @@ def free_symbols(self) -> Set[str]:
defined_syms = set()
free_syms = set()

# Start with the set of SDFG free symbols
free_syms |= set(self.symbols.keys())

# Exclude data descriptor names and constants
# Exclude data descriptor names, constants, and shapes of global data descriptors
not_strictly_necessary_global_symbols = set()
for name, desc in self.arrays.items():
defined_syms.add(name)
if not desc.transient:
if symbolic.issymbolic(desc.total_size):
not_strictly_necessary_global_symbols |= set(map(str, desc.total_size.free_symbols))
for s in desc.shape:
if symbolic.issymbolic(s):
not_strictly_necessary_global_symbols |= set(map(str, s.free_symbols))

defined_syms |= set(self.constants_prop.keys())

# Start with the set of SDFG free symbols
free_syms |= set(s for s in self.symbols.keys() if s not in not_strictly_necessary_global_symbols)

# Add free state symbols
used_before_assignment = set()

Expand Down
54 changes: 54 additions & 0 deletions tests/codegen/symbol_arguments_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.

import dace
import numpy as np

N = dace.symbol('N')


def test_global_sizes():

@dace.program
def tester(A: dace.float64[N]):
for i in dace.map[0:10]:
A[i] = 2

sdfg = tester.to_sdfg()
# Since N is not used anywhere, it should not be listed in the arguments
assert 'N' not in sdfg.arglist()

a = np.random.rand(20)
sdfg(a, N=20)
assert np.allclose(a[:10], 2)


def test_global_sizes_used():

@dace.program
def tester(A: dace.float64[N]):
for i in dace.map[0:10]:
with dace.tasklet:
a >> A[i]
a = N

sdfg = tester.to_sdfg()
# N is used in a tasklet
assert 'N' in sdfg.arglist()


def test_global_sizes_multidim():

@dace.program
def tester(A: dace.float64[N, N]):
for i, j in dace.map[0:10, 0:10]:
A[i, j] = 2

sdfg = tester.to_sdfg()
# Here N is implicitly used in the index expression, so it should be in the arguments
assert 'N' in sdfg.arglist()


if __name__ == '__main__':
test_global_sizes()
test_global_sizes_used()
test_global_sizes_multidim()

0 comments on commit 0040aad

Please sign in to comment.