diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index d0d29cfa1e..9ee0772eeb 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -452,9 +452,10 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: # GPU scalars are pointers, so this is fine if atype.storage != dtypes.StorageType.GPU_Global: raise TypeError('Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a)) - elif not isinstance(atype, dt.Array) and not isinstance(atype.dtype, dtypes.callback) and not isinstance( - arg, - (atype.dtype.type, sp.Basic)) and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype): + elif (not isinstance(atype, (dt.Array, dt.Structure)) and + not isinstance(atype.dtype, dtypes.callback) and + not isinstance(arg, (atype.dtype.type, sp.Basic)) and + not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)): if isinstance(arg, int) and atype.dtype.type == np.int64: pass elif isinstance(arg, float) and atype.dtype.type == np.float64: @@ -472,7 +473,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: else: warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}') arglist[i] = atype.dtype.type(arg) - elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) + elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) and not isinstance(atype, dt.StructArray) and atype.dtype.as_numpy_dtype() != arg.dtype): # Make exception for vector types if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): @@ -521,7 +522,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: # Construct init args, which only consist of the symbols symbols = self._free_symbols initargs = tuple( - actype(arg) if (not isinstance(arg, ctypes._SimpleCData)) else arg + actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg for arg, actype, atype, aname in callparams if aname in symbols) # Replace arrays with their base host/device pointers @@ -531,7 +532,8 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: try: newargs = tuple( - actype(arg) if (not isinstance(arg, ctypes._SimpleCData)) else arg for arg, actype, atype in newargs) + actype(arg) if not isinstance(arg, (ctypes._SimpleCData)) else arg + for arg, actype, atype in newargs) except TypeError: # Pinpoint bad argument for i, (arg, actype, _) in enumerate(newargs): diff --git a/dace/codegen/dispatcher.py b/dace/codegen/dispatcher.py index be032556a0..359d3a5853 100644 --- a/dace/codegen/dispatcher.py +++ b/dace/codegen/dispatcher.py @@ -505,11 +505,11 @@ def get_copy_dispatcher(self, src_node, dst_node, edge, sdfg, state): dst_is_data = True # Skip copies to/from views where edge matches - if src_is_data and isinstance(src_node.desc(sdfg), dt.View): + if src_is_data and isinstance(src_node.desc(sdfg), (dt.StructureView, dt.View)): e = sdutil.get_view_edge(state, src_node) if e is edge: return None - if dst_is_data and isinstance(dst_node.desc(sdfg), dt.View): + if dst_is_data and isinstance(dst_node.desc(sdfg), (dt.StructureView, dt.View)): e = sdutil.get_view_edge(state, dst_node) if e is edge: return None diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 264311a45c..d3d4f50ccd 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -370,6 +370,10 @@ def make_const(expr: str) -> str: # Register defined variable dispatcher.defined_vars.add(pointer_name, defined_type, typedef, allow_shadowing=True) + # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and structures. + # NOTE: Since structures are implemented as pointers, we replace dots with arrows. + expr = expr.replace('.', '->') + return (typedef + ref, pointer_name, expr) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 9bca137d51..0464672390 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -55,10 +55,29 @@ def __init__(self, frame_codegen, sdfg): # Keep track of generated NestedSDG, and the name of the assigned function self._generated_nested_sdfg = dict() - # Keeps track of generated connectors, so we know how to access them in - # nested scopes + # NOTE: Multi-nesting with StructArrays must be further investigated. + def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): + for k, v in struct.members.items(): + if isinstance(v, data.Structure): + _visit_structure(v, args, f'{prefix}.{k}') + elif isinstance(v, data.StructArray): + _visit_structure(v.stype, args, f'{prefix}.{k}') + elif isinstance(v, data.Data): + args[f'{prefix}.{k}'] = v + + # Keeps track of generated connectors, so we know how to access them in nested scopes + arglist = dict(self._frame.arglist) for name, arg_type in self._frame.arglist.items(): - if isinstance(arg_type, data.Scalar): + if isinstance(arg_type, data.Structure): + desc = sdfg.arrays[name] + _visit_structure(arg_type, arglist, name) + elif isinstance(arg_type, data.StructArray): + desc = sdfg.arrays[name] + desc = desc.stype + _visit_structure(desc, arglist, name) + + for name, arg_type in arglist.items(): + if isinstance(arg_type, (data.Scalar, data.Structure)): # GPU global memory is only accessed via pointers # TODO(later): Fix workaround somehow if arg_type.storage is dtypes.StorageType.GPU_Global: @@ -195,9 +214,21 @@ def allocate_view(self, sdfg: SDFG, dfg: SDFGState, state_id: int, node: nodes.A ancestor=0, is_write=is_write) if not declared: - declaration_stream.write(f'{atype} {aname};', sdfg, state_id, node) ctypedef = dtypes.pointer(nodedesc.dtype).ctype self._dispatcher.declared_arrays.add(aname, DefinedType.Pointer, ctypedef) + if isinstance(nodedesc, data.StructureView): + for k, v in nodedesc.members.items(): + if isinstance(v, data.Data): + ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype + defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer + self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef) + self._dispatcher.defined_vars.add(f"{name}.{k}", defined_type, ctypedef) + # TODO: Find a better way to do this (the issue is with pointers of pointers) + if atype.endswith('*'): + atype = atype[:-1] + if value.startswith('&'): + value = value[1:] + declaration_stream.write(f'{atype} {aname};', sdfg, state_id, node) allocation_stream.write(f'{aname} = {value};', sdfg, state_id, node) def allocate_reference(self, sdfg: SDFG, dfg: SDFGState, state_id: int, node: nodes.AccessNode, @@ -268,16 +299,19 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d name = node.data alloc_name = cpp.ptr(name, nodedesc, sdfg, self._frame) name = alloc_name + # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and + # NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows. + alloc_name = alloc_name.replace('.', '->') if nodedesc.transient is False: return # Check if array is already allocated - if self._dispatcher.defined_vars.has(alloc_name): + if self._dispatcher.defined_vars.has(name): return # Check if array is already declared - declared = self._dispatcher.declared_arrays.has(alloc_name) + declared = self._dispatcher.declared_arrays.has(name) define_var = self._dispatcher.defined_vars.add if nodedesc.lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External): @@ -290,7 +324,18 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d if not isinstance(nodedesc.dtype, dtypes.opaque): arrsize_bytes = arrsize * nodedesc.dtype.bytes - if isinstance(nodedesc, data.View): + if isinstance(nodedesc, data.Structure) and not isinstance(nodedesc, data.StructureView): + declaration_stream.write(f"{nodedesc.ctype} {name} = new {nodedesc.dtype.base_type};\n") + define_var(name, DefinedType.Pointer, nodedesc.ctype) + for k, v in nodedesc.members.items(): + if isinstance(v, data.Data): + ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype + defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer + self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef) + self.allocate_array(sdfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, function_stream, + declaration_stream, allocation_stream) + return + if isinstance(nodedesc, (data.StructureView, data.View)): return self.allocate_view(sdfg, dfg, state_id, node, function_stream, declaration_stream, allocation_stream) if isinstance(nodedesc, data.Reference): return self.allocate_reference(sdfg, dfg, state_id, node, function_stream, declaration_stream, @@ -455,7 +500,7 @@ def deallocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, dtypes.AllocationLifetime.External) self._dispatcher.declared_arrays.remove(alloc_name, is_global=is_global) - if isinstance(nodedesc, (data.Scalar, data.View, data.Stream, data.Reference)): + if isinstance(nodedesc, (data.Scalar, data.StructureView, data.View, data.Stream, data.Reference)): return elif (nodedesc.storage == dtypes.StorageType.CPU_Heap or (nodedesc.storage == dtypes.StorageType.Register and symbolic.issymbolic(arrsize, sdfg.constants))): @@ -1139,6 +1184,9 @@ def memlet_definition(self, if not types: types = self._dispatcher.defined_vars.get(ptr, is_global=True) var_type, ctypedef = types + # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and + # NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows. + ptr = ptr.replace('.', '->') if fpga.is_fpga_array(desc): decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 56419b9701..9ee5c2ef17 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -153,15 +153,23 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend: for _, arrname, arr in sdfg.arrays_recursive(): if arr is not None: datatypes.add(arr.dtype) + + def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool: + if isinstance(dtype, dtypes.pointer): + wrote_something = _emit_definitions(dtype._typeclass, wrote_something) + elif isinstance(dtype, dtypes.struct): + for field in dtype.fields.values(): + wrote_something = _emit_definitions(field, wrote_something) + if hasattr(dtype, 'emit_definition'): + if not wrote_something: + global_stream.write("", sdfg) + global_stream.write(dtype.emit_definition(), sdfg) + return wrote_something # Emit unique definitions wrote_something = False for typ in datatypes: - if hasattr(typ, 'emit_definition'): - if not wrote_something: - global_stream.write("", sdfg) - wrote_something = True - global_stream.write(typ.emit_definition(), sdfg) + wrote_something = _emit_definitions(typ, wrote_something) if wrote_something: global_stream.write("", sdfg) @@ -741,7 +749,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): instances = access_instances[sdfg.sdfg_id][name] # A view gets "allocated" everywhere it appears - if isinstance(desc, data.View): + if isinstance(desc, (data.StructureView, data.View)): for s, n in instances: self.to_allocate[s].append((sdfg, s, n, False, True, False)) self.to_allocate[s].append((sdfg, s, n, False, False, True)) diff --git a/dace/data.py b/dace/data.py index d492d06258..3b571e6537 100644 --- a/dace/data.py +++ b/dace/data.py @@ -1,10 +1,11 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import copy as cp import ctypes import functools -import re + +from collections import OrderedDict from numbers import Number -from typing import Any, Dict, Optional, Sequence, Set, Tuple +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import numpy import sympy as sp @@ -17,8 +18,8 @@ import dace.dtypes as dtypes from dace import serialize, symbolic from dace.codegen import cppunparse -from dace.properties import (CodeProperty, DebugInfoProperty, DictProperty, EnumProperty, ListProperty, Property, - ReferenceProperty, ShapeProperty, SubsetProperty, SymbolicProperty, TypeClassProperty, +from dace.properties import (DebugInfoProperty, DictProperty, EnumProperty, ListProperty, NestedDataClassProperty, + OrderedDictProperty, Property, ShapeProperty, SymbolicProperty, TypeClassProperty, make_properties) @@ -354,6 +355,157 @@ def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global): return new_desc +def _arrays_to_json(arrays): + if arrays is None: + return None + return [(k, serialize.to_json(v)) for k, v in arrays.items()] + + +def _arrays_from_json(obj, context=None): + if obj is None: + return {} + return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) + + +@make_properties +class Structure(Data): + """ Base class for structures. """ + + members = OrderedDictProperty(default=OrderedDict(), + desc="Dictionary of structure members", + from_json=_arrays_from_json, + to_json=_arrays_to_json) + name = Property(dtype=str, desc="Structure type name") + + def __init__(self, + members: Union[Dict[str, Data], List[Tuple[str, Data]]], + name: str = 'Structure', + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + + self.members = OrderedDict(members) + for k, v in self.members.items(): + v.transient = transient + + self.name = name + fields_and_types = OrderedDict() + symbols = set() + for k, v in self.members.items(): + if isinstance(v, Structure): + symbols |= v.free_symbols + fields_and_types[k] = (v.dtype, str(v.total_size)) + elif isinstance(v, Array): + symbols |= v.free_symbols + fields_and_types[k] = (dtypes.pointer(v.dtype), str(_prod(v.shape))) + elif isinstance(v, Scalar): + symbols |= v.free_symbols + fields_and_types[k] = v.dtype + elif isinstance(v, (sp.Basic, symbolic.SymExpr)): + symbols |= v.free_symbols + fields_and_types[k] = symbolic.symtype(v) + elif isinstance(v, (int, numpy.integer)): + fields_and_types[k] = dtypes.typeclass(type(v)) + else: + raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}") + + # NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later. + # NOTE: See discussion about data/object symbols. + # for s in symbols: + # if str(s) in fields_and_types: + # continue + # if hasattr(s, "dtype"): + # fields_and_types[str(s)] = s.dtype + # else: + # fields_and_types[str(s)] = dtypes.int32 + + dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types)) + shape = (1,) + super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo) + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Structure': + raise TypeError("Invalid data type") + + # Create dummy object + ret = Structure({}) + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + @property + def total_size(self): + return -1 + + @property + def offset(self): + return [0] + + @property + def start_offset(self): + return 0 + + @property + def strides(self): + return [1] + + @property + def free_symbols(self) -> Set[symbolic.SymbolicType]: + """ Returns a set of undefined symbols in this data descriptor. """ + result = set() + for k, v in self.members.items(): + result |= v.free_symbols + return result + + def __repr__(self): + return f"{self.name} ({', '.join([f'{k}: {v}' for k, v in self.members.items()])})" + + def as_arg(self, with_types=True, for_call=False, name=None): + if self.storage is dtypes.StorageType.GPU_Global: + return Array(self.dtype, [1]).as_arg(with_types, for_call, name) + if not with_types or for_call: + return name + return self.dtype.as_arg(name) + + def __getitem__(self, s): + """ This is syntactic sugar that allows us to define an array type + with the following syntax: ``Structure[N,M]`` + :return: A ``data.StructArray`` data descriptor. + """ + if isinstance(s, list) or isinstance(s, tuple): + return StructArray(self, tuple(s)) + return StructArray(self, (s, )) + + +@make_properties +class StructureView(Structure): + """ + Data descriptor that acts as a reference (or view) of another structure. + """ + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'StructureView': + raise TypeError("Invalid data type") + + # Create dummy object + ret = StructureView({}) + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + def validate(self): + super().validate() + + # We ensure that allocation lifetime is always set to Scope, since the + # view is generated upon "allocation" + if self.lifetime != dtypes.AllocationLifetime.Scope: + raise ValueError('Only Scope allocation lifetime is supported for Views') + + @make_properties class Scalar(Data): """ Data descriptor of a scalar value. """ @@ -920,6 +1072,56 @@ def free_symbols(self): return self.used_symbols(all_symbols=True) +@make_properties +class StructArray(Array): + """ Array of Structures. """ + + stype = NestedDataClassProperty(allow_none=True, default=None) + + def __init__(self, + stype: Structure, + shape, + transient=False, + allow_conflicts=False, + storage=dtypes.StorageType.Default, + location=None, + strides=None, + offset=None, + may_alias=False, + lifetime=dtypes.AllocationLifetime.Scope, + alignment=0, + debuginfo=None, + total_size=-1, + start_offset=None, + optional=None, + pool=False): + + self.stype = stype + if stype: + dtype = stype.dtype + else: + dtype = dtypes.int8 + super(StructArray, self).__init__(dtype, shape, transient, allow_conflicts, storage, location, strides, offset, + may_alias, lifetime, alignment, debuginfo, total_size, start_offset, optional, pool) + + @classmethod + def from_json(cls, json_obj, context=None): + # Create dummy object + ret = cls(None, ()) + serialize.set_properties_from_json(ret, json_obj, context=context) + + # Default shape-related properties + if not ret.offset: + ret.offset = [0] * len(ret.shape) + if not ret.strides: + # Default strides are C-ordered + ret.strides = [_prod(ret.shape[i + 1:]) for i in range(len(ret.shape))] + if ret.total_size == 0: + ret.total_size = _prod(ret.shape) + + return ret + + @make_properties class View(Array): """ diff --git a/dace/dtypes.py b/dace/dtypes.py index cbbc4125c1..f0bac23958 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ A module that contains various DaCe type definitions. """ from __future__ import print_function import ctypes @@ -7,6 +7,7 @@ import itertools import numpy import re +from collections import OrderedDict from functools import wraps from typing import Any from dace.config import Config @@ -657,6 +658,8 @@ def from_json(json_obj, context=None): def as_ctypes(self): """ Returns the ctypes version of the typeclass. """ + if isinstance(self._typeclass, struct): + return ctypes.POINTER(self._typeclass.as_ctypes()) return ctypes.POINTER(_FFI_CTYPES[self.type]) def as_numpy_dtype(self): @@ -772,10 +775,8 @@ def to_json(self): return { 'type': 'struct', 'name': self.name, - 'data': {k: v.to_json() - for k, v in self._data.items()}, - 'length': {k: v - for k, v in self._length.items()}, + 'data': [(k, v.to_json()) for k, v in self._data.items()], + 'length': [(k, v) for k, v in self._length.items()], 'bytes': self.bytes } @@ -787,23 +788,28 @@ def from_json(json_obj, context=None): import dace.serialize # Avoid import loop ret = struct(json_obj['name']) - ret._data = {k: json_to_typeclass(v, context) for k, v in json_obj['data'].items()} - ret._length = {k: v for k, v in json_obj['length'].items()} + ret._data = {k: json_to_typeclass(v, context) for k, v in json_obj['data']} + ret._length = {k: v for k, v in json_obj['length']} ret.bytes = json_obj['bytes'] return ret def _parse_field_and_types(self, **fields_and_types): - self._data = dict() - self._length = dict() + # from dace.symbolic import pystr_to_symbolic + self._data = OrderedDict() + self._length = OrderedDict() self.bytes = 0 for k, v in fields_and_types.items(): if isinstance(v, tuple): t, l = v if not isinstance(t, pointer): raise TypeError("Only pointer types may have a length.") - if l not in fields_and_types.keys(): - raise ValueError("Length {} not a field of struct {}".format(l, self.name)) + # TODO: Do we need the free symbols of the length in the struct? + # NOTE: It is needed for the old use of dtype.struct. Are we deprecating that? + # sym_tokens = pystr_to_symbolic(l).free_symbols + # for sym in sym_tokens: + # if str(sym) not in fields_and_types.keys(): + # raise ValueError(f"Symbol {sym} in {k}'s length {l} is not a field of struct {self.name}") self._data[k] = t self._length[k] = l self.bytes += t.bytes @@ -815,16 +821,24 @@ def _parse_field_and_types(self, **fields_and_types): def as_ctypes(self): """ Returns the ctypes version of the typeclass. """ + if self in _FFI_CTYPES: + return _FFI_CTYPES[self] # Populate the ctype fields for the struct class. fields = [] for k, v in self._data.items(): if isinstance(v, pointer): - fields.append((k, ctypes.c_void_p)) # ctypes.POINTER(_FFI_CTYPES[v.type]))) + if isinstance(v._typeclass, struct): + fields.append((k, ctypes.POINTER(v._typeclass.as_ctypes()))) + else: + fields.append((k, ctypes.c_void_p)) + elif isinstance(v, struct): + fields.append((k, v.as_ctypes())) else: fields.append((k, _FFI_CTYPES[v.type])) - fields = sorted(fields, key=lambda f: f[0]) # Create new struct class. struct_class = type("NewStructClass", (ctypes.Structure, ), {"_fields_": fields}) + # NOTE: Each call to `type` returns a different class, so we need to cache it to ensure uniqueness. + _FFI_CTYPES[self] = struct_class return struct_class def as_numpy_dtype(self): @@ -835,7 +849,7 @@ def emit_definition(self): {typ} }};""".format( name=self.name, - typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in sorted(self._data.items())]), + typ='\n'.join([" %s %s;" % (t.ctype, tname) for tname, t in self._data.items()]), ) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 009f45ca10..9643d51c1f 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -986,7 +986,7 @@ def _argminmax(pv: ProgramVisitor, reduced_shape = list(copy.deepcopy(a_arr.shape)) reduced_shape.pop(axis) - val_and_idx = dace.struct('_val_and_idx', val=a_arr.dtype, idx=result_type) + val_and_idx = dace.struct('_val_and_idx', idx=result_type, val=a_arr.dtype) # HACK: since identity cannot be specified for structs, we have to init the output array reduced_structs, reduced_struct_arr = sdfg.add_temp_transient(reduced_shape, val_and_idx) diff --git a/dace/properties.py b/dace/properties.py index 951a0564cc..61e569341f 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import ast from collections import OrderedDict import copy @@ -145,11 +145,15 @@ def fs(obj, *args, **kwargs): self._from_json = lambda *args, **kwargs: dace.serialize.from_json(*args, known_type=dtype, **kwargs) else: self._from_json = from_json + if self.from_json != from_json: + self.from_json = from_json if to_json is None: self._to_json = dace.serialize.to_json else: self._to_json = to_json + if self.to_json != to_json: + self.to_json = to_json if meta_to_json is None: @@ -412,8 +416,7 @@ def initialize_properties(obj, *args, **kwargs): except AttributeError: if not prop.unmapped: raise PropertyError("Property {} is unassigned in __init__ for {}".format(name, cls.__name__)) - # Assert that there are no fields in the object not captured by - # properties, unless they are prefixed with "_" + # Assert that there are no fields in the object not captured by properties, unless they are prefixed with "_" for name, prop in obj.__dict__.items(): if (name not in properties and not name.startswith("_") and name not in dir(type(obj))): raise PropertyError("{} : Variable {} is neither a Property nor " @@ -1385,6 +1388,47 @@ def from_json(obj, context=None): raise TypeError("Cannot parse type from: {}".format(obj)) +class NestedDataClassProperty(Property): + """ Custom property type for nested data. """ + + def __get__(self, obj, objtype=None) -> 'Data': + return super().__get__(obj, objtype) + + @property + def dtype(self): + from dace import data as dt + return dt.Data + + @staticmethod + def from_string(s): + from dace import data as dt + dtype = getattr(dt, s, None) + if dtype is None or not isinstance(dtype, dt.Data): + raise ValueError("Not a valid data type: {}".format(s)) + return dtype + + @staticmethod + def to_string(obj): + return obj.to_string() + + def to_json(self, obj): + if obj is None: + return None + return obj.to_json() + + @staticmethod + def from_json(obj, context=None): + if obj is None: + return None + elif isinstance(obj, str): + return NestedDataClassProperty.from_string(obj) + elif isinstance(obj, dict): + # Let the deserializer handle this + return dace.serialize.from_json(obj) + else: + raise TypeError("Cannot parse type from: {}".format(obj)) + + class LibraryImplementationProperty(Property): """ Property for choosing an implementation type for a library node. On the diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index f3a37ef08c..a23d2616f9 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -48,6 +48,41 @@ from dace.codegen.compiled_sdfg import CompiledSDFG +class NestedDict(dict): + + def __init__(self, mapping=None): + mapping = mapping or {} + super(NestedDict, self).__init__(mapping) + + def __getitem__(self, key): + tokens = key.split('.') if isinstance(key, str) else [key] + token = tokens.pop(0) + result = super(NestedDict, self).__getitem__(token) + while tokens: + token = tokens.pop(0) + result = result.members[token] + return result + + def __setitem__(self, key, val): + if isinstance(key, str) and '.' in key: + raise KeyError('NestedDict does not support setting nested keys') + super(NestedDict, self).__setitem__(key, val) + + def __contains__(self, key): + tokens = key.split('.') if isinstance(key, str) else [key] + token = tokens.pop(0) + result = super(NestedDict, self).__contains__(token) + desc = None + while tokens and result: + if desc is None: + desc = super(NestedDict, self).__getitem__(token) + else: + desc = desc.members[token] + token = tokens.pop(0) + result = token in desc.members + return result + + def _arrays_to_json(arrays): if arrays is None: return None @@ -60,6 +95,12 @@ def _arrays_from_json(obj, context=None): return {k: dace.serialize.from_json(v, context) for k, v in obj.items()} +def _nested_arrays_from_json(obj, context=None): + if obj is None: + return NestedDict({}) + return NestedDict({k: dace.serialize.from_json(v, context) for k, v in obj.items()}) + + def _replace_dict_keys(d, old, new): if old in d: if new in d: @@ -379,10 +420,10 @@ class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): name = Property(dtype=str, desc="Name of the SDFG") arg_names = ListProperty(element_type=str, desc='Ordered argument names (used for calling conventions).') constants_prop = Property(dtype=dict, default={}, desc="Compile-time constants") - _arrays = Property(dtype=dict, + _arrays = Property(dtype=NestedDict, desc="Data descriptors for this SDFG", to_json=_arrays_to_json, - from_json=_arrays_from_json) + from_json=_nested_arrays_from_json) symbols = DictProperty(str, dtypes.typeclass, desc="Global symbols for this SDFG") instrument = EnumProperty(dtype=dtypes.InstrumentationType, @@ -460,7 +501,7 @@ def __init__(self, self._sdfg_list = [self] self._start_state: Optional[int] = None self._cached_start_state: Optional[SDFGState] = None - self._arrays = {} # type: Dict[str, dt.Array] + self._arrays = NestedDict() # type: Dict[str, dt.Array] self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} self.init_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -1994,10 +2035,17 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str raise NameError(f'Array or Stream with name "{name}" already exists in SDFG') self._arrays[name] = datadesc + def _add_symbols(desc: dt.Data): + if isinstance(desc, dt.Structure): + for v in desc.members.values(): + if isinstance(v, dt.Data): + _add_symbols(v) + for sym in desc.free_symbols: + if sym.name not in self.symbols: + self.add_symbol(sym.name, sym.dtype) + # Add free symbols to the SDFG global symbol storage - for sym in datadesc.free_symbols: - if sym.name not in self.symbols: - self.add_symbol(sym.name, sym.dtype) + _add_symbols(datadesc) return name diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index d08518b10c..3396335ece 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1396,7 +1396,7 @@ def is_nonfree_sym_dependent(node: nd.AccessNode, desc: dt.Data, state: SDFGStat :param state: the state that contains the node :param fsymbols: the free symbols to check against """ - if isinstance(desc, dt.View): + if isinstance(desc, (dt.StructureView, dt.View)): # Views can be non-free symbol dependent due to the adjacent edges. e = get_view_edge(state, node) if e.data: diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index aa7674ca45..0bb3e9a64e 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -604,9 +604,14 @@ def validate_state(state: 'dace.sdfg.SDFGState', break # Check if memlet data matches src or dst nodes - if (e.data.data is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode)) - and (not isinstance(src_node, nd.AccessNode) or e.data.data != src_node.data) - and (not isinstance(dst_node, nd.AccessNode) or e.data.data != dst_node.data)): + name = e.data.data + if isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Structure): + name = None + if isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Structure): + name = None + if (name is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode)) + and (not isinstance(src_node, nd.AccessNode) or (name != src_node.data and name != e.src_conn)) + and (not isinstance(dst_node, nd.AccessNode) or (name != dst_node.data and name != e.dst_conn))): raise InvalidSDFGEdgeError( "Memlet data does not match source or destination " "data nodes)", diff --git a/tests/sdfg/data/struct_array_test.py b/tests/sdfg/data/struct_array_test.py new file mode 100644 index 0000000000..8e0f2f4739 --- /dev/null +++ b/tests/sdfg/data/struct_array_test.py @@ -0,0 +1,183 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import ctypes +import dace +import numpy as np + +from scipy import sparse + + +def test_read_struct_array(): + + L, M, N, nnz = (dace.symbol(s) for s in ('L', 'M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + csr_obj_view = dace.data.StructureView( + [('indptr', dace.int32[M + 1]), ('indices', dace.int32[nnz]), ('data', dace.float32[nnz])], + name='CSRMatrix', + transient=True) + + sdfg = dace.SDFG('array_of_csr_to_dense') + + sdfg.add_datadesc('A', csr_obj[L]) + sdfg.add_array('B', [L, M, N], dace.float32) + + sdfg.add_datadesc('vcsr', csr_obj_view) + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + bme, bmx = state.add_map('b', dict(b='0:L')) + bme.map.schedule = dace.ScheduleType.Sequential + + vcsr = state.add_access('vcsr') + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_memlet_path(A, bme, vcsr, dst_conn='views', memlet=dace.Memlet(data='A', subset='b')) + state.add_edge(vcsr, None, indptr, 'views', memlet=dace.Memlet.from_array('vcsr.indptr', csr_obj.members['indptr'])) + state.add_edge(vcsr, None, indices, 'views', memlet=dace.Memlet.from_array('vcsr.indices', csr_obj.members['indices'])) + state.add_edge(vcsr, None, data, 'views', memlet=dace.Memlet.from_array('vcsr.data', csr_obj.members['data'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, bmx, B, memlet=dace.Memlet(data='B', subset='b, 0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = np.ndarray((10,), dtype=sparse.csr_matrix) + dace_A = np.ndarray((10,), dtype=ctypes.c_void_p) + B = np.zeros((10, 20, 20), dtype=np.float32) + + ctypes_A = [] + for b in range(10): + A[b] = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + ctypes_obj = csr_obj.dtype._typeclass.as_ctypes()(indptr=A[b].indptr.__array_interface__['data'][0], + indices=A[b].indices.__array_interface__['data'][0], + data=A[b].data.__array_interface__['data'][0]) + ctypes_A.append(ctypes_obj) # This is needed to keep the object alive ... + dace_A[b] = ctypes.addressof(ctypes_obj) + + func(A=dace_A, B=B, L=A.shape[0], M=A[0].shape[0], N=A[0].shape[1], nnz=A[0].nnz) + ref = np.ndarray((10, 20, 20), dtype=np.float32) + for b in range(10): + ref[b] = A[b].toarray() + + assert np.allclose(B, ref) + + +def test_write_struct_array(): + + L, M, N, nnz = (dace.symbol(s) for s in ('L', 'M', 'N', 'nnz')) + csr_obj = dace.data.Structure( + [('indptr', dace.int32[M + 1]), ('indices', dace.int32[nnz]), ('data', dace.float32[nnz])], + name='CSRMatrix') + csr_obj_view = dace.data.StructureView( + dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix', + transient=True) + + sdfg = dace.SDFG('array_dense_to_csr') + + sdfg.add_array('A', [L, M, N], dace.float32) + sdfg.add_datadesc('B', csr_obj[L]) + + sdfg.add_datadesc('vcsr', csr_obj_view) + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[k, i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[k, i, j] == 0')) + A = if_body.add_access('A') + vcsr = if_body.add_access('vcsr') + B = if_body.add_access('B') + indices = if_body.add_access('vindices') + data = if_body.add_access('vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='k, i, j', other_subset='idx')) + if_body.add_edge(data, 'views', vcsr, None, dace.Memlet(data='vcsr.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) + if_body.add_edge(indices, 'views', vcsr, None, dace.Memlet(data='vcsr.indices', subset='0:nnz')) + if_body.add_edge(vcsr, 'views', B, None, dace.Memlet(data='B', subset='k')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + vcsr = i_guard.add_access('vcsr') + B = i_guard.add_access('B') + indptr = i_guard.add_access('vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', vcsr, None, dace.Memlet(data='vcsr.indptr', subset='0:M+1')) + i_guard.add_edge(vcsr, 'views', B, None, dace.Memlet(data='B', subset='k')) + vcsr = i_after.add_access('vcsr') + B = i_after.add_access('B') + indptr = i_after.add_access('vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='M')) + i_after.add_edge(indptr, 'views', vcsr, None, dace.Memlet(data='vcsr.indptr', subset='0:M+1')) + i_after.add_edge(vcsr, 'views', B, None, dace.Memlet(data='B', subset='k')) + + k_before, k_guard, k_after = sdfg.add_loop(None, i_before, None, 'k', '0', 'k < L', 'k + 1', loop_end_state=i_after) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + B = np.ndarray((10,), dtype=sparse.csr_matrix) + dace_B = np.ndarray((10,), dtype=ctypes.c_void_p) + A = np.empty((10, 20, 20), dtype=np.float32) + + ctypes_B = [] + for b in range(10): + B[b] = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A[b] = B[b].toarray() + nnz = B[b].nnz + B[b].indptr[:] = -1 + B[b].indices[:] = -1 + B[b].data[:] = -1 + ctypes_obj = csr_obj.dtype._typeclass.as_ctypes()(indptr=B[b].indptr.__array_interface__['data'][0], + indices=B[b].indices.__array_interface__['data'][0], + data=B[b].data.__array_interface__['data'][0]) + ctypes_B.append(ctypes_obj) # This is needed to keep the object alive ... + dace_B[b] = ctypes.addressof(ctypes_obj) + + func(A=A, B=dace_B, L=B.shape[0], M=B[0].shape[0], N=B[0].shape[1], nnz=nnz) + for b in range(10): + assert np.allclose(A[b], B[b].toarray()) + + +if __name__ == '__main__': + test_read_struct_array() + test_write_struct_array() diff --git a/tests/sdfg/data/structure_test.py b/tests/sdfg/data/structure_test.py new file mode 100644 index 0000000000..02b8f0c174 --- /dev/null +++ b/tests/sdfg/data/structure_test.py @@ -0,0 +1,507 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +import pytest + +from dace import serialize +from dace.properties import make_properties +from scipy import sparse + + +def test_read_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('csr_to_dense') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_edge(A, None, indptr, 'views', dace.Memlet.from_array('A.indptr', csr_obj.members['indptr'])) + state.add_edge(A, None, indices, 'views', dace.Memlet.from_array('A.indices', csr_obj.members['indices'])) + state.add_edge(A, None, data, 'views', dace.Memlet.from_array('A.data', csr_obj.members['data'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0]) + + func(A=inpA, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +def test_write_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('dense_to_csr') + + sdfg.add_array('A', [M, N], dace.float32) + sdfg.add_datadesc('B', csr_obj) + + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[i, j] == 0')) + A = if_body.add_access('A') + B = if_body.add_access('B') + indices = if_body.add_access('vindices') + data = if_body.add_access('vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='i, j', other_subset='idx')) + if_body.add_edge(data, 'views', B, None, dace.Memlet(data='B.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) + if_body.add_edge(indices, 'views', B, None, dace.Memlet(data='B.indices', subset='0:nnz')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + B = i_guard.add_access('B') + indptr = i_guard.add_access('vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.indptr', subset='0:M+1')) + B = i_after.add_access('B') + indptr = i_after.add_access('vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='M')) + i_after.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.indptr', subset='0:M+1')) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A = tmp.toarray() + B = tmp.tocsr(copy=True) + B.indptr[:] = -1 + B.indices[:] = -1 + B.data[:] = -1 + + outB = csr_obj.dtype._typeclass.as_ctypes()(indptr=B.indptr.__array_interface__['data'][0], + indices=B.indices.__array_interface__['data'][0], + data=B.data.__array_interface__['data'][0]) + + func(A=A, B=outB, M=tmp.shape[0], N=tmp.shape[1], nnz=tmp.nnz) + + assert np.allclose(A, B.toarray()) + + +def test_local_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + tmp_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix', + transient=True) + + sdfg = dace.SDFG('dense_to_csr_local') + + sdfg.add_array('A', [M, N], dace.float32) + sdfg.add_datadesc('B', csr_obj) + sdfg.add_datadesc('tmp', tmp_obj) + + sdfg.add_view('vindptr', csr_obj.members['indptr'].shape, csr_obj.members['indptr'].dtype) + sdfg.add_view('vindices', csr_obj.members['indices'].shape, csr_obj.members['indices'].dtype) + sdfg.add_view('vdata', csr_obj.members['data'].shape, csr_obj.members['data'].dtype) + + sdfg.add_view('tmp_vindptr', tmp_obj.members['indptr'].shape, tmp_obj.members['indptr'].dtype) + sdfg.add_view('tmp_vindices', tmp_obj.members['indices'].shape, tmp_obj.members['indices'].dtype) + sdfg.add_view('tmp_vdata', tmp_obj.members['data'].shape, tmp_obj.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[i, j] == 0')) + A = if_body.add_access('A') + tmp = if_body.add_access('tmp') + indices = if_body.add_access('tmp_vindices') + data = if_body.add_access('tmp_vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='i, j', other_subset='idx')) + if_body.add_edge(data, 'views', tmp, None, dace.Memlet(data='tmp.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='tmp_vindices', subset='idx')) + if_body.add_edge(indices, 'views', tmp, None, dace.Memlet(data='tmp.indices', subset='0:nnz')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + tmp = i_guard.add_access('tmp') + indptr = i_guard.add_access('tmp_vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='tmp_vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', tmp, None, dace.Memlet(data='tmp.indptr', subset='0:M+1')) + tmp = i_after.add_access('tmp') + indptr = i_after.add_access('tmp_vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='tmp_vindptr', subset='M')) + i_after.add_edge(indptr, 'views', tmp, None, dace.Memlet(data='tmp.indptr', subset='0:M+1')) + + set_B = sdfg.add_state('set_B') + sdfg.add_edge(i_after, set_B, dace.InterstateEdge()) + tmp = set_B.add_access('tmp') + tmp_indptr = set_B.add_access('tmp_vindptr') + tmp_indices = set_B.add_access('tmp_vindices') + tmp_data = set_B.add_access('tmp_vdata') + set_B.add_edge(tmp, None, tmp_indptr, 'views', dace.Memlet(data='tmp.indptr', subset='0:M+1')) + set_B.add_edge(tmp, None, tmp_indices, 'views', dace.Memlet(data='tmp.indices', subset='0:nnz')) + set_B.add_edge(tmp, None, tmp_data, 'views', dace.Memlet(data='tmp.data', subset='0:nnz')) + B = set_B.add_access('B') + B_indptr = set_B.add_access('vindptr') + B_indices = set_B.add_access('vindices') + B_data = set_B.add_access('vdata') + set_B.add_edge(B_indptr, 'views', B, None, dace.Memlet(data='B.indptr', subset='0:M+1')) + set_B.add_edge(B_indices, 'views', B, None, dace.Memlet(data='B.indices', subset='0:nnz')) + set_B.add_edge(B_data, 'views', B, None, dace.Memlet(data='B.data', subset='0:nnz')) + set_B.add_edge(tmp_indptr, None, B_indptr, None, dace.Memlet(data='tmp_vindptr', subset='0:M+1')) + set_B.add_edge(tmp_indices, None, B_indices, None, dace.Memlet(data='tmp_vindices', subset='0:nnz')) + t, me, mx = set_B.add_mapped_tasklet('set_data', {'idx': '0:nnz'}, + {'__inp': dace.Memlet(data='tmp_vdata', subset='idx')}, + '__out = 2 * __inp', {'__out': dace.Memlet(data='vdata', subset='idx')}, + external_edges=True, + input_nodes={'tmp_vdata': tmp_data}, + output_nodes={'vdata': B_data}) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A = tmp.toarray() + B = tmp.tocsr(copy=True) + B.indptr[:] = -1 + B.indices[:] = -1 + B.data[:] = -1 + + outB = csr_obj.dtype._typeclass.as_ctypes()(indptr=B.indptr.__array_interface__['data'][0], + indices=B.indices.__array_interface__['data'][0], + data=B.data.__array_interface__['data'][0]) + + func(A=A, B=outB, M=tmp.shape[0], N=tmp.shape[1], nnz=tmp.nnz) + + assert np.allclose(A * 2, B.toarray()) + + +def test_read_nested_structure(): + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') + + sdfg = dace.SDFG('nested_csr_to_dense') + + sdfg.add_datadesc('A', wrapper_obj) + sdfg.add_array('B', [M, N], dace.float32) + + spmat = wrapper_obj.members['csr'] + sdfg.add_view('vindptr', spmat.members['indptr'].shape, spmat.members['indptr'].dtype) + sdfg.add_view('vindices', spmat.members['indices'].shape, spmat.members['indices'].dtype) + sdfg.add_view('vdata', spmat.members['data'].shape, spmat.members['data'].dtype) + + state = sdfg.add_state() + + A = state.add_access('A') + B = state.add_access('B') + + indptr = state.add_access('vindptr') + indices = state.add_access('vindices') + data = state.add_access('vdata') + + state.add_edge(A, None, indptr, 'views', dace.Memlet.from_array('A.csr.indptr', spmat.members['indptr'])) + state.add_edge(A, None, indices, 'views', dace.Memlet.from_array('A.csr.indices', spmat.members['indices'])) + state.add_edge(A, None, data, 'views', dace.Memlet.from_array('A.csr.data', spmat.members['data'])) + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='vindptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='vindices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='vdata', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + structclass = csr_obj.dtype._typeclass.as_ctypes() + inpCSR = structclass(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0]) + import ctypes + inpW = wrapper_obj.dtype._typeclass.as_ctypes()(csr=ctypes.pointer(inpCSR)) + + func(A=inpW, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +def test_write_nested_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') + + sdfg = dace.SDFG('dense_to_csr') + + sdfg.add_array('A', [M, N], dace.float32) + sdfg.add_datadesc('B', wrapper_obj) + + spmat = wrapper_obj.members['csr'] + sdfg.add_view('vindptr', spmat.members['indptr'].shape, spmat.members['indptr'].dtype) + sdfg.add_view('vindices', spmat.members['indices'].shape, spmat.members['indices'].dtype) + sdfg.add_view('vdata', spmat.members['data'].shape, spmat.members['data'].dtype) + + # Make If + if_before = sdfg.add_state('if_before') + if_guard = sdfg.add_state('if_guard') + if_body = sdfg.add_state('if_body') + if_after = sdfg.add_state('if_after') + sdfg.add_edge(if_before, if_guard, dace.InterstateEdge()) + sdfg.add_edge(if_guard, if_body, dace.InterstateEdge(condition='A[i, j] != 0')) + sdfg.add_edge(if_body, if_after, dace.InterstateEdge(assignments={'idx': 'idx + 1'})) + sdfg.add_edge(if_guard, if_after, dace.InterstateEdge(condition='A[i, j] == 0')) + A = if_body.add_access('A') + B = if_body.add_access('B') + indices = if_body.add_access('vindices') + data = if_body.add_access('vdata') + if_body.add_edge(A, None, data, None, dace.Memlet(data='A', subset='i, j', other_subset='idx')) + if_body.add_edge(data, 'views', B, None, dace.Memlet(data='B.csr.data', subset='0:nnz')) + t = if_body.add_tasklet('set_indices', {}, {'__out'}, '__out = j') + if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) + if_body.add_edge(indices, 'views', B, None, dace.Memlet(data='B.csr.indices', subset='0:nnz')) + # Make For Loop for j + j_before, j_guard, j_after = sdfg.add_loop(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) + # Make For Loop for i + i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + sdfg.start_state = sdfg.node_id(i_before) + i_before_guard = sdfg.edges_between(i_before, i_guard)[0] + i_before_guard.data.assignments['idx'] = '0' + B = i_guard.add_access('B') + indptr = i_guard.add_access('vindptr') + t = i_guard.add_tasklet('set_indptr', {}, {'__out'}, '__out = idx') + i_guard.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='i')) + i_guard.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.csr.indptr', subset='0:M+1')) + B = i_after.add_access('B') + indptr = i_after.add_access('vindptr') + t = i_after.add_tasklet('set_indptr', {}, {'__out'}, '__out = nnz') + i_after.add_edge(t, '__out', indptr, None, dace.Memlet(data='vindptr', subset='M')) + i_after.add_edge(indptr, 'views', B, None, dace.Memlet(data='B.csr.indptr', subset='0:M+1')) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + A = tmp.toarray() + B = tmp.tocsr(copy=True) + B.indptr[:] = -1 + B.indices[:] = -1 + B.data[:] = -1 + + outCSR = csr_obj.dtype._typeclass.as_ctypes()(indptr=B.indptr.__array_interface__['data'][0], + indices=B.indices.__array_interface__['data'][0], + data=B.data.__array_interface__['data'][0]) + import ctypes + outW = wrapper_obj.dtype._typeclass.as_ctypes()(csr=ctypes.pointer(outCSR)) + + func(A=A, B=outW, M=tmp.shape[0], N=tmp.shape[1], nnz=tmp.nnz) + + assert np.allclose(A, B.toarray()) + + +def test_direct_read_structure(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('csr_to_dense_direct') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + state = sdfg.add_state() + + indptr = state.add_access('A.indptr') + indices = state.add_access('A.indices') + data = state.add_access('A.data') + B = state.add_access('B') + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.indptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.indptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='A.indices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='A.data', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0], + rows=A.shape[0], + cols=A.shape[1], + M=A.shape[0], + N=A.shape[1], + nnz=A.nnz) + + func(A=inpA, B=B, M=20, N=20, nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +def test_direct_read_nested_structure(): + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + wrapper_obj = dace.data.Structure(dict(csr=csr_obj), name='Wrapper') + + sdfg = dace.SDFG('nested_csr_to_dense_direct') + + sdfg.add_datadesc('A', wrapper_obj) + sdfg.add_array('B', [M, N], dace.float32) + + spmat = wrapper_obj.members['csr'] + sdfg.add_view('vindptr', spmat.members['indptr'].shape, spmat.members['indptr'].dtype) + sdfg.add_view('vindices', spmat.members['indices'].shape, spmat.members['indices'].dtype) + sdfg.add_view('vdata', spmat.members['data'].shape, spmat.members['data'].dtype) + + state = sdfg.add_state() + + indptr = state.add_access('A.csr.indptr') + indices = state.add_access('A.csr.indices') + data = state.add_access('A.csr.data') + B = state.add_access('B') + + ime, imx = state.add_map('i', dict(i='0:M')) + jme, jmx = state.add_map('idx', dict(idx='start:stop')) + jme.add_in_connector('start') + jme.add_in_connector('stop') + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.csr.indptr', subset='i'), dst_conn='start') + state.add_memlet_path(indptr, ime, jme, memlet=dace.Memlet(data='A.csr.indptr', subset='i+1'), dst_conn='stop') + state.add_memlet_path(indices, ime, jme, t, memlet=dace.Memlet(data='A.csr.indices', subset='idx'), dst_conn='j') + state.add_memlet_path(data, ime, jme, t, memlet=dace.Memlet(data='A.csr.data', subset='idx'), dst_conn='__val') + state.add_memlet_path(t, jmx, imx, B, memlet=dace.Memlet(data='B', subset='0:M, 0:N', volume=1), src_conn='__out') + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + structclass = csr_obj.dtype._typeclass.as_ctypes() + inpCSR = structclass(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0]) + import ctypes + inpW = wrapper_obj.dtype._typeclass.as_ctypes()(csr=ctypes.pointer(inpCSR)) + + func(A=inpW, B=B, M=A.shape[0], N=A.shape[1], nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + +if __name__ == "__main__": + test_read_structure() + test_write_structure() + test_local_structure() + test_read_nested_structure() + test_write_nested_structure() + test_direct_read_structure() + test_direct_read_nested_structure()