Skip to content

Commit

Permalink
Switched Structures and structs to OrderedDicts.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Jul 28, 2023
1 parent a98fce0 commit f431a8d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
40 changes: 26 additions & 14 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import ctypes
import functools

from collections import OrderedDict
from numbers import Number
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple

import numpy
import sympy as sp
Expand Down Expand Up @@ -344,40 +345,47 @@ def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global):
def _arrays_to_json(arrays):
if arrays is None:
return None
sorted_keys = sorted(arrays.keys())
return [(k, serialize.to_json(arrays[k])) for k in sorted_keys]
return [(k, serialize.to_json(v)) for k, v in arrays.items()]


def _arrays_from_json(obj, context=None):
if obj is None:
return {}
return {k: serialize.from_json(v, context) for k, v in obj}
return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj)


@make_properties
class Structure(Data):
""" Base class for structures. """

members = Property(dtype=dict,
members = Property(dtype=OrderedDict,
desc="Dictionary of structure members",
from_json=_arrays_from_json,
to_json=_arrays_to_json)
order = ListProperty(element_type=str, desc="Order of structure members")
name = Property(dtype=str, desc="Structure name")

def __init__(self,
members: Dict[str, Data],
order: List[str] = None,
name: str = 'Structure',
transient: bool = False,
storage: dtypes.StorageType = dtypes.StorageType.Default,
location: Dict[str, str] = None,
lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope,
debuginfo: dtypes.DebugInfo = None):

self.order = order or list(members.keys())
if set(members.keys()) != set(self.order):
raise ValueError('Order must contain all members of the structure.')

# TODO: Should we make a deep-copy here?
self.members = members or {}
self.members = OrderedDict((k, members[k]) for k in self.order)

for k, v in self.members.items():
v.transient = transient
self.name = name
fields_and_types = dict()
fields_and_types = OrderedDict()
symbols = set()
for k, v in members.items():
if isinstance(v, Structure):
Expand All @@ -396,13 +404,17 @@ def __init__(self,
fields_and_types[k] = dtypes.typeclass(type(v))
else:
raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}")
for s in symbols:
if str(s) in fields_and_types:
continue
if hasattr(s, "dtype"):
fields_and_types[str(s)] = s.dtype
else:
fields_and_types[str(s)] = dtypes.int32

# NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later.
# NOTE: See discussion about data/object symbols.
# for s in symbols:
# if str(s) in fields_and_types:
# continue
# if hasattr(s, "dtype"):
# fields_and_types[str(s)] = s.dtype
# else:
# fields_and_types[str(s)] = dtypes.int32

dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types))
shape = (1,)
super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo)
Expand Down
26 changes: 14 additions & 12 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import itertools
import numpy
import re
from collections import OrderedDict
from functools import wraps
from typing import Any
from dace.config import Config
Expand Down Expand Up @@ -768,12 +769,11 @@ def fields(self):
return self._data

def to_json(self):
sorted_keys = sorted(self._data.keys())
return {
'type': 'struct',
'name': self.name,
'data': [(k, self._data[k].to_json()) for k in sorted_keys],
'length': [(k, self._length[k]) for k in sorted_keys if k in self._length],
'data': [(k, v.to_json()) for k, v in self._data.items()],
'length': [(k, v) for k, v in self._length.items()],
'bytes': self.bytes
}

Expand All @@ -792,19 +792,21 @@ def from_json(json_obj, context=None):
return ret

def _parse_field_and_types(self, **fields_and_types):
from dace.symbolic import pystr_to_symbolic
self._data = dict()
self._length = dict()
# from dace.symbolic import pystr_to_symbolic
self._data = OrderedDict()
self._length = OrderedDict()
self.bytes = 0
for k, v in fields_and_types.items():
if isinstance(v, tuple):
t, l = v
if not isinstance(t, pointer):
raise TypeError("Only pointer types may have a length.")
sym_tokens = pystr_to_symbolic(l).free_symbols
for sym in sym_tokens:
if str(sym) not in fields_and_types.keys():
raise ValueError(f"Symbol {sym} in {k}'s length {l} is not a field of struct {self.name}")
# TODO: Do we need the free symbols of the length in the struct?
# NOTE: It is needed for the old use of dtype.struct. Are we deprecating that?
# sym_tokens = pystr_to_symbolic(l).free_symbols
# for sym in sym_tokens:
# if str(sym) not in fields_and_types.keys():
# raise ValueError(f"Symbol {sym} in {k}'s length {l} is not a field of struct {self.name}")
self._data[k] = t
self._length[k] = l
self.bytes += t.bytes
Expand All @@ -830,7 +832,7 @@ def as_ctypes(self):
fields.append((k, v.as_ctypes()))
else:
fields.append((k, _FFI_CTYPES[v.type]))
fields = sorted(fields, key=lambda f: f[0])
# fields = sorted(fields, key=lambda f: f[0])
# Create new struct class.
struct_class = type("NewStructClass", (ctypes.Structure, ), {"_fields_": fields})
_FFI_CTYPES[self] = struct_class
Expand All @@ -844,7 +846,7 @@ def emit_definition(self):
{typ}
}};""".format(
name=self.name,
typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in sorted(self._data.items())]),
typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in self._data.items()]),
)


Expand Down
8 changes: 8 additions & 0 deletions tests/sdfg/data/structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_read_structure():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix')

sdfg = dace.SDFG('csr_to_dense')
Expand Down Expand Up @@ -68,6 +69,7 @@ def test_write_structure():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix')

sdfg = dace.SDFG('dense_to_csr')
Expand Down Expand Up @@ -145,8 +147,10 @@ def test_local_structure():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix')
tmp_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix',
transient=True)

Expand Down Expand Up @@ -254,6 +258,7 @@ def test_local_structure():
def test_read_nested_structure():
M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix')
wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper')

Expand Down Expand Up @@ -315,6 +320,7 @@ def test_write_nested_structure():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix')
wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper')

Expand Down Expand Up @@ -396,6 +402,7 @@ def test_direct_read_structure():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix')

sdfg = dace.SDFG('csr_to_dense_direct')
Expand Down Expand Up @@ -446,6 +453,7 @@ def test_direct_read_structure():
def test_direct_read_nested_structure():
M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
order=['indptr', 'indices', 'data'],
name='CSRMatrix')
wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper')

Expand Down

0 comments on commit f431a8d

Please sign in to comment.