Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Container Groups #1719

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
[submodule "dace/external/hlslib"]
path = dace/external/hlslib
url = https://github.com/definelicht/hlslib.git
[submodule "dace/viewer/webclient"]
path = dace/viewer/webclient
url = https://github.com/spcl/dace-webclient.git
[submodule "dace/external/rtllib"]
path = dace/external/rtllib
url = https://github.com/carljohnsen/rtllib.git
142 changes: 142 additions & 0 deletions dace/sdfg/container_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from collections import OrderedDict
from typing import Set, Union
from dace import data
from dace.data import Data
from dace import serialize, symbolic
from dace.properties import OrderedDictProperty, Property, make_properties
from enum import Enum

import numpy
import sympy


class ContainerGroupFlatteningMode(Enum):
ArrayOfStructs = 1
StructsOfArrays = 2


def _members_to_json(members):
if members is None:
return None
return [(k, serialize.to_json(v)) for k, v in members.items()]


def _members_from_json(obj, context=None):
if obj is None:
return {}
return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj)


@make_properties
class ContainerGroup:
name = Property(dtype=str, default="", allow_none=False)
members = OrderedDictProperty(
default=OrderedDict(),
desc="Dictionary of structure members",
from_json=_members_from_json,
to_json=_members_to_json,
)

def __init__(self, name):
self.name = name
self.members = OrderedDict()
self._validate()

def add_member(self, name: str, member: Union[Data, "ContainerGroup"]):
if name is None or name == "":
name = len(self.members)
self.members[name] = member

@property
def free_symbols(self) -> Set[symbolic.SymbolicType]:
"""Returns a set of undefined symbols in this data descriptor."""
result = set()
for k, v in self.members.items():
result |= v.free_symbols
return result

def __call__(self):
return self

def validate(self):
self._validate()

def _validate(self):
return True

def to_json(self):
attrs = serialize.all_properties_to_json(self)
retdict = {"type": type(self).__name__, "attributes": attrs}
return retdict

def is_equivalent(self, other):
raise NotImplementedError

def __eq__(self, other):
return serialize.dumps(self) == serialize.dumps(other)

def __hash__(self):
return hash(serialize.dumps(self))

def __repr__(self):
members_repr = ", ".join(
f"{k}: {v.__class__.__name__}" for k, v in self.members.items()
)
return f"DataGroup(name='{self.name}', members={{ {members_repr} }})"

def __str__(self):
return self.__repr__()

def _add_members(self, name, structure, acc_shape):
# If not a structure, then we have a leaf node
for member_name, member in structure.members.items():
if isinstance(member, data.Structure):
# Recursively convert nested Structures
self.add_member(
name=f"{member_name}",
member=self.from_struct(name=f"{member_name}", structure=member),
)
self._add_members(name=f"{member_name}",
member=self.from_struct(name=f"{member_name}", structure=member),)
elif isinstance(member, (data.Array, data.Scalar)):
# Append the previous shape and add the member
self.add_member(member_name, member, shape=acc_shape)
elif isinstance(
member, (sympy.Basic, symbolic.SymExpr, int, numpy.integer)
):
# Convert other types to Scalar
self.add_member(member_name, data.Scalar(symbolic.symtype(member)))
else:
raise TypeError(f"Unsupported member type in Structure: {type(member)}")

def _soa_from_struct(self, name, structure, acc_shape):
self._add_members(name, structure, acc_shape=None)

@classmethod
def from_struct(
cls,
name: str,
structure: data.Structure
) -> "ContainerGroup":
dg = cls(name)

for member_name, member in structure.members.items():
if isinstance(member, data.Structure):
# Recursively convert nested Structures
dg.add_member(
name=f"{member_name}",
member=cls.from_struct(name=f"{member_name}", structure=member),
)
elif isinstance(member, (data.Array, data.Scalar)):
# Directly add Arrays and Scalars
dg.add_member(member_name, member)
elif isinstance(
member, (sympy.Basic, symbolic.SymExpr, int, numpy.integer)
):
# Convert other types to Scalar
dg.add_member(member_name, data.Scalar(symbolic.symtype(member)))
else:
raise TypeError(f"Unsupported member type in Structure: {type(member)}")

return dg
131 changes: 125 additions & 6 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dace.codegen.instrumentation.data.data_report import InstrumentedDataReport
from dace.codegen.compiled_sdfg import CompiledSDFG

from dace.sdfg.container_group import ContainerGroup, ContainerGroupFlatteningMode

class NestedDict(dict):

Expand All @@ -50,8 +51,11 @@ def __getitem__(self, key):
tokens = key.split('.') if isinstance(key, str) else [key]
token = tokens.pop(0)
result = super(NestedDict, self).__getitem__(token)

while tokens:
token = tokens.pop(0)
if isinstance(result, dt.ContainerArray):
result = result.stype
result = result.members[token]
return result

Expand Down Expand Up @@ -407,6 +411,10 @@ class SDFG(ControlFlowRegion):
desc="Data descriptors for this SDFG",
to_json=_arrays_to_json,
from_json=_nested_arrays_from_json)
container_groups = Property(dtype=NestedDict,
desc="Data group descriptors for this SDFG",
to_json=_arrays_to_json,
from_json=_nested_arrays_from_json)
symbols = DictProperty(str, dtypes.typeclass, desc="Global symbols for this SDFG")

instrument = EnumProperty(dtype=dtypes.InstrumentationType,
Expand All @@ -428,6 +436,7 @@ class SDFG(ControlFlowRegion):

debuginfo = DebugInfoProperty(allow_none=True)


_pgrids = DictProperty(str,
ProcessGrid,
desc="Process-grid descriptors for this SDFG",
Expand Down Expand Up @@ -485,6 +494,7 @@ def __init__(self,
self._parent_sdfg = None
self._parent_nsdfg_node = None
self._arrays = NestedDict() # type: Dict[str, dt.Array]
self.container_groups = NestedDict()
self.arg_names = []
self._labels: Set[str] = set()
self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)}
Expand Down Expand Up @@ -1032,7 +1042,7 @@ def clear_data_reports(self):

def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, **kwargs):
"""
Invokes an SDFG with an instrumented data report, generating and compiling code if necessary.
Invokes an SDFG with an instrumented data report, generating and compiling code if necessary.
Arguments given as ``args`` and ``kwargs`` will be overriden by the data containers defined in the report.

:param dreport: The instrumented data report to use upon calling.
Expand Down Expand Up @@ -1280,11 +1290,16 @@ def _used_symbols_internal(self,

defined_syms |= set(self.constants_prop.keys())

init_code_symbols=set()
exit_code_symbols=set()
# Add used symbols from init and exit code
for code in self.init_code.values():
free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys())
init_code_symbols |= symbolic.symbols_in_code(code.as_string, self.symbols.keys())
for code in self.exit_code.values():
free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys())
exit_code_symbols |= symbolic.symbols_in_code(code.as_string, self.symbols.keys())

#free_syms|=set(filter(lambda x: not str(x).startswith('__f2dace_ARRAY'),init_code_symbols))
#free_syms|=set(filter(lambda x: not str(x).startswith('__f2dace_ARRAY'),exit_code_symbols))

return super()._used_symbols_internal(all_symbols=all_symbols,
keep_defined_in_mapping=keep_defined_in_mapping,
Expand Down Expand Up @@ -1364,7 +1379,9 @@ def arglist(self, scalars_only=False, free_symbols=None) -> Dict[str, dt.Data]:
}

# Add global free symbols used in the generated code to scalar arguments
#TODO LATER investiagte why all_symbols=False leads to bug
free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False)
free_symbols = set(filter(lambda x: not str(x).startswith('__f2dace_STRUCTARRAY'), free_symbols))
scalar_args.update({k: dt.Scalar(self.symbols[k]) for k in free_symbols if not k.startswith('__dace')})

# Fill up ordered dictionary
Expand Down Expand Up @@ -1725,6 +1742,12 @@ def add_array(self,

return self.add_datadesc(name, desc, find_new_name=find_new_name), desc

def add_container_group(self,
name: str,
find_new_name: bool = False) -> Tuple[str, ContainerGroup]:
dg_desc = ContainerGroup(name)
return self.add_container_group_desc(name, dg_desc, find_new_name=find_new_name), dg_desc

def add_view(self,
name: str,
shape,
Expand Down Expand Up @@ -1865,7 +1888,7 @@ def add_scalar(self,
storage=storage,
transient=transient,
lifetime=lifetime,
debuginfo=debuginfo,
debuginfo=debuginfo
)

return self.add_datadesc(name, desc, find_new_name=find_new_name), desc
Expand Down Expand Up @@ -2011,6 +2034,42 @@ def _add_symbols(sdfg: SDFG, desc: dt.Data):

return name

def add_container_group_desc(self, name: str, container_group_desc: ContainerGroup, find_new_name=False) -> str:
if not isinstance(name, str):
raise TypeError("Data descriptor name must be a string. Got %s" % type(name).__name__)

if find_new_name:
name = self._find_new_name(name)
name = name.replace('.', '_')
if self.is_name_used(name):
name = self._find_new_name(name)
else:
if name in self.arrays:
raise FileExistsError(f'Data group descriptor "{name}" already exists in SDFG')
if name in self.symbols:
raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a symbol.')
if name in self._subarrays:
raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a ProcessGrid.')

def _add_symbols(sdfg: SDFG, desc: dt.Data):
if isinstance(desc, dt.Structure):
for v in desc.members.values():
if isinstance(v, dt.Data):
_add_symbols(sdfg, v)
for sym in desc.free_symbols:
if sym.name not in sdfg.symbols:
sdfg.add_symbol(sym.name, sym.dtype)

# Add the data descriptor to the SDFG and all symbols that are not yet known.
self.container_groups[name] = container_group_desc
_add_symbols(self, container_group_desc)

return name

def add_datadesc_view(self, name: str, datadesc: dt.Data, find_new_name=False) -> str:
""" Adds a view of a given data descriptor to the SDFG array store.

Expand Down Expand Up @@ -2604,7 +2663,7 @@ def apply_transformations_once_everywhere(self,
print_report: Optional[bool] = None,
order_by_transformation: bool = True,
progress: Optional[bool] = None) -> int:
"""
"""
This function applies a transformation or a set of (unique) transformations
until throughout the entire SDFG once. Operates in-place.

Expand Down Expand Up @@ -2720,7 +2779,7 @@ def expand_library_nodes(self, recursive=True):

def generate_code(self):
""" Generates code from this SDFG and returns it.

:return: A list of `CodeObject` objects containing the generated
code of different files and languages.
"""
Expand Down Expand Up @@ -2759,3 +2818,63 @@ def recheck_using_experimental_blocks(self) -> bool:
break
self.root_sdfg.using_experimental_blocks = found_experimental_block
return found_experimental_block

def register_container_group_members(self, flattening_mode):
for _, dg in self.container_groups.items():
self._register_container_group_members(flattening_mode=flattening_mode, container_group=dg, prefix_name='')
print(self._arrays)

def _register_container_group_members(self, flattening_mode, container_group: ContainerGroup, prefix_name: str, acc_shape: tuple):
# Let's say we have an struct of CSR arrays length = L1.
# CSR is a srtuct of 3 arrays = L2.1, L2.2, L2.3.

# If we flatten as Array-of-Structs then we will have an array of form:
# [L1 * (L2.1 + L2.2 + L2.3)]

# If we have a 3rd level then (L2.1 is a struct and has L3.1 and L.32):
# [L1 * (L2.1 * (L3.1 + L3.2) + L2.2 + L2.3)]

# If we flatten as Structs-of-Arrays then we will have an array of form:
# [L1][L2.1], [L1][L2.2], [L1][L2.3]

# If we have a 3rd level (L2.1 is a struct and has L3.1 and L.32):
# [L1][L2.1][L3.1], [L1][L2.1][L3.2], [L1][L2.2], [L1][L2.3]
if flattening_mode == ContainerGroupFlatteningMode.StructsOfArrays:
for name, member in container_group.members.items():
dg_prefix = prefix_name + f'__ContainerGroup_{container_group.name}'
if isinstance(member, ContainerGroup):
self._register_container_group_members(container_group=member, prefix_name=dg_prefix, acc_shape=acc_shape)
else:
member_demangled_name = dg_prefix + f'__member_{name}'
self.add_datadesc(name=member_demangled_name, datadesc=member, find_new_name=False)
elif flattening_mode == ContainerGroupFlatteningMode.ArrayOfStructs:
raise Exception("TODO")
else:
raise Exception("Unsupported Flattening Mode")

def get_demangled_container_group_member_name(self, name_hierarchy: List[Type[str]]):
current_dg = None
demangled_name = ''
for i, name in enumerate(name_hierarchy):
if current_dg is None:
current_dg = self.container_groups[name]
demangled_name += f"__ContainerGroup_{current_dg.name}"
elif name in current_dg.members:
if isinstance(current_dg.members[name], ContainerGroup):
current_dg = current_dg.members[name]
demangled_name += f"__ContainerGroup_{current_dg.name}"
else:
assert isinstance(current_dg.members[name], dace.data.Data)
assert i == len(name_hierarchy) - 1
demangled_name += f"__member_{name}"
return demangled_name
else:
raise Exception(f'Name Hierarchy {name_hierarchy} Not in ContainerGroups')
raise Exception(f'Name Hierarchy {name_hierarchy} Not in ContainerGroups')

def generate_container_groups_from_structs(self, flattening_mode : ContainerGroupFlatteningMode):
for arr_name, arr in self._arrays.items():
if isinstance(arr, dt.Structure):
dg_name = arr_name
dg = ContainerGroup.from_struct(name=dg_name, structure=arr)
self.container_groups[dg_name] = dg
Loading