diff --git a/dace/codegen/dispatcher.py b/dace/codegen/dispatcher.py index 59f472d57f..2defa04680 100644 --- a/dace/codegen/dispatcher.py +++ b/dace/codegen/dispatcher.py @@ -629,7 +629,7 @@ def dispatch_copy(self, src_node: nodes.Node, dst_node: nodes.Node, edge: MultiC self._used_targets.add(target) target.copy_memory(sdfg, cfg, dfg, state_id, src_node, dst_node, edge, function_stream, output_stream) - def dispatch_reallocate(self, node: nodes.Node, edge: MultiConnectorEdge[Memlet], sdfg: SDFG, + def dispatch_reallocate(self, src_node: nodes.Node, node: nodes.Node, edge: MultiConnectorEdge[Memlet], sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, function_stream: CodeIOStream, output_stream: CodeIOStream) -> None: state = cfg.state(state_id) @@ -640,7 +640,7 @@ def dispatch_reallocate(self, node: nodes.Node, edge: MultiConnectorEdge[Memlet] # Dispatch reallocate self._used_targets.add(target) - target.reallocate(sdfg, cfg, dfg, state_id, node, edge, function_stream, output_stream) + target.reallocate(sdfg, cfg, dfg, state_id, src_node, node, edge, function_stream, output_stream) # Dispatches definition code for a memlet that is outgoing from a tasklet diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 19884fb88d..c11b94d9cd 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -496,6 +496,10 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV if not declared: declaration_stream.write(f'{nodedesc.dtype.ctype} *{name};\n', cfg, state_id, node) + # Initialize size array + size_str = ",".join(["0" if cpp.sym2cpp(dim).startswith("__dace_defer") else cpp.sym2cpp(dim) for dim in nodedesc.shape]) + size_nodedesc = sdfg.arrays[f"{name}_size"] + declaration_stream.write(f'{size_nodedesc.dtype.ctype} {name}_size[{size_nodedesc.shape[0]}]{{{size_str}}};\n', cfg, state_id, node) if deferred_allocation: allocation_stream.write( "%s = nullptr; // Deferred Allocation" % @@ -515,7 +519,9 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV node ) + define_var(name, DefinedType.Pointer, ctypedef) + define_var(name + "_size", DefinedType.Pointer, size_nodedesc.dtype.ctype) if node.setzero: allocation_stream.write("memset(%s, 0, sizeof(%s)*%s);" % @@ -671,7 +677,8 @@ def reallocate( cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: Union[nodes.Tasklet, nodes.AccessNode], + src_node: nodes.AccessNode, + dst_node: nodes.AccessNode, edge: Tuple[nodes.Node, Optional[str], nodes.Node, Optional[str], mmlt.Memlet], function_stream: CodeIOStream, callsite_stream: CodeIOStream, @@ -679,13 +686,26 @@ def reallocate( function_stream.write( "#include " ) - data_name = node.data + data_name = dst_node.data size_array_name = f"{data_name}_size" + new_size_array_name = src_node.data + data = sdfg.arrays[data_name] + new_size_array = sdfg.arrays[new_size_array_name] dtype = sdfg.arrays[data_name].dtype + + # Only consider the offsets with __dace_defer in original dim + mask_array = [str(dim).startswith("__dace_defer") for dim in data.shape] + for i, mask in enumerate(mask_array): + if mask: + callsite_stream.write( + f"{size_array_name}[{i}] = {new_size_array_name}[{i}];" + ) + + # Call realloc only after no __dace_defer is left in size_array ? size_str = " * ".join([f"{size_array_name}[{i}]" for i in range(len(data.shape))]) callsite_stream.write( - f"{node.data} = static_cast<{dtype} *>(std::realloc(static_cast({node.data}), {size_str} * sizeof({dtype})));" + f"{dst_node.data} = static_cast<{dtype} *>(std::realloc(static_cast({dst_node.data}), {size_str} * sizeof({dtype})));" ) def _emit_copy( @@ -1145,7 +1165,7 @@ def process_out_memlets(self, elif isinstance(node, nodes.AccessNode): if dst_node != node and not isinstance(dst_node, nodes.Tasklet) : # If it is a size change, reallocate will be called - if edge.dst_conn is not None and edge.dst_conn.endswith("_size"): + if edge.dst_conn is not None and edge.dst_conn == "IN_size": continue dispatcher.dispatch_copy( @@ -1460,7 +1480,6 @@ def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra self._dispatcher.defined_vars.add(edge.dst_conn, defined_type, f"const {ctype}") else: - inner_stream.write("// COPY3") self._dispatcher.dispatch_copy( src_node, node, @@ -2205,22 +2224,23 @@ def _generate_AccessNode(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSub if memlet.data is None: continue # If the edge has to be skipped - if in_connector == "IN_size": - self._dispatcher.dispatch_reallocate( - node, - edge, - sdfg, - cfg, - dfg, - state_id, - function_stream, - callsite_stream, - ) - else: - # Determines if this path ends here or has a definite source (array) node - memlet_path = state_dfg.memlet_path(edge) - if memlet_path[-1].dst == node: - src_node = memlet_path[0].src + # Determines if this path ends here or has a definite source (array) node + memlet_path = state_dfg.memlet_path(edge) + if memlet_path[-1].dst == node: + src_node = memlet_path[0].src + if in_connector == "IN_size": + self._dispatcher.dispatch_reallocate( + src_node, + node, + edge, + sdfg, + cfg, + dfg, + state_id, + function_stream, + callsite_stream, + ) + else: # Only generate code in case this is the innermost scope # (copies are generated at the inner scope, where both arrays exist) if (scope_contains_scope(sdict, src_node, node) and sdict[src_node] != sdict[node]): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 2ae6109b31..ebb96d9735 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -243,7 +243,7 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ Determines what data is read and written in this graph. Does not include reads to subsets of containers that have previously been written within the same state. - + :return: A two-tuple of sets of things denoting ({data read}, {data written}). """ return set(), set() @@ -421,7 +421,8 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto # Trace through scope entry using IN_# -> OUT_# if isinstance(curedge.dst, (nd.EntryNode, nd.ExitNode)): if curedge.dst_conn is None: - raise ValueError("Destination connector cannot be None for {}".format(curedge.dst)) + #raise ValueError("Destination connector cannot be None for {}".format(curedge.dst)) + break if not curedge.dst_conn.startswith("IN_"): # Map variable break next_edge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == "OUT_" + curedge.dst_conn[3:]) @@ -794,7 +795,7 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ Determines what data is read and written in this subgraph. - + :return: A two-tuple of sets of things denoting ({data read}, {data written}). """ @@ -2595,7 +2596,7 @@ def inline(self) -> Tuple[bool, Any]: for b_edge in parent.in_edges(self): parent.add_edge(b_edge.src, self.start_block, b_edge.data) parent.remove_edge(b_edge) - + end_state = None if len(to_connect) > 0: end_state = parent.add_state(self.label + '_end') @@ -3262,7 +3263,7 @@ def nodes(self) -> List['ControlFlowBlock']: def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: return [] - + def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, @@ -3304,7 +3305,7 @@ def to_json(self, parent=None): json['branches'] = [(condition.to_json() if condition is not None else None, cfg.to_json()) for condition, cfg in self._branches] return json - + @classmethod def from_json(cls, json_obj, context=None): context = context or {'sdfg': None, 'parent_graph': None} @@ -3322,7 +3323,7 @@ def from_json(cls, json_obj, context=None): else: ret._branches.append((None, ControlFlowRegion.from_json(region, context))) return ret - + def inline(self) -> Tuple[bool, Any]: """ Inlines the conditional region into its parent control flow region. diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index e75099276f..e0528e6584 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -186,7 +186,7 @@ def validate_control_flow_region(sdfg: 'SDFG', def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context: bool): """ Verifies the correctness of an SDFG by applying multiple tests. - + :param sdfg: The SDFG to verify. :param references: An optional set keeping seen IDs for object miscopy validation. @@ -464,7 +464,7 @@ def validate_state(state: 'dace.sdfg.SDFGState', if isinstance(node, (nd.EntryNode, nd.ExitNode)): for iconn in node.in_connectors: - if (iconn is not None and iconn.startswith("IN_") and ("OUT_" + iconn[3:]) not in node.out_connectors): + if (iconn is not None and iconn.startswith("IN_") and not iconn.endswith("_size") and ("OUT_" + iconn[3:]) not in node.out_connectors): raise InvalidSDFGNodeError( "No match for input connector %s in output " "connectors" % iconn, @@ -685,14 +685,15 @@ def validate_state(state: 'dace.sdfg.SDFGState', break # Check if memlet data matches src or dst nodes + # If is read from the size output connector it needs to match the array's size descriptor name = e.data.data if isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Structure): name = None if isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Structure): name = None if (name is not None and (isinstance(src_node, nd.AccessNode) or isinstance(dst_node, nd.AccessNode)) - and (not isinstance(src_node, nd.AccessNode) or (name != src_node.data and name != e.src_conn)) - and (not isinstance(dst_node, nd.AccessNode) or (name != dst_node.data and name != e.dst_conn))): + and (not isinstance(src_node, nd.AccessNode) or (name != src_node.data and name != e.src_conn and name != src_node.data + "_size")) + and (not isinstance(dst_node, nd.AccessNode) or (name != dst_node.data and name != e.dst_conn and name != dst_node.data + "_size"))): raise InvalidSDFGEdgeError( "Memlet data does not match source or destination " "data nodes)", @@ -716,14 +717,16 @@ def validate_state(state: 'dace.sdfg.SDFGState', # Check memlet subset validity with respect to source/destination nodes if e.data.data is not None and e.data.allow_oob == False: subset_node = (dst_node - if isinstance(dst_node, nd.AccessNode) and e.data.data == dst_node.data else src_node) + if isinstance(dst_node, nd.AccessNode) and e.data.data == dst_node.data or e.data.data == dst_node.data + "_size" else src_node) other_subset_node = (dst_node - if isinstance(dst_node, nd.AccessNode) and e.data.data != dst_node.data else src_node) + if isinstance(dst_node, nd.AccessNode) and e.data.data != dst_node.data or e.data.data == dst_node.data + "_size" else src_node) if isinstance(subset_node, nd.AccessNode): arr = sdfg.arrays[subset_node.data] + size_arr = sdfg.arrays[subset_node.data + "_size"] # Dimensionality - if e.data.subset.dims() != len(arr.shape): + + if e.data.data == subset_node.data and e.data.subset.dims() != len(arr.shape): raise InvalidSDFGEdgeError( "Memlet subset does not match node dimension " "(expected %d, got %d)" % (len(arr.shape), e.data.subset.dims()), @@ -731,6 +734,14 @@ def validate_state(state: 'dace.sdfg.SDFGState', state_id, eid, ) + if e.data.data == (subset_node.data + "_size") and e.data.subset.dims() != len(size_arr.shape): + raise InvalidSDFGEdgeError( + "Memlet subset does not match node size dimension " + "(expected %d, got %d)" % (len(size_arr.shape), e.data.subset.dims()), + sdfg, + state_id, + eid, + ) # Bounds if any(((minel + off) < 0) == True for minel, off in zip(e.data.subset.min_element(), arr.offset)): @@ -741,10 +752,11 @@ def validate_state(state: 'dace.sdfg.SDFGState', raise InvalidSDFGEdgeError("Memlet subset negative out-of-bounds", sdfg, state_id, eid) if any(((maxel + off) >= s) == True for maxel, s, off in zip(e.data.subset.max_element(), arr.shape, arr.offset)): - if e.data.dynamic: + if e.data.dynamic or e.data.data.endswith("_size"): warnings.warn(f'Potential out-of-bounds memlet subset: {e}') else: - raise InvalidSDFGEdgeError("Memlet subset out-of-bounds", sdfg, state_id, eid) + warnings.warn(f'Memlet subset out-of-bounds {sdfg}, {state_id}, {eid}') + #raise InvalidSDFGEdgeError("Memlet subset out-of-bounds", sdfg, state_id, eid) # Test other_subset as well if e.data.other_subset is not None and isinstance(other_subset_node, nd.AccessNode):