From f431a8df0c99890d5dbeef48674157aa196d6a3e Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 28 Jul 2023 10:29:11 +0200 Subject: [PATCH] Switched Structures and structs to OrderedDicts. --- dace/data.py | 40 ++++++++++++++++++++----------- dace/dtypes.py | 26 ++++++++++---------- tests/sdfg/data/structure_test.py | 8 +++++++ 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/dace/data.py b/dace/data.py index fd7cdaf8e3..b20f9f7db5 100644 --- a/dace/data.py +++ b/dace/data.py @@ -3,8 +3,9 @@ import ctypes import functools +from collections import OrderedDict from numbers import Number -from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple import numpy import sympy as sp @@ -344,40 +345,47 @@ def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global): def _arrays_to_json(arrays): if arrays is None: return None - sorted_keys = sorted(arrays.keys()) - return [(k, serialize.to_json(arrays[k])) for k in sorted_keys] + return [(k, serialize.to_json(v)) for k, v in arrays.items()] def _arrays_from_json(obj, context=None): if obj is None: return {} - return {k: serialize.from_json(v, context) for k, v in obj} + return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) @make_properties class Structure(Data): """ Base class for structures. """ - members = Property(dtype=dict, + members = Property(dtype=OrderedDict, desc="Dictionary of structure members", from_json=_arrays_from_json, to_json=_arrays_to_json) + order = ListProperty(element_type=str, desc="Order of structure members") name = Property(dtype=str, desc="Structure name") def __init__(self, members: Dict[str, Data], + order: List[str] = None, name: str = 'Structure', transient: bool = False, storage: dtypes.StorageType = dtypes.StorageType.Default, location: Dict[str, str] = None, lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, debuginfo: dtypes.DebugInfo = None): + + self.order = order or list(members.keys()) + if set(members.keys()) != set(self.order): + raise ValueError('Order must contain all members of the structure.') + # TODO: Should we make a deep-copy here? - self.members = members or {} + self.members = OrderedDict((k, members[k]) for k in self.order) + for k, v in self.members.items(): v.transient = transient self.name = name - fields_and_types = dict() + fields_and_types = OrderedDict() symbols = set() for k, v in members.items(): if isinstance(v, Structure): @@ -396,13 +404,17 @@ def __init__(self, fields_and_types[k] = dtypes.typeclass(type(v)) else: raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}") - for s in symbols: - if str(s) in fields_and_types: - continue - if hasattr(s, "dtype"): - fields_and_types[str(s)] = s.dtype - else: - fields_and_types[str(s)] = dtypes.int32 + + # NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later. + # NOTE: See discussion about data/object symbols. + # for s in symbols: + # if str(s) in fields_and_types: + # continue + # if hasattr(s, "dtype"): + # fields_and_types[str(s)] = s.dtype + # else: + # fields_and_types[str(s)] = dtypes.int32 + dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types)) shape = (1,) super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo) diff --git a/dace/dtypes.py b/dace/dtypes.py index 9c483d5df1..678f2f59b0 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -7,6 +7,7 @@ import itertools import numpy import re +from collections import OrderedDict from functools import wraps from typing import Any from dace.config import Config @@ -768,12 +769,11 @@ def fields(self): return self._data def to_json(self): - sorted_keys = sorted(self._data.keys()) return { 'type': 'struct', 'name': self.name, - 'data': [(k, self._data[k].to_json()) for k in sorted_keys], - 'length': [(k, self._length[k]) for k in sorted_keys if k in self._length], + 'data': [(k, v.to_json()) for k, v in self._data.items()], + 'length': [(k, v) for k, v in self._length.items()], 'bytes': self.bytes } @@ -792,19 +792,21 @@ def from_json(json_obj, context=None): return ret def _parse_field_and_types(self, **fields_and_types): - from dace.symbolic import pystr_to_symbolic - self._data = dict() - self._length = dict() + # from dace.symbolic import pystr_to_symbolic + self._data = OrderedDict() + self._length = OrderedDict() self.bytes = 0 for k, v in fields_and_types.items(): if isinstance(v, tuple): t, l = v if not isinstance(t, pointer): raise TypeError("Only pointer types may have a length.") - sym_tokens = pystr_to_symbolic(l).free_symbols - for sym in sym_tokens: - if str(sym) not in fields_and_types.keys(): - raise ValueError(f"Symbol {sym} in {k}'s length {l} is not a field of struct {self.name}") + # TODO: Do we need the free symbols of the length in the struct? + # NOTE: It is needed for the old use of dtype.struct. Are we deprecating that? + # sym_tokens = pystr_to_symbolic(l).free_symbols + # for sym in sym_tokens: + # if str(sym) not in fields_and_types.keys(): + # raise ValueError(f"Symbol {sym} in {k}'s length {l} is not a field of struct {self.name}") self._data[k] = t self._length[k] = l self.bytes += t.bytes @@ -830,7 +832,7 @@ def as_ctypes(self): fields.append((k, v.as_ctypes())) else: fields.append((k, _FFI_CTYPES[v.type])) - fields = sorted(fields, key=lambda f: f[0]) + # fields = sorted(fields, key=lambda f: f[0]) # Create new struct class. struct_class = type("NewStructClass", (ctypes.Structure, ), {"_fields_": fields}) _FFI_CTYPES[self] = struct_class @@ -844,7 +846,7 @@ def emit_definition(self): {typ} }};""".format( name=self.name, - typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in sorted(self._data.items())]), + typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in self._data.items()]), ) diff --git a/tests/sdfg/data/structure_test.py b/tests/sdfg/data/structure_test.py index 02b8f0c174..995aacb2fd 100644 --- a/tests/sdfg/data/structure_test.py +++ b/tests/sdfg/data/structure_test.py @@ -12,6 +12,7 @@ def test_read_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix') sdfg = dace.SDFG('csr_to_dense') @@ -68,6 +69,7 @@ def test_write_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix') sdfg = dace.SDFG('dense_to_csr') @@ -145,8 +147,10 @@ def test_local_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix') tmp_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix', transient=True) @@ -254,6 +258,7 @@ def test_local_structure(): def test_read_nested_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix') wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') @@ -315,6 +320,7 @@ def test_write_nested_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix') wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') @@ -396,6 +402,7 @@ def test_direct_read_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix') sdfg = dace.SDFG('csr_to_dense_direct') @@ -446,6 +453,7 @@ def test_direct_read_structure(): def test_direct_read_nested_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + order=['indptr', 'indices', 'data'], name='CSRMatrix') wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper')