Skip to content

Commit

Permalink
Add reading the size of array, add size input as a special in connector
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Oct 29, 2024
1 parent 023c86c commit 4aca5ee
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 39 deletions.
4 changes: 2 additions & 2 deletions dace/codegen/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def dispatch_copy(self, src_node: nodes.Node, dst_node: nodes.Node, edge: MultiC
self._used_targets.add(target)
target.copy_memory(sdfg, cfg, dfg, state_id, src_node, dst_node, edge, function_stream, output_stream)

def dispatch_reallocate(self, node: nodes.Node, edge: MultiConnectorEdge[Memlet], sdfg: SDFG,
def dispatch_reallocate(self, src_node: nodes.Node, node: nodes.Node, edge: MultiConnectorEdge[Memlet], sdfg: SDFG,
cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, function_stream: CodeIOStream,
output_stream: CodeIOStream) -> None:
state = cfg.state(state_id)
Expand All @@ -640,7 +640,7 @@ def dispatch_reallocate(self, node: nodes.Node, edge: MultiConnectorEdge[Memlet]

# Dispatch reallocate
self._used_targets.add(target)
target.reallocate(sdfg, cfg, dfg, state_id, node, edge, function_stream, output_stream)
target.reallocate(sdfg, cfg, dfg, state_id, src_node, node, edge, function_stream, output_stream)


# Dispatches definition code for a memlet that is outgoing from a tasklet
Expand Down
62 changes: 41 additions & 21 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV

if not declared:
declaration_stream.write(f'{nodedesc.dtype.ctype} *{name};\n', cfg, state_id, node)
# Initialize size array
size_str = ",".join(["0" if cpp.sym2cpp(dim).startswith("__dace_defer") else cpp.sym2cpp(dim) for dim in nodedesc.shape])
size_nodedesc = sdfg.arrays[f"{name}_size"]
declaration_stream.write(f'{size_nodedesc.dtype.ctype} {name}_size[{size_nodedesc.shape[0]}]{{{size_str}}};\n', cfg, state_id, node)
if deferred_allocation:
allocation_stream.write(
"%s = nullptr; // Deferred Allocation" %
Expand All @@ -515,7 +519,9 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV
node
)


define_var(name, DefinedType.Pointer, ctypedef)
define_var(name + "_size", DefinedType.Pointer, size_nodedesc.dtype.ctype)

if node.setzero:
allocation_stream.write("memset(%s, 0, sizeof(%s)*%s);" %
Expand Down Expand Up @@ -671,21 +677,35 @@ def reallocate(
cfg: ControlFlowRegion,
dfg: StateSubgraphView,
state_id: int,
node: Union[nodes.Tasklet, nodes.AccessNode],
src_node: nodes.AccessNode,
dst_node: nodes.AccessNode,
edge: Tuple[nodes.Node, Optional[str], nodes.Node, Optional[str], mmlt.Memlet],
function_stream: CodeIOStream,
callsite_stream: CodeIOStream,
):
function_stream.write(
"#include <cstdlib>"
)
data_name = node.data
data_name = dst_node.data
size_array_name = f"{data_name}_size"
new_size_array_name = src_node.data

data = sdfg.arrays[data_name]
new_size_array = sdfg.arrays[new_size_array_name]
dtype = sdfg.arrays[data_name].dtype

# Only consider the offsets with __dace_defer in original dim
mask_array = [str(dim).startswith("__dace_defer") for dim in data.shape]
for i, mask in enumerate(mask_array):
if mask:
callsite_stream.write(
f"{size_array_name}[{i}] = {new_size_array_name}[{i}];"
)

# Call realloc only after no __dace_defer is left in size_array ?
size_str = " * ".join([f"{size_array_name}[{i}]" for i in range(len(data.shape))])
callsite_stream.write(
f"{node.data} = static_cast<{dtype} *>(std::realloc(static_cast<void *>({node.data}), {size_str} * sizeof({dtype})));"
f"{dst_node.data} = static_cast<{dtype} *>(std::realloc(static_cast<void *>({dst_node.data}), {size_str} * sizeof({dtype})));"
)

def _emit_copy(
Expand Down Expand Up @@ -1145,7 +1165,7 @@ def process_out_memlets(self,
elif isinstance(node, nodes.AccessNode):
if dst_node != node and not isinstance(dst_node, nodes.Tasklet) :
# If it is a size change, reallocate will be called
if edge.dst_conn is not None and edge.dst_conn.endswith("_size"):
if edge.dst_conn is not None and edge.dst_conn == "IN_size":
continue

dispatcher.dispatch_copy(
Expand Down Expand Up @@ -1460,7 +1480,6 @@ def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra
self._dispatcher.defined_vars.add(edge.dst_conn, defined_type, f"const {ctype}")

else:
inner_stream.write("// COPY3")
self._dispatcher.dispatch_copy(
src_node,
node,
Expand Down Expand Up @@ -2205,22 +2224,23 @@ def _generate_AccessNode(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSub
if memlet.data is None:
continue # If the edge has to be skipped

if in_connector == "IN_size":
self._dispatcher.dispatch_reallocate(
node,
edge,
sdfg,
cfg,
dfg,
state_id,
function_stream,
callsite_stream,
)
else:
# Determines if this path ends here or has a definite source (array) node
memlet_path = state_dfg.memlet_path(edge)
if memlet_path[-1].dst == node:
src_node = memlet_path[0].src
# Determines if this path ends here or has a definite source (array) node
memlet_path = state_dfg.memlet_path(edge)
if memlet_path[-1].dst == node:
src_node = memlet_path[0].src
if in_connector == "IN_size":
self._dispatcher.dispatch_reallocate(
src_node,
node,
edge,
sdfg,
cfg,
dfg,
state_id,
function_stream,
callsite_stream,
)
else:
# Only generate code in case this is the innermost scope
# (copies are generated at the inner scope, where both arrays exist)
if (scope_contains_scope(sdict, src_node, node) and sdict[src_node] != sdict[node]):
Expand Down
15 changes: 8 additions & 7 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
"""
Determines what data is read and written in this graph.
Does not include reads to subsets of containers that have previously been written within the same state.
:return: A two-tuple of sets of things denoting ({data read}, {data written}).
"""
return set(), set()
Expand Down Expand Up @@ -421,7 +421,8 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto
# Trace through scope entry using IN_# -> OUT_#
if isinstance(curedge.dst, (nd.EntryNode, nd.ExitNode)):
if curedge.dst_conn is None:
raise ValueError("Destination connector cannot be None for {}".format(curedge.dst))
#raise ValueError("Destination connector cannot be None for {}".format(curedge.dst))
break
if not curedge.dst_conn.startswith("IN_"): # Map variable
break
next_edge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == "OUT_" + curedge.dst_conn[3:])
Expand Down Expand Up @@ -794,7 +795,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr,
def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
"""
Determines what data is read and written in this subgraph.
:return: A two-tuple of sets of things denoting
({data read}, {data written}).
"""
Expand Down Expand Up @@ -2595,7 +2596,7 @@ def inline(self) -> Tuple[bool, Any]:
for b_edge in parent.in_edges(self):
parent.add_edge(b_edge.src, self.start_block, b_edge.data)
parent.remove_edge(b_edge)

end_state = None
if len(to_connect) > 0:
end_state = parent.add_state(self.label + '_end')
Expand Down Expand Up @@ -3262,7 +3263,7 @@ def nodes(self) -> List['ControlFlowBlock']:

def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]:
return []

def _used_symbols_internal(self,
all_symbols: bool,
defined_syms: Optional[Set] = None,
Expand Down Expand Up @@ -3304,7 +3305,7 @@ def to_json(self, parent=None):
json['branches'] = [(condition.to_json() if condition is not None else None, cfg.to_json())
for condition, cfg in self._branches]
return json

@classmethod
def from_json(cls, json_obj, context=None):
context = context or {'sdfg': None, 'parent_graph': None}
Expand All @@ -3322,7 +3323,7 @@ def from_json(cls, json_obj, context=None):
else:
ret._branches.append((None, ControlFlowRegion.from_json(region, context)))
return ret

def inline(self) -> Tuple[bool, Any]:
"""
Inlines the conditional region into its parent control flow region.
Expand Down
30 changes: 21 additions & 9 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def validate_control_flow_region(sdfg: 'SDFG',

def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context: bool):
""" Verifies the correctness of an SDFG by applying multiple tests.
:param sdfg: The SDFG to verify.
:param references: An optional set keeping seen IDs for object
miscopy validation.
Expand Down Expand Up @@ -464,7 +464,7 @@ def validate_state(state: 'dace.sdfg.SDFGState',

if isinstance(node, (nd.EntryNode, nd.ExitNode)):
for iconn in node.in_connectors:
if (iconn is not None and iconn.startswith("IN_") and ("OUT_" + iconn[3:]) not in node.out_connectors):
if (iconn is not None and iconn.startswith("IN_") and not iconn.endswith("_size") and ("OUT_" + iconn[3:]) not in node.out_connectors):
raise InvalidSDFGNodeError(
"No match for input connector %s in output "
"connectors" % iconn,
Expand Down Expand Up @@ -685,14 +685,15 @@ def validate_state(state: 'dace.sdfg.SDFGState',
break

# Check if memlet data matches src or dst nodes
# If is read from the size output connector it needs to match the array's size descriptor
name = e.data.data
if isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Structure):
name = None
if isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Structure):
name = None
if (name is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode))
and (not isinstance(src_node, nd.AccessNode) or (name != src_node.data and name != e.src_conn))
and (not isinstance(dst_node, nd.AccessNode) or (name != dst_node.data and name != e.dst_conn))):
and (not isinstance(src_node, nd.AccessNode) or (name != src_node.data and name != e.src_conn and name != src_node.data + "_size"))
and (not isinstance(dst_node, nd.AccessNode) or (name != dst_node.data and name != e.dst_conn and name != dst_node.data + "_size"))):
raise InvalidSDFGEdgeError(
"Memlet data does not match source or destination "
"data nodes)",
Expand All @@ -716,21 +717,31 @@ def validate_state(state: 'dace.sdfg.SDFGState',
# Check memlet subset validity with respect to source/destination nodes
if e.data.data is not None and e.data.allow_oob == False:
subset_node = (dst_node
if isinstance(dst_node, nd.AccessNode) and e.data.data == dst_node.data else src_node)
if isinstance(dst_node, nd.AccessNode) and e.data.data == dst_node.data or e.data.data == dst_node.data + "_size" else src_node)
other_subset_node = (dst_node
if isinstance(dst_node, nd.AccessNode) and e.data.data != dst_node.data else src_node)
if isinstance(dst_node, nd.AccessNode) and e.data.data != dst_node.data or e.data.data == dst_node.data + "_size" else src_node)

if isinstance(subset_node, nd.AccessNode):
arr = sdfg.arrays[subset_node.data]
size_arr = sdfg.arrays[subset_node.data + "_size"]
# Dimensionality
if e.data.subset.dims() != len(arr.shape):

if e.data.data == subset_node.data and e.data.subset.dims() != len(arr.shape):
raise InvalidSDFGEdgeError(
"Memlet subset does not match node dimension "
"(expected %d, got %d)" % (len(arr.shape), e.data.subset.dims()),
sdfg,
state_id,
eid,
)
if e.data.data == (subset_node.data + "_size") and e.data.subset.dims() != len(size_arr.shape):
raise InvalidSDFGEdgeError(
"Memlet subset does not match node size dimension "
"(expected %d, got %d)" % (len(size_arr.shape), e.data.subset.dims()),
sdfg,
state_id,
eid,
)

# Bounds
if any(((minel + off) < 0) == True for minel, off in zip(e.data.subset.min_element(), arr.offset)):
Expand All @@ -741,10 +752,11 @@ def validate_state(state: 'dace.sdfg.SDFGState',
raise InvalidSDFGEdgeError("Memlet subset negative out-of-bounds", sdfg, state_id, eid)
if any(((maxel + off) >= s) == True
for maxel, s, off in zip(e.data.subset.max_element(), arr.shape, arr.offset)):
if e.data.dynamic:
if e.data.dynamic or e.data.data.endswith("_size"):
warnings.warn(f'Potential out-of-bounds memlet subset: {e}')
else:
raise InvalidSDFGEdgeError("Memlet subset out-of-bounds", sdfg, state_id, eid)
warnings.warn(f'Memlet subset out-of-bounds {sdfg}, {state_id}, {eid}')
#raise InvalidSDFGEdgeError("Memlet subset out-of-bounds", sdfg, state_id, eid)

# Test other_subset as well
if e.data.other_subset is not None and isinstance(other_subset_node, nd.AccessNode):
Expand Down

0 comments on commit 4aca5ee

Please sign in to comment.