diff --git a/AUTHORS b/AUTHORS index 573f142cf9..48cb4c05ec 100644 --- a/AUTHORS +++ b/AUTHORS @@ -36,5 +36,6 @@ Reid Wahl Yihang Luo Alexandru Calotoiu Phillip Lane +Samuel Martin and other contributors listed in https://github.com/spcl/dace/graphs/contributors diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 59f04b7c36..2c41aa1d99 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3655,6 +3655,11 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no # If the symbol is a callback, but is not used in the nested SDFG, skip it continue + # NOTE: Is it possible that an array in the SDFG's closure is not in the SDFG? + # NOTE: Perhaps its use was simplified/optimized away? + if aname not in sdfg.arrays: + continue + # First, we do an inverse lookup on the already added closure arrays for `arr`. is_new_arr = True for k, v in self.nested_closure_arrays.items(): @@ -3831,6 +3836,12 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no for k, v in argdict.items() if self._is_outputnode(sdfg, k) } + # If an argument does not register as input nor as output, put it in the inputs. + # This may happen with input arguments that are used to set a promoted scalar. + for k, v in argdict.items(): + if k not in inputs.keys() and k not in outputs.keys(): + inputs[k] = v + # Add closure to global inputs/outputs (e.g., if processed as part of a map) for arrname in closure_arrays.keys(): if arrname not in names_to_replace: @@ -3842,13 +3853,6 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no if narrname in outputs: self.outputs[arrname] = (state, outputs[narrname], []) - # If an argument does not register as input nor as output, - # put it in the inputs. - # This may happen with input argument that are used to set - # a promoted scalar. - for k, v in argdict.items(): - if k not in inputs.keys() and k not in outputs.keys(): - inputs[k] = v # Unset parent inputs/read accesses that # turn out to be outputs/write accesses. for memlet in outputs.values(): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 0796bf00d0..8059609c36 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -502,13 +502,22 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # is read is not counted in the read set for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()): if isinstance(n, nd.AccessNode): - for e in sg.in_edges(n): + in_edges = sg.in_edges(n) + out_edges = sg.out_edges(n) + # Filter out memlets which go out but the same data is written to the AccessNode by another memlet + for out_edge in list(out_edges): + for in_edge in list(in_edges): + if (in_edge.data.data == out_edge.data.data and + in_edge.data.dst_subset.covers(out_edge.data.src_subset)): + out_edges.remove(out_edge) + + for e in in_edges: # skip empty memlets if e.data.is_empty(): continue # Store all subsets that have been written ws[n.data].append(e.data.subset) - for e in sg.out_edges(n): + for e in out_edges: # skip empty memlets if e.data.is_empty(): continue diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 3bac646479..aa7674ca45 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -42,7 +42,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context """ # Avoid import loop from dace.codegen.targets import fpga - from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga + from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga, is_in_scope references = references or set() @@ -111,6 +111,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context # Check if SDFG is located within a GPU kernel context['in_gpu'] = is_devicelevel_gpu(sdfg, None, None) context['in_fpga'] = is_devicelevel_fpga(sdfg, None, None) + in_default_scope = None # Check every state separately start_state = sdfg.start_state @@ -171,10 +172,18 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context for memlet in ise_memlets: container = memlet.data if not _accessible(sdfg, container, context): - eid = sdfg.edge_id(edge) - raise InvalidSDFGInterstateEdgeError( - f'Trying to read an inaccessible data container "{container}" ' - f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) + # Check context w.r.t. maps + if in_default_scope is None: # Lazy-evaluate in_default_scope + in_default_scope = False + if sdfg.parent_nsdfg_node is not None: + if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, + [dtypes.ScheduleType.Default]): + in_default_scope = True + if in_default_scope is False: + eid = sdfg.edge_id(edge) + raise InvalidSDFGInterstateEdgeError( + f'Trying to read an inaccessible data container "{container}" ' + f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) # Add edge symbols into defined symbols symbols.update(issyms) @@ -219,9 +228,17 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context for memlet in ise_memlets: container = memlet.data if not _accessible(sdfg, container, context): - raise InvalidSDFGInterstateEdgeError( - f'Trying to read an inaccessible data container "{container}" ' - f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) + # Check context w.r.t. maps + if in_default_scope is None: # Lazy-evaluate in_default_scope + in_default_scope = False + if sdfg.parent_nsdfg_node is not None: + if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node, + [dtypes.ScheduleType.Default]): + in_default_scope = True + if in_default_scope is False: + raise InvalidSDFGInterstateEdgeError( + f'Trying to read an inaccessible data container "{container}" ' + f'(Storage: {sdfg.arrays[container].storage}) in host code interstate edge', sdfg, eid) except InvalidSDFGError as ex: # If the SDFG is invalid, save it diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 1b9324546a..71d9e22aca 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -925,7 +925,12 @@ def _candidates( continue # For now we only detect one element + read_set, write_set = nstate.read_and_write_sets() for e in nstate.in_edges(dnode): + if e.data.data not in write_set: + # Skip data which is not in the read and write set of the state -> there also won't be a + # connector + continue # If more than one unique element detected, remove from # candidates if e.data.data in out_candidates: @@ -941,6 +946,10 @@ def _candidates( continue out_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset)))) for e in nstate.out_edges(dnode): + if e.data.data not in read_set: + # Skip data which is not in the read and write set of the state -> there also won't be a + # connector + continue # If more than one unique element detected, remove from # candidates if e.data.data in in_candidates: diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index e313f7bf66..d1b80c2327 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -170,6 +170,9 @@ def remove_redundant_copies(self, sdfg: SDFG, state: SDFGState, removable_data: for anode in access_nodes[aname]: if anode in removed_nodes: continue + if anode not in state.nodes(): + removed_nodes.add(anode) + continue if state.out_degree(anode) == 1: succ = state.successors(anode)[0] diff --git a/tests/sdfg/disallowed_access_test.py b/tests/sdfg/disallowed_access_test.py index 8700e34db5..520481ea46 100644 --- a/tests/sdfg/disallowed_access_test.py +++ b/tests/sdfg/disallowed_access_test.py @@ -40,6 +40,7 @@ def test_gpu_access_on_host_interstate_invalid(): @pytest.mark.gpu def test_gpu_access_on_host_tasklet(): + @dace.program def tester(a: dace.float64[20] @ dace.StorageType.GPU_Global): for i in dace.map[0:20] @ dace.ScheduleType.CPU_Multicore: @@ -49,7 +50,29 @@ def tester(a: dace.float64[20] @ dace.StorageType.GPU_Global): tester.to_sdfg(validate=True) +@pytest.mark.gpu +def test_gpu_access_on_device_interstate_edge_default(): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20], dace.float64, storage=dace.StorageType.GPU_Global) + state = sdfg.add_state() + + me, mx = state.add_map('test', dict(i='0:20')) + + nsdfg = dace.SDFG('nester') + nsdfg.add_array('A', [20], dace.float64, storage=dace.StorageType.GPU_Global) + state1 = nsdfg.add_state() + state2 = nsdfg.add_state() + nsdfg.add_edge(state1, state2, dace.InterstateEdge(assignments=dict(s='A[4]'))) + + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'A'}, {}) + state.add_memlet_path(state.add_read('A'), me, nsdfg_node, dst_conn='A', memlet=dace.Memlet('A[0:20]')) + state.add_nedge(nsdfg_node, mx, dace.Memlet()) + + sdfg.validate() + + if __name__ == '__main__': test_gpu_access_on_host_interstate_ok() test_gpu_access_on_host_interstate_invalid() test_gpu_access_on_host_tasklet() + test_gpu_access_on_device_interstate_edge_default() diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py new file mode 100644 index 0000000000..c5cb953c4d --- /dev/null +++ b/tests/sdfg/state_test.py @@ -0,0 +1,24 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace + + +def test_read_write_set(): + sdfg = dace.SDFG('graph') + A = sdfg.add_array('A', [10], dace.float64) + B = sdfg.add_array('B', [10], dace.float64) + C = sdfg.add_array('C', [10], dace.float64) + state = sdfg.add_state('state') + task1 = state.add_tasklet('work1', {'A'}, {'B'}, 'B = A + 1') + task2 = state.add_tasklet('work2', {'B'}, {'C'}, 'C = B + 1') + read_a = state.add_access('A') + rw_b = state.add_access('B') + write_c = state.add_access('C') + state.add_memlet_path(read_a, task1, dst_conn='A', memlet=dace.Memlet('A[2]')) + state.add_memlet_path(task1, rw_b, src_conn='B', memlet=dace.Memlet('B[2]')) + state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet('B[2]')) + state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet('C[2]')) + + assert 'B' not in state.read_and_write_sets()[0] + +if __name__ == '__main__': + test_read_write_set()