From d0a25f2122ab568cfe2bfd1fc0f0d0380844ae3d Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 8 Oct 2024 22:42:07 +0200 Subject: [PATCH 01/29] First draft: only within-state fusion, and only if the ranges exactly match. --- .../dataflow/const_assignment_fusion.py | 149 ++++++++++++++++++ .../const_assignment_fusion_test.py | 110 +++++++++++++ 2 files changed, 259 insertions(+) create mode 100644 dace/transformation/dataflow/const_assignment_fusion.py create mode 100644 tests/transformations/const_assignment_fusion_test.py diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py new file mode 100644 index 0000000000..dac72329f9 --- /dev/null +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -0,0 +1,149 @@ +import ast +from itertools import chain + +import dace.subsets +from dace import transformation, SDFGState, SDFG, Memlet +from dace.sdfg import nodes +from dace.sdfg.nodes import Tasklet, ExitNode +from dace.transformation.dataflow import MapFusion + + +class ConstAssignmentMapFusion(MapFusion): + first_map_exit = transformation.PatternNode(nodes.ExitNode) + array = transformation.PatternNode(nodes.AccessNode) + second_map_entry = transformation.PatternNode(nodes.EntryNode) + + # NOTE: `expression()` is inherited. + + @staticmethod + def consistent_const_assignment_table(graph, en, ex) -> tuple[bool, dict]: + table = {} + for n in graph.all_nodes_between(en, ex): + # Each of the nodes in this map must be... + if not isinstance(n, Tasklet): + # ...a tasklet... + return False, table + if len(n.code.code) != 1 or not isinstance(n.code.code[0], ast.Assign): + # ...that assigns... + return False, table + op = n.code.code[0] + if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: + # ...a constant to a single target. + return False, table + const = op.value.value + for oe in graph.out_edges(n): + dst = oe.data + dst_arr = oe.data.data + if dst_arr in table and table[dst_arr] != const: + # A target array can appear multiple times, but it must always be consistently assigned. + return False, table + table[dst] = const + table[dst_arr] = const + return True, table + + def map_nodes(self, graph: SDFGState): + return (graph.entry_node(self.first_map_exit), self.first_map_exit, + self.second_map_entry, graph.exit_node(self.second_map_entry)) + + def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) + # TODO(pratyai): Make a better check for map compatibility. + if first_entry.map.range != second_entry.map.range or first_entry.map.schedule != second_entry.map.schedule: + # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. + return False + + # Both maps must have consistent constant assignment for the target arrays. + is_const_assignment, assignments = self.consistent_const_assignment_table(graph, first_entry, first_exit) + if not is_const_assignment: + return False + is_const_assignment, further_assignments = self.consistent_const_assignment_table(graph, second_entry, + second_exit) + if not is_const_assignment: + return False + for k, v in further_assignments.items(): + if k in assignments and v != assignments[k]: + return False + assignments[k] = v + return True + + @staticmethod + def track_access_nodes(graph: SDFGState, first_exit: ExitNode, second_exit: ExitNode): + # Track all the access nodes that will survive the purge. + access_nodes, remove_nodes = {}, set() + dst_nodes = set(e.dst for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit))) + for n in dst_nodes: + if n.data in access_nodes: + remove_nodes.add(n) + else: + access_nodes[n.data] = n + for n in remove_nodes: + assert n.data in access_nodes + assert access_nodes[n.data] != n + return access_nodes, remove_nodes + + @staticmethod + def make_equivalent_connections(first_exit: ExitNode, second_exit: ExitNode): + # Set up the extra connections on the first node. + conn_map = {} + for c, v in second_exit.in_connectors.items(): + assert c.startswith('IN_') + cbase = c.removeprefix('IN_') + sc = first_exit.next_connector(cbase) + conn_map[f"IN_{cbase}"] = f"IN_{sc}" + conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" + first_exit.add_in_connector(f"IN_{sc}", dtype=v) + first_exit.add_out_connector(f"OUT_{sc}", dtype=v) + for c, v in second_exit.out_connectors.items(): + assert c in conn_map + return conn_map + + def apply(self, graph: SDFGState, sdfg: SDFG): + first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) + + # By now, we know that the two maps are compatible, not reading anything, and just blindly writing constants + # _consistently_. + is_const_assignment, assignments = self.consistent_const_assignment_table(graph, first_entry, first_exit) + assert is_const_assignment + + # Track all the access nodes that will survive the purge. + access_nodes, remove_nodes = self.track_access_nodes(graph, first_exit, second_exit) + + # Set up the extra connections on the first node. + conn_map = self.make_equivalent_connections(first_exit, second_exit) + + # Redirect outgoing edges from exit nodes that are going to be invalidated. + for e in graph.out_edges(first_exit): + array_name = e.dst.data + assert array_name in access_nodes + if access_nodes[array_name] != e.dst: + graph.add_memlet_path(first_exit, access_nodes[array_name], src_conn=e.src_conn, dst_conn=e.dst_conn, + memlet=Memlet(str(e.data))) + graph.remove_edge(e) + for e in graph.out_edges(second_exit): + array_name = e.dst.data + assert array_name in access_nodes + graph.add_memlet_path(first_exit, access_nodes[array_name], src_conn=conn_map[e.src_conn], + dst_conn=e.dst_conn, memlet=Memlet(str(e.data))) + graph.remove_edge(e) + + # Move the tasklets from the second map into the first map. + second_tasklets = graph.all_nodes_between(second_entry, second_exit) + for t in second_tasklets: + for e in graph.in_edges(t): + graph.add_memlet_path(first_entry, t, memlet=Memlet()) + graph.remove_edge(e) + for e in graph.out_edges(t): + graph.add_memlet_path(e.src, first_exit, src_conn=e.src_conn, dst_conn=conn_map[e.dst_conn], + memlet=Memlet(str(e.data))) + graph.remove_edge(e) + + # Redirect any outgoing edges from the nodes to be removed through their surviving counterparts. + for n in remove_nodes: + for e in graph.out_edges(n): + if e.dst != second_entry: + alt_n = access_nodes[n.data] + memlet = Memlet(str(e.data)) if not e.data.is_empty() else Memlet() + graph.add_memlet_path(alt_n, e.dst, src_conn=e.src_conn, dst_conn=e.dst_conn, memlet=memlet) + graph.remove_node(n) + graph.remove_node(second_entry) + graph.remove_node(second_exit) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py new file mode 100644 index 0000000000..844be07a4b --- /dev/null +++ b/tests/transformations/const_assignment_fusion_test.py @@ -0,0 +1,110 @@ +import os +from copy import deepcopy + +import numpy as np + +import dace +from dace.sdfg import nodes +from dace.transformation.dataflow.const_assignment_fusion import ConstAssignmentMapFusion +from dace.transformation.interstate import StateFusionExtended + +M = dace.symbol('M') +N = dace.symbol('N') + + +@dace.program +def assign_top_row(A: dace.float32[M, N]): + for i in dace.map[0:N]: + A[0, i] = 1 + + +@dace.program +def assign_bottom_row(A: dace.float32[M, N]): + for i in dace.map[0:N]: + A[M - 1, i] = 1 + + +@dace.program +def assign_left_col(A: dace.float32[M, N]): + for i in dace.map[0:M]: + A[i, 0] = 1 + + +@dace.program +def assign_right_col(A: dace.float32[M, N]): + for i in dace.map[0:M]: + A[i, N - 1] = 1 + + +def assign_bounary_sdfg(): + st0 = assign_top_row.to_sdfg(simplify=True, validate=True) + st0.start_block.label = 'st0' + + st1 = assign_bottom_row.to_sdfg(simplify=True, validate=True) + st1.start_block.label = 'st1' + st0.add_edge(st0.start_state, st1.start_state, dace.InterstateEdge()) + + st2 = assign_left_col.to_sdfg(simplify=True, validate=True) + st2.start_block.label = 'st2' + st0.add_edge(st1.start_state, st2.start_state, dace.InterstateEdge()) + + st3 = assign_right_col.to_sdfg(simplify=True, validate=True) + st3.start_block.label = 'st3' + st0.add_edge(st2.start_state, st3.start_state, dace.InterstateEdge()) + + return st0 + + +def find_access_node_by_name(g, name): + """ Finds the first data node by the given name""" + return next((n, s) for n, s in g.all_nodes_recursive() + if isinstance(n, nodes.AccessNode) and name == n.data) + + +def find_map_entry_by_name(g, name): + """ Finds the first map entry node by the given name """ + return next((n, s) for n, s in g.all_nodes_recursive() + if isinstance(n, nodes.MapEntry) and n.label.startswith(name)) + + +def find_map_exit_by_name(g, name): + """ Finds the first map entry node by the given name """ + return next((n, s) for n, s in g.all_nodes_recursive() + if isinstance(n, nodes.MapExit) and n.label.startswith(name)) + + +def test_within_state_fusion(): + A = np.random.uniform(size=(4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_bounary_sdfg() + g.save(os.path.join('_dacegraphs', 'simple-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + g(A=actual_A, M=4, N=5) + + # Fuse the two states so that the const-assignment-fusion is applicable. + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', 'simple-1.sdfg')) + g.validate() + + g.apply_transformations(ConstAssignmentMapFusion) + g.save(os.path.join('_dacegraphs', 'simple-2.sdfg')) + g.validate() + + g.apply_transformations(ConstAssignmentMapFusion) + g.save(os.path.join('_dacegraphs', 'simple-3.sdfg')) + g.validate() + + g.apply_transformations(ConstAssignmentMapFusion) + g.save(os.path.join('_dacegraphs', 'simple-4.sdfg')) + g.validate() + our_A = deepcopy(A) + g(A=our_A, M=4, N=5) + + print(our_A) + assert np.allclose(our_A, actual_A) + + +if __name__ == '__main__': + test_within_state_fusion() From 0c76246ec82476bcc1cfce9cbfd4ceddb9406668 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 8 Oct 2024 23:39:19 +0200 Subject: [PATCH 02/29] add interstate fusion in there --- .../dataflow/const_assignment_fusion.py | 64 +++++++++++++++++-- .../const_assignment_fusion_test.py | 52 +++++++++------ 2 files changed, 90 insertions(+), 26 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index dac72329f9..e88184a892 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -1,11 +1,21 @@ import ast from itertools import chain +from typing import Optional -import dace.subsets from dace import transformation, SDFGState, SDFG, Memlet from dace.sdfg import nodes -from dace.sdfg.nodes import Tasklet, ExitNode +from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit from dace.transformation.dataflow import MapFusion +from dace.transformation.interstate import StateFusionExtended + + +def unique_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]: + all_nodes = list(graph.all_nodes_recursive()) + en: list[MapEntry] = [n for n, _ in all_nodes if isinstance(n, MapEntry)] + ex = [n for n, _ in all_nodes if isinstance(n, MapExit)] + if len(en) != 1 or len(ex) != 1: + return None + return en[0], ex[0] class ConstAssignmentMapFusion(MapFusion): @@ -16,7 +26,7 @@ class ConstAssignmentMapFusion(MapFusion): # NOTE: `expression()` is inherited. @staticmethod - def consistent_const_assignment_table(graph, en, ex) -> tuple[bool, dict]: + def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: table = {} for n in graph.all_nodes_between(en, ex): # Each of the nodes in this map must be... @@ -45,12 +55,18 @@ def map_nodes(self, graph: SDFGState): return (graph.entry_node(self.first_map_exit), self.first_map_exit, self.second_map_entry, graph.exit_node(self.second_map_entry)) - def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: - first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) + @staticmethod + def compatible_range(first_entry: MapEntry, second_entry: MapEntry): # TODO(pratyai): Make a better check for map compatibility. if first_entry.map.range != second_entry.map.range or first_entry.map.schedule != second_entry.map.schedule: # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. return False + return True + + def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) + if not self.compatible_range(first_entry, second_entry): + return False # Both maps must have consistent constant assignment for the target arrays. is_const_assignment, assignments = self.consistent_const_assignment_table(graph, first_entry, first_exit) @@ -147,3 +163,41 @@ def apply(self, graph: SDFGState, sdfg: SDFG): graph.remove_node(n) graph.remove_node(second_entry) graph.remove_node(second_exit) + + +class ConstAssignmentStateFusion(StateFusionExtended): + first_state = transformation.PatternNode(SDFGState) + second_state = transformation.PatternNode(SDFGState) + + # NOTE: `expression()` is inherited. + + def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + # All the basic rules apply. + if not super().can_be_applied(graph, expr_index, sdfg, permissive): + return False + st0, st1 = self.first_state, self.second_state + + # Moreover, each state must contain just one constant assignment map. + for st in [st0, st1]: + en_ex = unique_map_node(st) + if not en_ex: + return False + en, ex = en_ex + if len(st.in_edges(en)) != 0: + return False + is_const_assignment, assignments = ConstAssignmentMapFusion.consistent_const_assignment_table(st, en, ex) + if not is_const_assignment: + return False + + # Moreover, both states' ranges must be compatible. + if not ConstAssignmentMapFusion.compatible_range(unique_map_node(st0)[0], unique_map_node(st1)[0]): + return False + + return True + + def apply(self, graph: SDFGState, sdfg: SDFG): + # First, fuse the two states. + super().apply(graph, sdfg) + sdfg.validate() + sdfg.apply_transformations_repeated(ConstAssignmentMapFusion) + sdfg.validate() diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 844be07a4b..5dd60c8a6d 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -4,8 +4,7 @@ import numpy as np import dace -from dace.sdfg import nodes -from dace.transformation.dataflow.const_assignment_fusion import ConstAssignmentMapFusion +from dace.transformation.dataflow.const_assignment_fusion import ConstAssignmentMapFusion, ConstAssignmentStateFusion from dace.transformation.interstate import StateFusionExtended M = dace.symbol('M') @@ -55,24 +54,6 @@ def assign_bounary_sdfg(): return st0 -def find_access_node_by_name(g, name): - """ Finds the first data node by the given name""" - return next((n, s) for n, s in g.all_nodes_recursive() - if isinstance(n, nodes.AccessNode) and name == n.data) - - -def find_map_entry_by_name(g, name): - """ Finds the first map entry node by the given name """ - return next((n, s) for n, s in g.all_nodes_recursive() - if isinstance(n, nodes.MapEntry) and n.label.startswith(name)) - - -def find_map_exit_by_name(g, name): - """ Finds the first map entry node by the given name """ - return next((n, s) for n, s in g.all_nodes_recursive() - if isinstance(n, nodes.MapExit) and n.label.startswith(name)) - - def test_within_state_fusion(): A = np.random.uniform(size=(4, 5)).astype(np.float32) @@ -102,9 +83,38 @@ def test_within_state_fusion(): our_A = deepcopy(A) g(A=our_A, M=4, N=5) - print(our_A) + # print(our_A) + assert np.allclose(our_A, actual_A) + + +def test_interstate_fusion(): + A = np.random.uniform(size=(4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_bounary_sdfg() + g.save(os.path.join('_dacegraphs', 'interstate-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + g(A=actual_A, M=4, N=5) + + g.apply_transformations(ConstAssignmentStateFusion) + g.save(os.path.join('_dacegraphs', 'interstate-1.sdfg')) + g.validate() + + g.apply_transformations(ConstAssignmentStateFusion) + g.save(os.path.join('_dacegraphs', 'interstate-2.sdfg')) + g.validate() + + g.apply_transformations(ConstAssignmentStateFusion) + g.save(os.path.join('_dacegraphs', 'interstate-3.sdfg')) + g.validate() + our_A = deepcopy(A) + g(A=our_A, M=4, N=5) + + # print(our_A) assert np.allclose(our_A, actual_A) if __name__ == '__main__': test_within_state_fusion() + test_interstate_fusion() From 1ef574b06346dcd7377324201386bed8d5ca4893 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 9 Oct 2024 00:36:41 +0200 Subject: [PATCH 03/29] Handle the free floating maps too, whenever possible. --- .../dataflow/const_assignment_fusion.py | 30 ++++++++++++++++++- .../const_assignment_fusion_test.py | 29 ++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index e88184a892..24b26dffa5 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -4,7 +4,9 @@ from dace import transformation, SDFGState, SDFG, Memlet from dace.sdfg import nodes +from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit +from dace.sdfg.utils import node_path_graph from dace.transformation.dataflow import MapFusion from dace.transformation.interstate import StateFusionExtended @@ -18,12 +20,24 @@ def unique_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]: return en[0], ex[0] +def free_floating_maps(*args): + g = OrderedDiGraph() + for n in args: + g.add_node(n) + return g + + class ConstAssignmentMapFusion(MapFusion): + first_map_entry = transformation.PatternNode(nodes.EntryNode) first_map_exit = transformation.PatternNode(nodes.ExitNode) array = transformation.PatternNode(nodes.AccessNode) second_map_entry = transformation.PatternNode(nodes.EntryNode) + second_map_exit = transformation.PatternNode(nodes.ExitNode) - # NOTE: `expression()` is inherited. + @classmethod + def expressions(cls): + return [node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry), + free_floating_maps(cls.first_map_entry, cls.first_map_exit, cls.second_map_entry, cls.second_map_exit)] @staticmethod def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: @@ -63,7 +77,21 @@ def compatible_range(first_entry: MapEntry, second_entry: MapEntry): return False return True + def can_be_applied_free_floating(self, graph: SDFGState, sdfg: SDFG, permissive: bool = False) -> bool: + first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) + if first_entry != self.first_map_entry or second_exit != self.second_map_exit: + return False + if not self.compatible_range(first_entry, second_entry): + return False + if graph.all_nodes_between(first_exit, second_entry) or graph.all_nodes_between(second_exit, first_entry): + return False + return True + def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + assert expr_index in (0, 1) + if expr_index == 1: + if not self.can_be_applied_free_floating(graph, sdfg, permissive): + return False first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) if not self.compatible_range(first_entry, second_entry): return False diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 5dd60c8a6d..a2c7ecc23d 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -115,6 +115,35 @@ def test_interstate_fusion(): assert np.allclose(our_A, actual_A) +@dace.program +def assign_bounary_free_floating(A: dace.float32[M, N], B: dace.float32[M, N]): + assign_top_row(A) + assign_bottom_row(B) + + +def test_free_floating_fusion(): + A = np.random.uniform(size=(4, 5)).astype(np.float32) + B = np.random.uniform(size=(4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_bounary_free_floating.to_sdfg(simplify=True) + g.save(os.path.join('_dacegraphs', 'floating-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + actual_B = deepcopy(B) + g(A=actual_A, B=actual_B, M=4, N=5) + + g.apply_transformations(ConstAssignmentMapFusion) + g.save(os.path.join('_dacegraphs', 'floating-1.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, M=4, N=5) + + # print(our_A) + assert np.allclose(our_A, actual_A) + + if __name__ == '__main__': test_within_state_fusion() test_interstate_fusion() From 5121081d732a0e1ea45f5545b5a2a94a191113fb Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 9 Oct 2024 22:11:42 +0200 Subject: [PATCH 04/29] Handle the fact that loop variable names can be different. --- .../dataflow/const_assignment_fusion.py | 7 ++++++- .../const_assignment_fusion_test.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 24b26dffa5..30e8bb42a0 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -149,6 +149,9 @@ def apply(self, graph: SDFGState, sdfg: SDFG): is_const_assignment, assignments = self.consistent_const_assignment_table(graph, first_entry, first_exit) assert is_const_assignment + # Keep track in case loop variables are named differently. + param_map = {p2: p1 for p1, p2 in zip(first_entry.map.params, second_entry.map.params)} + # Track all the access nodes that will survive the purge. access_nodes, remove_nodes = self.track_access_nodes(graph, first_exit, second_exit) @@ -177,8 +180,10 @@ def apply(self, graph: SDFGState, sdfg: SDFG): graph.add_memlet_path(first_entry, t, memlet=Memlet()) graph.remove_edge(e) for e in graph.out_edges(t): + e_data = e.data + e_data.subset.replace(param_map) graph.add_memlet_path(e.src, first_exit, src_conn=e.src_conn, dst_conn=conn_map[e.dst_conn], - memlet=Memlet(str(e.data))) + memlet=Memlet(str(e_data))) graph.remove_edge(e) # Redirect any outgoing edges from the nodes to be removed through their surviving counterparts. diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index a2c7ecc23d..2ce6dba1f3 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -13,26 +13,26 @@ @dace.program def assign_top_row(A: dace.float32[M, N]): - for i in dace.map[0:N]: - A[0, i] = 1 + for t in dace.map[0:N]: + A[0, t] = 1 @dace.program def assign_bottom_row(A: dace.float32[M, N]): - for i in dace.map[0:N]: - A[M - 1, i] = 1 + for b in dace.map[0:N]: + A[M - 1, b] = 1 @dace.program def assign_left_col(A: dace.float32[M, N]): - for i in dace.map[0:M]: - A[i, 0] = 1 + for l in dace.map[0:M]: + A[l, 0] = 1 @dace.program def assign_right_col(A: dace.float32[M, N]): - for i in dace.map[0:M]: - A[i, N - 1] = 1 + for r in dace.map[0:M]: + A[r, N - 1] = 1 def assign_bounary_sdfg(): From 95ea47ac58b40a52117b4bee6f8b5f88ec386d06 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Thu, 10 Oct 2024 10:30:32 +0200 Subject: [PATCH 05/29] handle if the const assignments are behind an if branch that only depends on the map parameters. --- .../dataflow/const_assignment_fusion.py | 106 +++++++++++++++--- .../const_assignment_fusion_test.py | 37 ++++++ 2 files changed, 130 insertions(+), 13 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 30e8bb42a0..d3036f6fdd 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -5,7 +5,8 @@ from dace import transformation, SDFGState, SDFG, Memlet from dace.sdfg import nodes from dace.sdfg.graph import OrderedDiGraph -from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit +from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node +from dace.sdfg.state import ControlFlowBlock from dace.sdfg.utils import node_path_graph from dace.transformation.dataflow import MapFusion from dace.transformation.interstate import StateFusionExtended @@ -40,22 +41,61 @@ def expressions(cls): free_floating_maps(cls.first_map_entry, cls.first_map_exit, cls.second_map_entry, cls.second_map_exit)] @staticmethod - def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: + def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: table = {} - for n in graph.all_nodes_between(en, ex): - # Each of the nodes in this map must be... + # Basic premise check. + if not isinstance(graph, NestedSDFG): + return False, table + graph: SDFG = graph.sdfg + if not isinstance(graph, ControlFlowBlock): + return False, table + + # Must have exactly 3 nodes, and exactly one of them a source, another a sink. + src, snk = graph.source_nodes(), graph.sink_nodes() + if len(graph.nodes()) != 3 or len(src) != 1 or len(snk) != 1: + return False, table + src, snk = src[0], snk[0] + body = set(graph.nodes()) - {src, snk} + if len(body) != 1: + return False, table + body = list(body)[0] + + # Must have certain structure of outgoing edges. + src_eds = list(graph.out_edges(src)) + if len(src_eds) != 2 or any(e.data.is_unconditional() or e.data.assignments for e in src_eds): + return False, table + tb, el = src_eds + if tb.dst != body: + tb, el = el, tb + if tb.dst != body or el.dst != snk: + return False, table + body_eds = list(graph.out_edges(body)) + if len(body_eds) != 1 or body_eds[0].dst != snk or not body_eds[0].data.is_unconditional() or body_eds[ + 0].data.assignments: + return False, table + + # Branch conditions must depend only on the loop variables. + for b in [tb, el]: + cond = b.data.condition + for c in cond.code: + used = set([ast_node.id for ast_node in ast.walk(c) if isinstance(ast_node, ast.Name)]) + if not used.issubset(graph.free_symbols): + return False, table + + # Body must have only constant assignments. + for n, _ in body.all_nodes_recursive(): + # Each tasklet in this box... if not isinstance(n, Tasklet): - # ...a tasklet... - return False, table + continue if len(n.code.code) != 1 or not isinstance(n.code.code[0], ast.Assign): - # ...that assigns... + # ...must assign... return False, table op = n.code.code[0] if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: # ...a constant to a single target. return False, table const = op.value.value - for oe in graph.out_edges(n): + for oe in body.out_edges(n): dst = oe.data dst_arr = oe.data.data if dst_arr in table and table[dst_arr] != const: @@ -65,6 +105,46 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi table[dst_arr] = const return True, table + @staticmethod + def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: + table = {} + for n in graph.all_nodes_between(en, ex): + if isinstance(n, NestedSDFG): + # First handle the case of conditional constant assignment. + is_branch_const_assignment, internal_table = ConstAssignmentMapFusion.consistent_branch_const_assignment_table(n) + if not is_branch_const_assignment: + return False, table + for oe in graph.out_edges(n): + dst = oe.data + dst_arr = oe.data.data + if dst_arr in table and table[dst_arr] != internal_table[oe.src_conn]: + # A target array can appear multiple times, but it must always be consistently assigned. + return False, table + table[dst] = internal_table[oe.src_conn] + table[dst_arr] = internal_table[oe.src_conn] + else: + # Each of the nodes in this map must be... + if not isinstance(n, Tasklet): + # ...a tasklet... + return False, table + if len(n.code.code) != 1 or not isinstance(n.code.code[0], ast.Assign): + # ...that assigns... + return False, table + op = n.code.code[0] + if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: + # ...a constant to a single target. + return False, table + const = op.value.value + for oe in graph.out_edges(n): + dst = oe.data + dst_arr = oe.data.data + if dst_arr in table and table[dst_arr] != const: + # A target array can appear multiple times, but it must always be consistently assigned. + return False, table + table[dst] = const + table[dst_arr] = const + return True, table + def map_nodes(self, graph: SDFGState): return (graph.entry_node(self.first_map_exit), self.first_map_exit, self.second_map_entry, graph.exit_node(self.second_map_entry)) @@ -164,26 +244,26 @@ def apply(self, graph: SDFGState, sdfg: SDFG): assert array_name in access_nodes if access_nodes[array_name] != e.dst: graph.add_memlet_path(first_exit, access_nodes[array_name], src_conn=e.src_conn, dst_conn=e.dst_conn, - memlet=Memlet(str(e.data))) + memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) for e in graph.out_edges(second_exit): array_name = e.dst.data assert array_name in access_nodes graph.add_memlet_path(first_exit, access_nodes[array_name], src_conn=conn_map[e.src_conn], - dst_conn=e.dst_conn, memlet=Memlet(str(e.data))) + dst_conn=e.dst_conn, memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) # Move the tasklets from the second map into the first map. second_tasklets = graph.all_nodes_between(second_entry, second_exit) for t in second_tasklets: for e in graph.in_edges(t): - graph.add_memlet_path(first_entry, t, memlet=Memlet()) + graph.add_memlet_path(first_entry, t, memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) for e in graph.out_edges(t): e_data = e.data e_data.subset.replace(param_map) graph.add_memlet_path(e.src, first_exit, src_conn=e.src_conn, dst_conn=conn_map[e.dst_conn], - memlet=Memlet(str(e_data))) + memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) # Redirect any outgoing edges from the nodes to be removed through their surviving counterparts. @@ -191,7 +271,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): for e in graph.out_edges(n): if e.dst != second_entry: alt_n = access_nodes[n.data] - memlet = Memlet(str(e.data)) if not e.data.is_empty() else Memlet() + memlet = Memlet.from_memlet(e.data) graph.add_memlet_path(alt_n, e.dst, src_conn=e.src_conn, dst_conn=e.dst_conn, memlet=memlet) graph.remove_node(n) graph.remove_node(second_entry) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 2ce6dba1f3..09331077a0 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -16,6 +16,12 @@ def assign_top_row(A: dace.float32[M, N]): for t in dace.map[0:N]: A[0, t] = 1 +@dace.program +def assign_top_row_branched(A: dace.float32[M, N]): + for t, in dace.map[0:N]: + if t % 2 == 0: + A[0, t] = 1 + @dace.program def assign_bottom_row(A: dace.float32[M, N]): @@ -144,6 +150,37 @@ def test_free_floating_fusion(): assert np.allclose(our_A, actual_A) +@dace.program +def assign_bounary_with_branch(A: dace.float32[M, N], B: dace.float32[M, N]): + assign_top_row_branched(A) + assign_bottom_row(B) + + +def test_fusion_with_branch(): + A = np.random.uniform(size=(4, 5)).astype(np.float32) + B = np.random.uniform(size=(4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_bounary_with_branch.to_sdfg(simplify=True) + g.save(os.path.join('_dacegraphs', 'branched-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + actual_B = deepcopy(B) + g(A=actual_A, B=actual_B, M=4, N=5) + + g.apply_transformations(ConstAssignmentMapFusion) + g.save(os.path.join('_dacegraphs', 'branched-1.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, M=4, N=5) + + # print(our_A) + assert np.allclose(our_A, actual_A) + + if __name__ == '__main__': test_within_state_fusion() test_interstate_fusion() + test_free_floating_fusion() + test_fusion_with_branch() From e9be112ecaa33a8e09b64f0512895d5d7dde9fd9 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Thu, 10 Oct 2024 11:25:35 +0200 Subject: [PATCH 06/29] small changes and documentation --- .../dataflow/const_assignment_fusion.py | 71 +++++++++++++++---- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index d3036f6fdd..e7dabfcc17 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -21,7 +21,7 @@ def unique_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]: return en[0], ex[0] -def free_floating_maps(*args): +def floating_nodes_graph(*args): g = OrderedDiGraph() for n in args: g.add_node(n) @@ -29,6 +29,21 @@ def free_floating_maps(*args): class ConstAssignmentMapFusion(MapFusion): + """ + Fuses two maps within a state, where each map: + 1. Either assigns consistent constant values to elements of one or more data arrays. + - Consisency: The values must be the same for all elements in a data array (in both maps). But different data + arrays are allowed to have different values. + 2. Or assigns constant values as described earlier, but _conditionally_. The condition must only depend on the map + Parameters. + + Further conditions: + 1. Range compatibility: The two map must have the exact same range. + # TODO(pratyai): Generalize this in `compatible_range()`. + 2. The maps must have one of the following patterns. + - Exists a path like: MapExit -> AccessNode -> MapEntry + - Neither map is dependent on the other. I.e. There is no dependency path between them. + """ first_map_entry = transformation.PatternNode(nodes.EntryNode) first_map_exit = transformation.PatternNode(nodes.ExitNode) array = transformation.PatternNode(nodes.AccessNode) @@ -38,10 +53,15 @@ class ConstAssignmentMapFusion(MapFusion): @classmethod def expressions(cls): return [node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry), - free_floating_maps(cls.first_map_entry, cls.first_map_exit, cls.second_map_entry, cls.second_map_exit)] + floating_nodes_graph(cls.first_map_entry, cls.first_map_exit, cls.second_map_entry, + cls.second_map_exit)] @staticmethod def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: + """ + If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays + and memlets to their consistent constant assignments. See the class docstring for what is considered consistent. + """ table = {} # Basic premise check. if not isinstance(graph, NestedSDFG): @@ -107,6 +127,11 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: @staticmethod def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: + """ + If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table + mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is + considered consistent. + """ table = {} for n in graph.all_nodes_between(en, ex): if isinstance(n, NestedSDFG): @@ -146,23 +171,26 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi return True, table def map_nodes(self, graph: SDFGState): + """Return the entry and exit nodes of the relevant maps as a tuple: entry_1, exit_1, entry_2, exit_2.""" return (graph.entry_node(self.first_map_exit), self.first_map_exit, self.second_map_entry, graph.exit_node(self.second_map_entry)) @staticmethod - def compatible_range(first_entry: MapEntry, second_entry: MapEntry): - # TODO(pratyai): Make a better check for map compatibility. - if first_entry.map.range != second_entry.map.range or first_entry.map.schedule != second_entry.map.schedule: + def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: + """Decide if the two ranges are compatible. See the class docstring for what is considered compatible.""" + if first_entry.map.schedule != second_entry.map.schedule: + # If the two maps are not to be scheduled on the same device, don't fuse them. + return False + if first_entry.map.range != second_entry.map.range: # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. return False return True - def can_be_applied_free_floating(self, graph: SDFGState, sdfg: SDFG, permissive: bool = False) -> bool: + def no_dependency_pattern(self, graph: SDFGState, sdfg: SDFG, permissive: bool = False) -> bool: + """Decide if the two maps are independent of each other.""" first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) if first_entry != self.first_map_entry or second_exit != self.second_map_exit: return False - if not self.compatible_range(first_entry, second_entry): - return False if graph.all_nodes_between(first_exit, second_entry) or graph.all_nodes_between(second_exit, first_entry): return False return True @@ -170,8 +198,10 @@ def can_be_applied_free_floating(self, graph: SDFGState, sdfg: SDFG, permissive: def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: assert expr_index in (0, 1) if expr_index == 1: - if not self.can_be_applied_free_floating(graph, sdfg, permissive): + # Test the rest of the second pattern in the `expressions()`. + if not self.no_dependency_pattern(graph, sdfg, permissive): return False + first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) if not self.compatible_range(first_entry, second_entry): return False @@ -191,8 +221,12 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return True @staticmethod - def track_access_nodes(graph: SDFGState, first_exit: ExitNode, second_exit: ExitNode): - # Track all the access nodes that will survive the purge. + def track_access_nodes(graph: SDFGState, first_exit: ExitNode, second_exit: ExitNode) -> tuple[dict, set]: + """ + Track all the access nodes that will survive after cleaning up duplicates. Returns a tuple with: + 1. A map: the underlying data array -> the surviving access node. + 2. A set of access nodes to be removed, i.e. has a duplicate in the map described earlier. + """ access_nodes, remove_nodes = {}, set() dst_nodes = set(e.dst for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit))) for n in dst_nodes: @@ -206,8 +240,11 @@ def track_access_nodes(graph: SDFGState, first_exit: ExitNode, second_exit: Exit return access_nodes, remove_nodes @staticmethod - def make_equivalent_connections(first_exit: ExitNode, second_exit: ExitNode): - # Set up the extra connections on the first node. + def make_equivalent_connectors(first_exit: ExitNode, second_exit: ExitNode): + """ + Create the additional connectors in the first exit node that matches the second exit node (which will be removed + later). + """ conn_map = {} for c, v in second_exit.in_connectors.items(): assert c.startswith('IN_') @@ -236,7 +273,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): access_nodes, remove_nodes = self.track_access_nodes(graph, first_exit, second_exit) # Set up the extra connections on the first node. - conn_map = self.make_equivalent_connections(first_exit, second_exit) + conn_map = self.make_equivalent_connectors(first_exit, second_exit) # Redirect outgoing edges from exit nodes that are going to be invalidated. for e in graph.out_edges(first_exit): @@ -279,6 +316,12 @@ def apply(self, graph: SDFGState, sdfg: SDFG): class ConstAssignmentStateFusion(StateFusionExtended): + """ + If two consecutive states are such that + 1. Each state has just one _constant assigment map_ (see the docstring of `ConstAssignmentMapFusion`). + 2. If those two maps were in the same state `ConstAssignmentMapFusion` would fuse them. + then fuse the two states. + """ first_state = transformation.PatternNode(SDFGState) second_state = transformation.PatternNode(SDFGState) From 4bc5d1411f9ec51905b79cfe1285361b68a64698 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Thu, 10 Oct 2024 14:55:41 +0200 Subject: [PATCH 07/29] Don't use cache in tests. --- .../transformations/const_assignment_fusion_test.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 09331077a0..97f462a9b4 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -16,6 +16,7 @@ def assign_top_row(A: dace.float32[M, N]): for t in dace.map[0:N]: A[0, t] = 1 + @dace.program def assign_top_row_branched(A: dace.float32[M, N]): for t, in dace.map[0:N]: @@ -42,18 +43,18 @@ def assign_right_col(A: dace.float32[M, N]): def assign_bounary_sdfg(): - st0 = assign_top_row.to_sdfg(simplify=True, validate=True) + st0 = assign_top_row.to_sdfg(simplify=True, validate=True, use_cache=False) st0.start_block.label = 'st0' - st1 = assign_bottom_row.to_sdfg(simplify=True, validate=True) + st1 = assign_bottom_row.to_sdfg(simplify=True, validate=True, use_cache=False) st1.start_block.label = 'st1' st0.add_edge(st0.start_state, st1.start_state, dace.InterstateEdge()) - st2 = assign_left_col.to_sdfg(simplify=True, validate=True) + st2 = assign_left_col.to_sdfg(simplify=True, validate=True, use_cache=False) st2.start_block.label = 'st2' st0.add_edge(st1.start_state, st2.start_state, dace.InterstateEdge()) - st3 = assign_right_col.to_sdfg(simplify=True, validate=True) + st3 = assign_right_col.to_sdfg(simplify=True, validate=True, use_cache=False) st3.start_block.label = 'st3' st0.add_edge(st2.start_state, st3.start_state, dace.InterstateEdge()) @@ -132,7 +133,7 @@ def test_free_floating_fusion(): B = np.random.uniform(size=(4, 5)).astype(np.float32) # Construct SDFG with the maps on separate states. - g = assign_bounary_free_floating.to_sdfg(simplify=True) + g = assign_bounary_free_floating.to_sdfg(simplify=True, validate=True, use_cache=False) g.save(os.path.join('_dacegraphs', 'floating-0.sdfg')) g.validate() actual_A = deepcopy(A) @@ -161,7 +162,7 @@ def test_fusion_with_branch(): B = np.random.uniform(size=(4, 5)).astype(np.float32) # Construct SDFG with the maps on separate states. - g = assign_bounary_with_branch.to_sdfg(simplify=True) + g = assign_bounary_with_branch.to_sdfg(simplify=True, validate=True, use_cache=False) g.save(os.path.join('_dacegraphs', 'branched-0.sdfg')) g.validate() actual_A = deepcopy(A) From c8206b1281f372b47fe02644c6d9d019b1551adc Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Thu, 10 Oct 2024 23:38:43 +0200 Subject: [PATCH 08/29] Fix the "no dependency" pattern + a small refactor. --- .../dataflow/const_assignment_fusion.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index e7dabfcc17..226f0ff361 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -1,11 +1,12 @@ import ast +from collections import defaultdict from itertools import chain -from typing import Optional +from typing import Optional, Union -from dace import transformation, SDFGState, SDFG, Memlet +from dace import transformation, SDFGState, SDFG, Memlet, subsets from dace.sdfg import nodes from dace.sdfg.graph import OrderedDiGraph -from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node +from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode from dace.sdfg.state import ControlFlowBlock from dace.sdfg.utils import node_path_graph from dace.transformation.dataflow import MapFusion @@ -191,7 +192,7 @@ def no_dependency_pattern(self, graph: SDFGState, sdfg: SDFG, permissive: bool = first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) if first_entry != self.first_map_entry or second_exit != self.second_map_exit: return False - if graph.all_nodes_between(first_exit, second_entry) or graph.all_nodes_between(second_exit, first_entry): + if graph.in_edges(first_entry) or graph.in_edges(second_entry): return False return True @@ -228,7 +229,8 @@ def track_access_nodes(graph: SDFGState, first_exit: ExitNode, second_exit: Exit 2. A set of access nodes to be removed, i.e. has a duplicate in the map described earlier. """ access_nodes, remove_nodes = {}, set() - dst_nodes = set(e.dst for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit))) + dst_nodes = set(e.dst for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit)) + if isinstance(e.dst, AccessNode)) for n in dst_nodes: if n.data in access_nodes: remove_nodes.add(n) @@ -240,21 +242,21 @@ def track_access_nodes(graph: SDFGState, first_exit: ExitNode, second_exit: Exit return access_nodes, remove_nodes @staticmethod - def make_equivalent_connectors(first_exit: ExitNode, second_exit: ExitNode): + def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): """ Create the additional connectors in the first exit node that matches the second exit node (which will be removed later). """ - conn_map = {} - for c, v in second_exit.in_connectors.items(): + conn_map = defaultdict() + for c, v in src.in_connectors.items(): assert c.startswith('IN_') cbase = c.removeprefix('IN_') - sc = first_exit.next_connector(cbase) + sc = dst.next_connector(cbase) conn_map[f"IN_{cbase}"] = f"IN_{sc}" conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" - first_exit.add_in_connector(f"IN_{sc}", dtype=v) - first_exit.add_out_connector(f"OUT_{sc}", dtype=v) - for c, v in second_exit.out_connectors.items(): + dst.add_in_connector(f"IN_{sc}", dtype=v) + dst.add_out_connector(f"OUT_{sc}", dtype=v) + for c, v in src.out_connectors.items(): assert c in conn_map return conn_map @@ -273,7 +275,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): access_nodes, remove_nodes = self.track_access_nodes(graph, first_exit, second_exit) # Set up the extra connections on the first node. - conn_map = self.make_equivalent_connectors(first_exit, second_exit) + conn_map = self.add_equivalent_connectors(first_exit, second_exit) # Redirect outgoing edges from exit nodes that are going to be invalidated. for e in graph.out_edges(first_exit): From 378673ae33496c88faa95e529f4833791157db58 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Fri, 11 Oct 2024 13:35:10 +0200 Subject: [PATCH 09/29] if the ranges are not exactly the same, still try to fuse by adding a grid-stride loop as a guard. --- .../dataflow/const_assignment_fusion.py | 300 +++++++++++++----- .../const_assignment_fusion_test.py | 43 +++ 2 files changed, 270 insertions(+), 73 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 226f0ff361..b142a1fc5c 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -1,5 +1,6 @@ import ast from collections import defaultdict +from copy import deepcopy from itertools import chain from typing import Optional, Union @@ -7,7 +8,7 @@ from dace.sdfg import nodes from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode -from dace.sdfg.state import ControlFlowBlock +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.sdfg.utils import node_path_graph from dace.transformation.dataflow import MapFusion from dace.transformation.interstate import StateFusionExtended @@ -53,6 +54,8 @@ class ConstAssignmentMapFusion(MapFusion): @classmethod def expressions(cls): + # TODO(pratyai): Probably a better pattern idea: take any two maps, then check that _every_ path from the first + # map to second map has exactly one access node in the middle and the second edge of the path is empty. return [node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry), floating_nodes_graph(cls.first_map_entry, cls.first_map_exit, cls.second_map_entry, cls.second_map_exit)] @@ -137,7 +140,8 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi for n in graph.all_nodes_between(en, ex): if isinstance(n, NestedSDFG): # First handle the case of conditional constant assignment. - is_branch_const_assignment, internal_table = ConstAssignmentMapFusion.consistent_branch_const_assignment_table(n) + is_branch_const_assignment, internal_table = ConstAssignmentMapFusion.consistent_branch_const_assignment_table( + n) if not is_branch_const_assignment: return False, table for oe in graph.out_edges(n): @@ -182,7 +186,8 @@ def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: if first_entry.map.schedule != second_entry.map.schedule: # If the two maps are not to be scheduled on the same device, don't fuse them. return False - if first_entry.map.range != second_entry.map.range: + if len(first_entry.map.range) != len(second_entry.map.range): + # If it's not even possible to take component-wise union of the two map's range, don't fuse them. # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. return False return True @@ -192,7 +197,11 @@ def no_dependency_pattern(self, graph: SDFGState, sdfg: SDFG, permissive: bool = first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) if first_entry != self.first_map_entry or second_exit != self.second_map_exit: return False - if graph.in_edges(first_entry) or graph.in_edges(second_entry): + if any(not e.data.is_empty() + for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))): + return False + if any(not isinstance(e.src, AccessNode) + for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))): return False return True @@ -221,26 +230,6 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi assignments[k] = v return True - @staticmethod - def track_access_nodes(graph: SDFGState, first_exit: ExitNode, second_exit: ExitNode) -> tuple[dict, set]: - """ - Track all the access nodes that will survive after cleaning up duplicates. Returns a tuple with: - 1. A map: the underlying data array -> the surviving access node. - 2. A set of access nodes to be removed, i.e. has a duplicate in the map described earlier. - """ - access_nodes, remove_nodes = {}, set() - dst_nodes = set(e.dst for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit)) - if isinstance(e.dst, AccessNode)) - for n in dst_nodes: - if n.data in access_nodes: - remove_nodes.add(n) - else: - access_nodes[n.data] = n - for n in remove_nodes: - assert n.data in access_nodes - assert access_nodes[n.data] != n - return access_nodes, remove_nodes - @staticmethod def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): """ @@ -260,6 +249,187 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN assert c in conn_map return conn_map + @staticmethod + def connector_counterpart(c: Union[str, None]) -> Union[str, None]: + """If it's an input connector, find the corresponding output connector, and vice versa.""" + if c is None: + return None + assert isinstance(c, str) + if c.startswith('IN_'): + return f"OUT_{c.removeprefix('IN_')}" + elif c.startswith('OUT_'): + return f"IN_{c.removeprefix('OUT_')}" + return None + + @staticmethod + def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry): + """ + Remove all the incoming edges of the two maps and add empty edges from the union of the access nodes they + depended on before. + + Preconditions: + 1. All the incoming edges of the two maps must be from an access node and empty (i.e. have existed + only for synchronization). + 2. The two maps must be constistent const assignments (see the class docstring for what is considered + consistent). + """ + # First, construct a table of the dependencies. + table = {} + for en in [first_entry, second_entry]: + for e in graph.in_edges(en): + assert isinstance(e.src, AccessNode) + assert e.data.is_empty() + assert e.src_conn is None and e.dst_conn is None + if e.src.data not in table: + table[e.src.data] = e.src + elif table[e.src.data] in graph.bfs_nodes(e.src): + # If this copy of the node is above the copy we've seen before, use this one instead. + table[e.src.data] = e.src + graph.remove_edge(e) + # Then, if we still have so that any of the map _writes_ to these nodes, we want to just create fresh copies to + # avoid cycles. + alt_table = {} + for k, v in table.items(): + if v in graph.bfs_nodes(first_entry) or v in graph.bfs_nodes(second_entry): + alt_v = deepcopy(v) + graph.add_node(alt_v) + alt_table[k] = alt_v + else: + alt_table[k] = v + # Finally, these nodes should be depended on by _both_ maps. + for en in [first_entry, second_entry]: + for n in alt_table.values(): + graph.add_memlet_path(n, en, memlet=Memlet()) + + @staticmethod + def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): + """ + If the two maps write to the same underlying data array through two access nodes, replace those edges' + destination with a single shared copy. + + Precondition: + 1. The two maps must not depend on each other through an access node (which should be taken care of already by + `consolidate_empty_dependencies()`. + 2. The two maps must be constistent const assignments (see the class docstring for what is considered + consistent). + """ + # First, construct tables of the surviving and all written access nodes. + surviving_nodes, all_written_nodes = {}, set() + for ex in [first_exit, second_exit]: + for e in graph.out_edges(ex): + assert isinstance(e.dst, AccessNode) + assert not e.data.is_empty() + assert e.src_conn is not None and e.dst_conn is None + all_written_nodes.add(e.dst) + if e.dst.data not in surviving_nodes: + surviving_nodes[e.dst.data] = e.dst + elif e.dst in graph.bfs_nodes(surviving_nodes[e.dst.data]): + # If this copy of the node is above the copy we've seen before, use this one instead. + surviving_nodes[e.dst.data] = e.dst + # Then, redirect all the edges toward the surviving copies of the destination access nodes. + for n in all_written_nodes: + for e in graph.in_edges(n): + assert e.src in [first_exit, second_exit] + assert e.dst_conn is None + graph.add_memlet_path(e.src, surviving_nodes[e.dst.data], + src_conn=e.src_conn, dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + for e in graph.out_edges(n): + assert e.src_conn is None + graph.add_memlet_path(surviving_nodes[e.src.data], e.dst, + src_conn=e.src_conn, dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + # Finally, cleanup the orphan nodes. + for n in all_written_nodes: + if graph.degree(n) == 0: + graph.remove_node(n) + + @staticmethod + def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tuple[MapEntry, MapExit]): + """ + Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical. + """ + dst_en, dst_ex = dst + src_en, src_ex = src + + assert all(e.data.is_empty() for e in graph.in_edges(src_en)) + cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_en, src_en) + for e in graph.in_edges(src_en): + graph.add_memlet_path(e.src, dst_en, + src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + for e in graph.out_edges(src_en): + graph.add_memlet_path(dst_en, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + + cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_ex, src_ex) + for e in graph.in_edges(src_ex): + graph.add_memlet_path(e.src, dst_ex, + src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + for e in graph.out_edges(src_ex): + graph.add_memlet_path(dst_ex, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + + graph.remove_node(src_en) + graph.remove_node(src_ex) + + @staticmethod + def consume_map_with_grid_strided_loop(graph: SDFGState, dst: tuple[MapEntry, MapExit], + src: tuple[MapEntry, MapExit]): + """ + Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop. + Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not + always be possible. + """ + dst_en, dst_ex = dst + src_en, src_ex = src + + def with_start_and_stride(r, start, stride): + r = list(r) + r[0] = start + r[2] = stride + return tuple(r) + + gsl_ranges = [with_start_and_stride(rd, p, rs[1] + 1) + for p, rs, rd in zip(dst_en.map.params, src_en.map.range.ranges, dst_en.map.range.ranges)] + gsl_params = [f"gsl_{p}" for p in dst_en.map.params] + en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), + {k: v for k, v in zip(gsl_params, gsl_ranges)}) + # graph.add_memlet_path(dst_en, en, memlet=Memlet()) + ConstAssignmentMapFusion.consume_map_exactly(graph, (en, ex), src) + # graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) + + assert all(e.data.is_empty() for e in graph.in_edges(en)) + cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_en, en) + for e in graph.in_edges(en): + graph.add_memlet_path(e.src, dst_en, + src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), + memlet=Memlet.from_memlet(e.data)) + graph.add_memlet_path(dst_en, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + + cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_ex, ex) + for e in graph.out_edges(ex): + graph.add_memlet_path(e.src, dst_ex, + src_conn=e.src_conn, + dst_conn=ConstAssignmentMapFusion.connector_counterpart(cmap.get(e.src_conn)), + memlet=Memlet.from_memlet(e.data)) + graph.add_memlet_path(dst_ex, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + def apply(self, graph: SDFGState, sdfg: SDFG): first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) @@ -268,53 +438,29 @@ def apply(self, graph: SDFGState, sdfg: SDFG): is_const_assignment, assignments = self.consistent_const_assignment_table(graph, first_entry, first_exit) assert is_const_assignment - # Keep track in case loop variables are named differently. + # Rename in case loop variables are named differently. param_map = {p2: p1 for p1, p2 in zip(first_entry.map.params, second_entry.map.params)} - - # Track all the access nodes that will survive the purge. - access_nodes, remove_nodes = self.track_access_nodes(graph, first_exit, second_exit) - - # Set up the extra connections on the first node. - conn_map = self.add_equivalent_connectors(first_exit, second_exit) - - # Redirect outgoing edges from exit nodes that are going to be invalidated. - for e in graph.out_edges(first_exit): - array_name = e.dst.data - assert array_name in access_nodes - if access_nodes[array_name] != e.dst: - graph.add_memlet_path(first_exit, access_nodes[array_name], src_conn=e.src_conn, dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - for e in graph.out_edges(second_exit): - array_name = e.dst.data - assert array_name in access_nodes - graph.add_memlet_path(first_exit, access_nodes[array_name], src_conn=conn_map[e.src_conn], - dst_conn=e.dst_conn, memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - - # Move the tasklets from the second map into the first map. - second_tasklets = graph.all_nodes_between(second_entry, second_exit) - for t in second_tasklets: - for e in graph.in_edges(t): - graph.add_memlet_path(first_entry, t, memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) + for t in graph.all_nodes_between(second_entry, second_exit): for e in graph.out_edges(t): - e_data = e.data - e_data.subset.replace(param_map) - graph.add_memlet_path(e.src, first_exit, src_conn=e.src_conn, dst_conn=conn_map[e.dst_conn], - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - - # Redirect any outgoing edges from the nodes to be removed through their surviving counterparts. - for n in remove_nodes: - for e in graph.out_edges(n): - if e.dst != second_entry: - alt_n = access_nodes[n.data] - memlet = Memlet.from_memlet(e.data) - graph.add_memlet_path(alt_n, e.dst, src_conn=e.src_conn, dst_conn=e.dst_conn, memlet=memlet) - graph.remove_node(n) - graph.remove_node(second_entry) - graph.remove_node(second_exit) + e.data.subset.replace(param_map) + second_entry.map.params = first_entry.map.params + + # Consolidate the incoming dependencies of the two maps. + self.consolidate_empty_dependencies(graph, first_entry, second_entry) + + # Consolidate the written access nodes of the two maps. + self.consolidate_written_nodes(graph, first_exit, second_exit) + + # If the ranges are identical, then simply fuse the two maps. Otherwise, use grid-strided loops. + en, ex = graph.add_map(sdfg._find_new_name('map_fusion_wrapper'), + {k: v for k, v in zip(first_entry.map.params, + subsets.union(first_entry.map.range, second_entry.map.range))}, + schedule=first_entry.map.schedule) + for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: + if en.map.range.covers(cur_en.map.range): + self.consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) + else: + self.consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) class ConstAssignmentStateFusion(StateFusionExtended): @@ -329,23 +475,30 @@ class ConstAssignmentStateFusion(StateFusionExtended): # NOTE: `expression()` is inherited. - def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: # All the basic rules apply. if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False st0, st1 = self.first_state, self.second_state - # Moreover, each state must contain just one constant assignment map. + # Moreover, the states together must contain a consistent constant assignment map. + assignments = {} for st in [st0, st1]: en_ex = unique_map_node(st) if not en_ex: return False en, ex = en_ex - if len(st.in_edges(en)) != 0: + if any(not e.data.is_empty for e in st.in_edges(en)): return False - is_const_assignment, assignments = ConstAssignmentMapFusion.consistent_const_assignment_table(st, en, ex) + is_const_assignment, further_assignments = ConstAssignmentMapFusion.consistent_const_assignment_table(st, + en, + ex) if not is_const_assignment: return False + for k, v in further_assignments.items(): + if k in assignments and v != assignments[k]: + return False + assignments[k] = v # Moreover, both states' ranges must be compatible. if not ConstAssignmentMapFusion.compatible_range(unique_map_node(st0)[0], unique_map_node(st1)[0]): @@ -357,5 +510,6 @@ def apply(self, graph: SDFGState, sdfg: SDFG): # First, fuse the two states. super().apply(graph, sdfg) sdfg.validate() + # Then, fuse the maps inside. sdfg.apply_transformations_repeated(ConstAssignmentMapFusion) sdfg.validate() diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 97f462a9b4..dba0128ae2 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -7,6 +7,7 @@ from dace.transformation.dataflow.const_assignment_fusion import ConstAssignmentMapFusion, ConstAssignmentStateFusion from dace.transformation.interstate import StateFusionExtended +K = dace.symbol('K') M = dace.symbol('M') N = dace.symbol('N') @@ -151,6 +152,47 @@ def test_free_floating_fusion(): assert np.allclose(our_A, actual_A) +@dace.program +def assign_top_face(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:M, 0:N]: + A[0, t1, t2] = 1 + + +@dace.program +def assign_bottom_face(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:M, 0:N]: + A[K - 1, t1, t2] = 1 + + +@dace.program +def assign_bounary_3d(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): + assign_top_face(A) + assign_bottom_face(B) + + +def test_fusion_with_multiple_indices(): + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_bounary_3d.to_sdfg(simplify=True, validate=True, use_cache=False) + g.save(os.path.join('_dacegraphs', '3d-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + actual_B = deepcopy(B) + g(A=actual_A, B=actual_B, K=3, M=4, N=5) + + g.apply_transformations(ConstAssignmentMapFusion) + g.save(os.path.join('_dacegraphs', '3d-1.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, K=3, M=4, N=5) + + # print(our_A) + assert np.allclose(our_A, actual_A) + + @dace.program def assign_bounary_with_branch(A: dace.float32[M, N], B: dace.float32[M, N]): assign_top_row_branched(A) @@ -185,3 +227,4 @@ def test_fusion_with_branch(): test_interstate_fusion() test_free_floating_fusion() test_fusion_with_branch() + test_fusion_with_multiple_indices() From 06cfcab6b194cce394b22f85ea6f64d163b1fe7b Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Fri, 11 Oct 2024 14:41:53 +0200 Subject: [PATCH 10/29] Fix the bound for grid strided loop. --- dace/transformation/dataflow/const_assignment_fusion.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index b142a1fc5c..2268b14216 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -393,13 +393,14 @@ def consume_map_with_grid_strided_loop(graph: SDFGState, dst: tuple[MapEntry, Ma dst_en, dst_ex = dst src_en, src_ex = src - def with_start_and_stride(r, start, stride): + def range_for_grid_stride(r, val, bound): r = list(r) - r[0] = start - r[2] = stride + r[0] = val + r[1] = bound - 1 + r[2] = bound return tuple(r) - gsl_ranges = [with_start_and_stride(rd, p, rs[1] + 1) + gsl_ranges = [range_for_grid_stride(rd, p, rs[1] + 1) for p, rs, rd in zip(dst_en.map.params, src_en.map.range.ranges, dst_en.map.range.ranges)] gsl_params = [f"gsl_{p}" for p in dst_en.map.params] en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), From 39ee8439c0b09fbe5f838f43f118bc3e27135ee9 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Fri, 11 Oct 2024 15:08:07 +0200 Subject: [PATCH 11/29] Set the schedule type to the inner GSLs to sequential. --- dace/transformation/dataflow/const_assignment_fusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 2268b14216..e24bd4d753 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -4,7 +4,7 @@ from itertools import chain from typing import Optional, Union -from dace import transformation, SDFGState, SDFG, Memlet, subsets +from dace import transformation, SDFGState, SDFG, Memlet, subsets, ScheduleType from dace.sdfg import nodes from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode @@ -404,7 +404,8 @@ def range_for_grid_stride(r, val, bound): for p, rs, rd in zip(dst_en.map.params, src_en.map.range.ranges, dst_en.map.range.ranges)] gsl_params = [f"gsl_{p}" for p in dst_en.map.params] en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), - {k: v for k, v in zip(gsl_params, gsl_ranges)}) + {k: v for k, v in zip(gsl_params, gsl_ranges)}, + schedule=ScheduleType.Sequential) # graph.add_memlet_path(dst_en, en, memlet=Memlet()) ConstAssignmentMapFusion.consume_map_exactly(graph, (en, ex), src) # graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) From d2e6f9f3cba70c71d3521305129a92775558362a Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Fri, 11 Oct 2024 15:58:51 +0200 Subject: [PATCH 12/29] Allow multiple nested maps, by starting to look at only the outermost one. --- .../dataflow/const_assignment_fusion.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index e24bd4d753..ddca7dd241 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -14,10 +14,12 @@ from dace.transformation.interstate import StateFusionExtended -def unique_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]: - all_nodes = list(graph.all_nodes_recursive()) - en: list[MapEntry] = [n for n, _ in all_nodes if isinstance(n, MapEntry)] - ex = [n for n, _ in all_nodes if isinstance(n, MapExit)] +def unique_top_level_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]: + all_top_nodes = [n for n, s in graph.scope_dict().items() if s is None] + if not all(isinstance(n, (MapEntry, AccessNode)) for n in all_top_nodes): + return None + en: list[MapEntry] = [n for n in all_top_nodes if isinstance(n, MapEntry)] + ex: list[MapExit] = [graph.exit_node(n) for n in all_top_nodes if isinstance(n, MapEntry)] if len(en) != 1 or len(ex) != 1: return None return en[0], ex[0] @@ -152,6 +154,19 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi return False, table table[dst] = internal_table[oe.src_conn] table[dst_arr] = internal_table[oe.src_conn] + elif isinstance(n, MapEntry): + is_const_assignment, internal_table = ConstAssignmentMapFusion.consistent_const_assignment_table(graph, + n, + graph.exit_node( + n)) + if not is_const_assignment: + return False, table + for k, v in internal_table.items(): + if k in table and v != table[k]: + return False, table + internal_table[k] = v + elif isinstance(n, MapExit): + pass # Handled with `MapEntry` else: # Each of the nodes in this map must be... if not isinstance(n, Tasklet): @@ -486,7 +501,7 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, # Moreover, the states together must contain a consistent constant assignment map. assignments = {} for st in [st0, st1]: - en_ex = unique_map_node(st) + en_ex = unique_top_level_map_node(st) if not en_ex: return False en, ex = en_ex @@ -503,7 +518,8 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, assignments[k] = v # Moreover, both states' ranges must be compatible. - if not ConstAssignmentMapFusion.compatible_range(unique_map_node(st0)[0], unique_map_node(st1)[0]): + if not ConstAssignmentMapFusion.compatible_range(unique_top_level_map_node(st0)[0], + unique_top_level_map_node(st1)[0]): return False return True From 58b291f68e183935e78baeb497994195af1efed0 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Fri, 11 Oct 2024 17:02:32 +0200 Subject: [PATCH 13/29] Rewrite the fusion pattern to allow fusing nested maps if possible. --- .../dataflow/const_assignment_fusion.py | 64 +++++++++++++------ 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index ddca7dd241..7f0baf26aa 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -9,7 +9,6 @@ from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion -from dace.sdfg.utils import node_path_graph from dace.transformation.dataflow import MapFusion from dace.transformation.interstate import StateFusionExtended @@ -49,18 +48,13 @@ class ConstAssignmentMapFusion(MapFusion): - Neither map is dependent on the other. I.e. There is no dependency path between them. """ first_map_entry = transformation.PatternNode(nodes.EntryNode) - first_map_exit = transformation.PatternNode(nodes.ExitNode) - array = transformation.PatternNode(nodes.AccessNode) second_map_entry = transformation.PatternNode(nodes.EntryNode) - second_map_exit = transformation.PatternNode(nodes.ExitNode) @classmethod def expressions(cls): - # TODO(pratyai): Probably a better pattern idea: take any two maps, then check that _every_ path from the first - # map to second map has exactly one access node in the middle and the second edge of the path is empty. - return [node_path_graph(cls.first_map_exit, cls.array, cls.second_map_entry), - floating_nodes_graph(cls.first_map_entry, cls.first_map_exit, cls.second_map_entry, - cls.second_map_exit)] + # Take any two maps, then check that _every_ path from the first map to second map has exactly one access node + # in the middle and the second edge of the path is empty. + return [floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)] @staticmethod def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: @@ -192,7 +186,7 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi def map_nodes(self, graph: SDFGState): """Return the entry and exit nodes of the relevant maps as a tuple: entry_1, exit_1, entry_2, exit_2.""" - return (graph.entry_node(self.first_map_exit), self.first_map_exit, + return (self.first_map_entry, graph.exit_node(self.first_map_entry), self.second_map_entry, graph.exit_node(self.second_map_entry)) @staticmethod @@ -205,27 +199,44 @@ def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: # If it's not even possible to take component-wise union of the two map's range, don't fuse them. # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. return False + if first_entry.map.schedule == ScheduleType.Sequential: + # For _grid-strided loops_, fuse them only when their ranges are _exactly_ the same. I.e., never put them + # behind another layer of grid-strided loop. + if first_entry.map.range != second_entry.map.range: + return False return True - def no_dependency_pattern(self, graph: SDFGState, sdfg: SDFG, permissive: bool = False) -> bool: + def no_dependency_pattern(self, graph: SDFGState) -> bool: """Decide if the two maps are independent of each other.""" first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) - if first_entry != self.first_map_entry or second_exit != self.second_map_exit: + if graph.scope_dict()[first_entry] != graph.scope_dict()[second_entry]: + return False + if not all(isinstance(n, AccessNode) for n in graph.all_nodes_between(first_exit, second_entry)): + return False + if not all(isinstance(n, AccessNode) for n in graph.all_nodes_between(second_exit, first_entry)): return False if any(not e.data.is_empty() for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))): return False - if any(not isinstance(e.src, AccessNode) + if any(not isinstance(e.src, (MapEntry, AccessNode)) for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))): return False + if not (all(isinstance(e.src, AccessNode) + for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))) + or all(isinstance(e.src, MapEntry) + for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry)))): + return False + if not (all(isinstance(e.dst, AccessNode) + for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit))) + or all(isinstance(e.dst, MapExit) + for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit)))): + return False return True def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: - assert expr_index in (0, 1) - if expr_index == 1: - # Test the rest of the second pattern in the `expressions()`. - if not self.no_dependency_pattern(graph, sdfg, permissive): - return False + # Test the rest of the second pattern in the `expressions()`. + if not self.no_dependency_pattern(graph): + return False first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) if not self.compatible_range(first_entry, second_entry): @@ -292,9 +303,10 @@ def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, seco table = {} for en in [first_entry, second_entry]: for e in graph.in_edges(en): - assert isinstance(e.src, AccessNode) assert e.data.is_empty() assert e.src_conn is None and e.dst_conn is None + if not isinstance(e.src, AccessNode): + continue if e.src.data not in table: table[e.src.data] = e.src elif table[e.src.data] in graph.bfs_nodes(e.src): @@ -332,9 +344,10 @@ def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit surviving_nodes, all_written_nodes = {}, set() for ex in [first_exit, second_exit]: for e in graph.out_edges(ex): - assert isinstance(e.dst, AccessNode) assert not e.data.is_empty() - assert e.src_conn is not None and e.dst_conn is None + assert e.src_conn is not None and ((e.dst_conn is None) == isinstance(e.dst, AccessNode)) + if not isinstance(e.dst, AccessNode): + continue all_written_nodes.add(e.dst) if e.dst.data not in surviving_nodes: surviving_nodes[e.dst.data] = e.dst @@ -479,6 +492,15 @@ def apply(self, graph: SDFGState, sdfg: SDFG): else: self.consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) + # Cleanup: remove duplicate empty dependencies. + seen = set() + for e in graph.in_edges(en): + assert e.data.is_empty() + if e.src not in seen: + seen.add(e.src) + else: + graph.remove_edge(e) + class ConstAssignmentStateFusion(StateFusionExtended): """ From b2f950b59d257ab845dc56bc3ae9481769f9f0f5 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Fri, 11 Oct 2024 23:49:34 +0200 Subject: [PATCH 14/29] move out the `@staticmethod`s out of the class definiton. --- .../dataflow/const_assignment_fusion.py | 721 +++++++++--------- 1 file changed, 356 insertions(+), 365 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 7f0baf26aa..b37364c99e 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -5,7 +5,6 @@ from typing import Optional, Union from dace import transformation, SDFGState, SDFG, Memlet, subsets, ScheduleType -from dace.sdfg import nodes from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion @@ -31,91 +30,120 @@ def floating_nodes_graph(*args): return g -class ConstAssignmentMapFusion(MapFusion): +def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: """ - Fuses two maps within a state, where each map: - 1. Either assigns consistent constant values to elements of one or more data arrays. - - Consisency: The values must be the same for all elements in a data array (in both maps). But different data - arrays are allowed to have different values. - 2. Or assigns constant values as described earlier, but _conditionally_. The condition must only depend on the map - Parameters. - - Further conditions: - 1. Range compatibility: The two map must have the exact same range. - # TODO(pratyai): Generalize this in `compatible_range()`. - 2. The maps must have one of the following patterns. - - Exists a path like: MapExit -> AccessNode -> MapEntry - - Neither map is dependent on the other. I.e. There is no dependency path between them. + If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays + and memlets to their consistent constant assignments. See the class docstring for what is considered consistent. """ - first_map_entry = transformation.PatternNode(nodes.EntryNode) - second_map_entry = transformation.PatternNode(nodes.EntryNode) - - @classmethod - def expressions(cls): - # Take any two maps, then check that _every_ path from the first map to second map has exactly one access node - # in the middle and the second edge of the path is empty. - return [floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)] - - @staticmethod - def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: - """ - If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays - and memlets to their consistent constant assignments. See the class docstring for what is considered consistent. - """ - table = {} - # Basic premise check. - if not isinstance(graph, NestedSDFG): - return False, table - graph: SDFG = graph.sdfg - if not isinstance(graph, ControlFlowBlock): - return False, table + table = {} + # Basic premise check. + if not isinstance(graph, NestedSDFG): + return False, table + graph: SDFG = graph.sdfg + if not isinstance(graph, ControlFlowBlock): + return False, table + + # Must have exactly 3 nodes, and exactly one of them a source, another a sink. + src, snk = graph.source_nodes(), graph.sink_nodes() + if len(graph.nodes()) != 3 or len(src) != 1 or len(snk) != 1: + return False, table + src, snk = src[0], snk[0] + body = set(graph.nodes()) - {src, snk} + if len(body) != 1: + return False, table + body = list(body)[0] + + # Must have certain structure of outgoing edges. + src_eds = list(graph.out_edges(src)) + if len(src_eds) != 2 or any(e.data.is_unconditional() or e.data.assignments for e in src_eds): + return False, table + tb, el = src_eds + if tb.dst != body: + tb, el = el, tb + if tb.dst != body or el.dst != snk: + return False, table + body_eds = list(graph.out_edges(body)) + if len(body_eds) != 1 or body_eds[0].dst != snk or not body_eds[0].data.is_unconditional() or body_eds[ + 0].data.assignments: + return False, table + + # Branch conditions must depend only on the loop variables. + for b in [tb, el]: + cond = b.data.condition + for c in cond.code: + used = set([ast_node.id for ast_node in ast.walk(c) if isinstance(ast_node, ast.Name)]) + if not used.issubset(graph.free_symbols): + return False, table - # Must have exactly 3 nodes, and exactly one of them a source, another a sink. - src, snk = graph.source_nodes(), graph.sink_nodes() - if len(graph.nodes()) != 3 or len(src) != 1 or len(snk) != 1: + # Body must have only constant assignments. + for n, _ in body.all_nodes_recursive(): + # Each tasklet in this box... + if not isinstance(n, Tasklet): + continue + if len(n.code.code) != 1 or not isinstance(n.code.code[0], ast.Assign): + # ...must assign... return False, table - src, snk = src[0], snk[0] - body = set(graph.nodes()) - {src, snk} - if len(body) != 1: + op = n.code.code[0] + if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: + # ...a constant to a single target. return False, table - body = list(body)[0] + const = op.value.value + for oe in body.out_edges(n): + dst = oe.data + dst_arr = oe.data.data + if dst_arr in table and table[dst_arr] != const: + # A target array can appear multiple times, but it must always be consistently assigned. + return False, table + table[dst] = const + table[dst_arr] = const + return True, table - # Must have certain structure of outgoing edges. - src_eds = list(graph.out_edges(src)) - if len(src_eds) != 2 or any(e.data.is_unconditional() or e.data.assignments for e in src_eds): - return False, table - tb, el = src_eds - if tb.dst != body: - tb, el = el, tb - if tb.dst != body or el.dst != snk: - return False, table - body_eds = list(graph.out_edges(body)) - if len(body_eds) != 1 or body_eds[0].dst != snk or not body_eds[0].data.is_unconditional() or body_eds[ - 0].data.assignments: - return False, table - # Branch conditions must depend only on the loop variables. - for b in [tb, el]: - cond = b.data.condition - for c in cond.code: - used = set([ast_node.id for ast_node in ast.walk(c) if isinstance(ast_node, ast.Name)]) - if not used.issubset(graph.free_symbols): +def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: + """ + If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table + mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is + considered consistent. + """ + table = {} + for n in graph.all_nodes_between(en, ex): + if isinstance(n, NestedSDFG): + # First handle the case of conditional constant assignment. + is_branch_const_assignment, internal_table = consistent_branch_const_assignment_table(n) + if not is_branch_const_assignment: + return False, table + for oe in graph.out_edges(n): + dst = oe.data + dst_arr = oe.data.data + if dst_arr in table and table[dst_arr] != internal_table[oe.src_conn]: + # A target array can appear multiple times, but it must always be consistently assigned. return False, table - - # Body must have only constant assignments. - for n, _ in body.all_nodes_recursive(): - # Each tasklet in this box... + table[dst] = internal_table[oe.src_conn] + table[dst_arr] = internal_table[oe.src_conn] + elif isinstance(n, MapEntry): + is_const_assignment, internal_table = consistent_const_assignment_table(graph, n, graph.exit_node(n)) + if not is_const_assignment: + return False, table + for k, v in internal_table.items(): + if k in table and v != table[k]: + return False, table + internal_table[k] = v + elif isinstance(n, MapExit): + pass # Handled with `MapEntry` + else: + # Each of the nodes in this map must be... if not isinstance(n, Tasklet): - continue + # ...a tasklet... + return False, table if len(n.code.code) != 1 or not isinstance(n.code.code[0], ast.Assign): - # ...must assign... + # ...that assigns... return False, table op = n.code.code[0] if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: # ...a constant to a single target. return False, table const = op.value.value - for oe in body.out_edges(n): + for oe in graph.out_edges(n): dst = oe.data dst_arr = oe.data.data if dst_arr in table and table[dst_arr] != const: @@ -123,89 +151,260 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: return False, table table[dst] = const table[dst_arr] = const - return True, table - - @staticmethod - def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: - """ - If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table - mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is - considered consistent. - """ - table = {} - for n in graph.all_nodes_between(en, ex): - if isinstance(n, NestedSDFG): - # First handle the case of conditional constant assignment. - is_branch_const_assignment, internal_table = ConstAssignmentMapFusion.consistent_branch_const_assignment_table( - n) - if not is_branch_const_assignment: - return False, table - for oe in graph.out_edges(n): - dst = oe.data - dst_arr = oe.data.data - if dst_arr in table and table[dst_arr] != internal_table[oe.src_conn]: - # A target array can appear multiple times, but it must always be consistently assigned. - return False, table - table[dst] = internal_table[oe.src_conn] - table[dst_arr] = internal_table[oe.src_conn] - elif isinstance(n, MapEntry): - is_const_assignment, internal_table = ConstAssignmentMapFusion.consistent_const_assignment_table(graph, - n, - graph.exit_node( - n)) - if not is_const_assignment: - return False, table - for k, v in internal_table.items(): - if k in table and v != table[k]: - return False, table - internal_table[k] = v - elif isinstance(n, MapExit): - pass # Handled with `MapEntry` - else: - # Each of the nodes in this map must be... - if not isinstance(n, Tasklet): - # ...a tasklet... - return False, table - if len(n.code.code) != 1 or not isinstance(n.code.code[0], ast.Assign): - # ...that assigns... - return False, table - op = n.code.code[0] - if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: - # ...a constant to a single target. - return False, table - const = op.value.value - for oe in graph.out_edges(n): - dst = oe.data - dst_arr = oe.data.data - if dst_arr in table and table[dst_arr] != const: - # A target array can appear multiple times, but it must always be consistently assigned. - return False, table - table[dst] = const - table[dst_arr] = const - return True, table + return True, table + + +def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): + """ + Create the additional connectors in the first exit node that matches the second exit node (which will be removed + later). + """ + conn_map = defaultdict() + for c, v in src.in_connectors.items(): + assert c.startswith('IN_') + cbase = c.removeprefix('IN_') + sc = dst.next_connector(cbase) + conn_map[f"IN_{cbase}"] = f"IN_{sc}" + conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" + dst.add_in_connector(f"IN_{sc}", dtype=v) + dst.add_out_connector(f"OUT_{sc}", dtype=v) + for c, v in src.out_connectors.items(): + assert c in conn_map + return conn_map + + +def connector_counterpart(c: Union[str, None]) -> Union[str, None]: + """If it's an input connector, find the corresponding output connector, and vice versa.""" + if c is None: + return None + assert isinstance(c, str) + if c.startswith('IN_'): + return f"OUT_{c.removeprefix('IN_')}" + elif c.startswith('OUT_'): + return f"IN_{c.removeprefix('OUT_')}" + return None + + +def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry): + """ + Remove all the incoming edges of the two maps and add empty edges from the union of the access nodes they + depended on before. + + Preconditions: + 1. All the incoming edges of the two maps must be from an access node and empty (i.e. have existed + only for synchronization). + 2. The two maps must be constistent const assignments (see the class docstring for what is considered + consistent). + """ + # First, construct a table of the dependencies. + table = {} + for en in [first_entry, second_entry]: + for e in graph.in_edges(en): + assert e.data.is_empty() + assert e.src_conn is None and e.dst_conn is None + if not isinstance(e.src, AccessNode): + continue + if e.src.data not in table: + table[e.src.data] = e.src + elif table[e.src.data] in graph.bfs_nodes(e.src): + # If this copy of the node is above the copy we've seen before, use this one instead. + table[e.src.data] = e.src + graph.remove_edge(e) + # Then, if we still have so that any of the map _writes_ to these nodes, we want to just create fresh copies to + # avoid cycles. + alt_table = {} + for k, v in table.items(): + if v in graph.bfs_nodes(first_entry) or v in graph.bfs_nodes(second_entry): + alt_v = deepcopy(v) + graph.add_node(alt_v) + alt_table[k] = alt_v + else: + alt_table[k] = v + # Finally, these nodes should be depended on by _both_ maps. + for en in [first_entry, second_entry]: + for n in alt_table.values(): + graph.add_memlet_path(n, en, memlet=Memlet()) + + +def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): + """ + If the two maps write to the same underlying data array through two access nodes, replace those edges' + destination with a single shared copy. + + Precondition: + 1. The two maps must not depend on each other through an access node, which should be taken care of already by + `consolidate_empty_dependencies()`. + 2. The two maps must be constistent const assignments (see the class docstring for what is considered + consistent). + """ + # First, construct tables of the surviving and all written access nodes. + surviving_nodes, all_written_nodes = {}, set() + for ex in [first_exit, second_exit]: + for e in graph.out_edges(ex): + assert not e.data.is_empty() + assert e.src_conn is not None and ((e.dst_conn is None) == isinstance(e.dst, AccessNode)) + if not isinstance(e.dst, AccessNode): + continue + all_written_nodes.add(e.dst) + if e.dst.data not in surviving_nodes: + surviving_nodes[e.dst.data] = e.dst + elif e.dst in graph.bfs_nodes(surviving_nodes[e.dst.data]): + # If this copy of the node is above the copy we've seen before, use this one instead. + surviving_nodes[e.dst.data] = e.dst + # Then, redirect all the edges toward the surviving copies of the destination access nodes. + for n in all_written_nodes: + for e in graph.in_edges(n): + assert e.src in [first_exit, second_exit] + assert e.dst_conn is None + graph.add_memlet_path(e.src, surviving_nodes[e.dst.data], + src_conn=e.src_conn, dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + for e in graph.out_edges(n): + assert e.src_conn is None + graph.add_memlet_path(surviving_nodes[e.src.data], e.dst, + src_conn=e.src_conn, dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + # Finally, cleanup the orphan nodes. + for n in all_written_nodes: + if graph.degree(n) == 0: + graph.remove_node(n) + + +def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tuple[MapEntry, MapExit]): + """ + Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical. + """ + dst_en, dst_ex = dst + src_en, src_ex = src + + assert all(e.data.is_empty() for e in graph.in_edges(src_en)) + cmap = add_equivalent_connectors(dst_en, src_en) + for e in graph.in_edges(src_en): + graph.add_memlet_path(e.src, dst_en, + src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + for e in graph.out_edges(src_en): + graph.add_memlet_path(dst_en, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + + cmap = add_equivalent_connectors(dst_ex, src_ex) + for e in graph.in_edges(src_ex): + graph.add_memlet_path(e.src, dst_ex, + src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + for e in graph.out_edges(src_ex): + graph.add_memlet_path(dst_ex, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + + graph.remove_node(src_en) + graph.remove_node(src_ex) + + +def consume_map_with_grid_strided_loop(graph: SDFGState, dst: tuple[MapEntry, MapExit], + src: tuple[MapEntry, MapExit]): + """ + Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop. + Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not + always be possible. + """ + dst_en, dst_ex = dst + src_en, src_ex = src + + def range_for_grid_stride(r, val, bound): + r = list(r) + r[0] = val + r[1] = bound - 1 + r[2] = bound + return tuple(r) + + gsl_ranges = [range_for_grid_stride(rd, p, rs[1] + 1) + for p, rs, rd in zip(dst_en.map.params, src_en.map.range.ranges, dst_en.map.range.ranges)] + gsl_params = [f"gsl_{p}" for p in dst_en.map.params] + en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), + {k: v for k, v in zip(gsl_params, gsl_ranges)}, + schedule=ScheduleType.Sequential) + # graph.add_memlet_path(dst_en, en, memlet=Memlet()) + consume_map_exactly(graph, (en, ex), src) + # graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) + + assert all(e.data.is_empty() for e in graph.in_edges(en)) + cmap = add_equivalent_connectors(dst_en, en) + for e in graph.in_edges(en): + graph.add_memlet_path(e.src, dst_en, + src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), + memlet=Memlet.from_memlet(e.data)) + graph.add_memlet_path(dst_en, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + + cmap = add_equivalent_connectors(dst_ex, ex) + for e in graph.out_edges(ex): + graph.add_memlet_path(e.src, dst_ex, + src_conn=e.src_conn, + dst_conn=connector_counterpart(cmap.get(e.src_conn)), + memlet=Memlet.from_memlet(e.data)) + graph.add_memlet_path(dst_ex, e.dst, + src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, + memlet=Memlet.from_memlet(e.data)) + graph.remove_edge(e) + + +def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: + """Decide if the two ranges are compatible. See the class docstring for what is considered compatible.""" + if first_entry.map.schedule != second_entry.map.schedule: + # If the two maps are not to be scheduled on the same device, don't fuse them. + return False + if len(first_entry.map.range) != len(second_entry.map.range): + # If it's not even possible to take component-wise union of the two map's range, don't fuse them. + # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. + return False + if first_entry.map.schedule == ScheduleType.Sequential: + # For _grid-strided loops_, fuse them only when their ranges are _exactly_ the same. I.e., never put them + # behind another layer of grid-strided loop. + if first_entry.map.range != second_entry.map.range: + return False + return True + + +class ConstAssignmentMapFusion(MapFusion): + """ + Fuses two maps within a state, where each map: + 1. Either assigns consistent constant values to elements of one or more data arrays. + - Consisency: The values must be the same for all elements in a data array (in both maps). But different data + arrays are allowed to have different values. + 2. Or assigns constant values as described earlier, but _conditionally_. The condition must only depend on the map + Parameters. + + Further conditions: + 1. Range compatibility: The two map must have the exact same range. + # TODO(pratyai): Generalize this in `compatible_range()`. + 2. The maps must have one of the following patterns. + - Exists a path like: MapExit -> AccessNode -> MapEntry + - Neither map is dependent on the other. I.e. There is no dependency path between them. + """ + first_map_entry = transformation.PatternNode(MapEntry) + second_map_entry = transformation.PatternNode(MapEntry) + + @classmethod + def expressions(cls): + # Take any two maps, then check that _every_ path from the first map to second map has exactly one access node + # in the middle and the second edge of the path is empty. + return [floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)] def map_nodes(self, graph: SDFGState): """Return the entry and exit nodes of the relevant maps as a tuple: entry_1, exit_1, entry_2, exit_2.""" return (self.first_map_entry, graph.exit_node(self.first_map_entry), self.second_map_entry, graph.exit_node(self.second_map_entry)) - @staticmethod - def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: - """Decide if the two ranges are compatible. See the class docstring for what is considered compatible.""" - if first_entry.map.schedule != second_entry.map.schedule: - # If the two maps are not to be scheduled on the same device, don't fuse them. - return False - if len(first_entry.map.range) != len(second_entry.map.range): - # If it's not even possible to take component-wise union of the two map's range, don't fuse them. - # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. - return False - if first_entry.map.schedule == ScheduleType.Sequential: - # For _grid-strided loops_, fuse them only when their ranges are _exactly_ the same. I.e., never put them - # behind another layer of grid-strided loop. - if first_entry.map.range != second_entry.map.range: - return False - return True - def no_dependency_pattern(self, graph: SDFGState) -> bool: """Decide if the two maps are independent of each other.""" first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) @@ -239,15 +438,14 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return False first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) - if not self.compatible_range(first_entry, second_entry): + if not compatible_range(first_entry, second_entry): return False # Both maps must have consistent constant assignment for the target arrays. - is_const_assignment, assignments = self.consistent_const_assignment_table(graph, first_entry, first_exit) + is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit) if not is_const_assignment: return False - is_const_assignment, further_assignments = self.consistent_const_assignment_table(graph, second_entry, - second_exit) + is_const_assignment, further_assignments = consistent_const_assignment_table(graph, second_entry, second_exit) if not is_const_assignment: return False for k, v in further_assignments.items(): @@ -256,216 +454,12 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi assignments[k] = v return True - @staticmethod - def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): - """ - Create the additional connectors in the first exit node that matches the second exit node (which will be removed - later). - """ - conn_map = defaultdict() - for c, v in src.in_connectors.items(): - assert c.startswith('IN_') - cbase = c.removeprefix('IN_') - sc = dst.next_connector(cbase) - conn_map[f"IN_{cbase}"] = f"IN_{sc}" - conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" - dst.add_in_connector(f"IN_{sc}", dtype=v) - dst.add_out_connector(f"OUT_{sc}", dtype=v) - for c, v in src.out_connectors.items(): - assert c in conn_map - return conn_map - - @staticmethod - def connector_counterpart(c: Union[str, None]) -> Union[str, None]: - """If it's an input connector, find the corresponding output connector, and vice versa.""" - if c is None: - return None - assert isinstance(c, str) - if c.startswith('IN_'): - return f"OUT_{c.removeprefix('IN_')}" - elif c.startswith('OUT_'): - return f"IN_{c.removeprefix('OUT_')}" - return None - - @staticmethod - def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry): - """ - Remove all the incoming edges of the two maps and add empty edges from the union of the access nodes they - depended on before. - - Preconditions: - 1. All the incoming edges of the two maps must be from an access node and empty (i.e. have existed - only for synchronization). - 2. The two maps must be constistent const assignments (see the class docstring for what is considered - consistent). - """ - # First, construct a table of the dependencies. - table = {} - for en in [first_entry, second_entry]: - for e in graph.in_edges(en): - assert e.data.is_empty() - assert e.src_conn is None and e.dst_conn is None - if not isinstance(e.src, AccessNode): - continue - if e.src.data not in table: - table[e.src.data] = e.src - elif table[e.src.data] in graph.bfs_nodes(e.src): - # If this copy of the node is above the copy we've seen before, use this one instead. - table[e.src.data] = e.src - graph.remove_edge(e) - # Then, if we still have so that any of the map _writes_ to these nodes, we want to just create fresh copies to - # avoid cycles. - alt_table = {} - for k, v in table.items(): - if v in graph.bfs_nodes(first_entry) or v in graph.bfs_nodes(second_entry): - alt_v = deepcopy(v) - graph.add_node(alt_v) - alt_table[k] = alt_v - else: - alt_table[k] = v - # Finally, these nodes should be depended on by _both_ maps. - for en in [first_entry, second_entry]: - for n in alt_table.values(): - graph.add_memlet_path(n, en, memlet=Memlet()) - - @staticmethod - def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): - """ - If the two maps write to the same underlying data array through two access nodes, replace those edges' - destination with a single shared copy. - - Precondition: - 1. The two maps must not depend on each other through an access node (which should be taken care of already by - `consolidate_empty_dependencies()`. - 2. The two maps must be constistent const assignments (see the class docstring for what is considered - consistent). - """ - # First, construct tables of the surviving and all written access nodes. - surviving_nodes, all_written_nodes = {}, set() - for ex in [first_exit, second_exit]: - for e in graph.out_edges(ex): - assert not e.data.is_empty() - assert e.src_conn is not None and ((e.dst_conn is None) == isinstance(e.dst, AccessNode)) - if not isinstance(e.dst, AccessNode): - continue - all_written_nodes.add(e.dst) - if e.dst.data not in surviving_nodes: - surviving_nodes[e.dst.data] = e.dst - elif e.dst in graph.bfs_nodes(surviving_nodes[e.dst.data]): - # If this copy of the node is above the copy we've seen before, use this one instead. - surviving_nodes[e.dst.data] = e.dst - # Then, redirect all the edges toward the surviving copies of the destination access nodes. - for n in all_written_nodes: - for e in graph.in_edges(n): - assert e.src in [first_exit, second_exit] - assert e.dst_conn is None - graph.add_memlet_path(e.src, surviving_nodes[e.dst.data], - src_conn=e.src_conn, dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - for e in graph.out_edges(n): - assert e.src_conn is None - graph.add_memlet_path(surviving_nodes[e.src.data], e.dst, - src_conn=e.src_conn, dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - # Finally, cleanup the orphan nodes. - for n in all_written_nodes: - if graph.degree(n) == 0: - graph.remove_node(n) - - @staticmethod - def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tuple[MapEntry, MapExit]): - """ - Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical. - """ - dst_en, dst_ex = dst - src_en, src_ex = src - - assert all(e.data.is_empty() for e in graph.in_edges(src_en)) - cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_en, src_en) - for e in graph.in_edges(src_en): - graph.add_memlet_path(e.src, dst_en, - src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - for e in graph.out_edges(src_en): - graph.add_memlet_path(dst_en, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - - cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_ex, src_ex) - for e in graph.in_edges(src_ex): - graph.add_memlet_path(e.src, dst_ex, - src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - for e in graph.out_edges(src_ex): - graph.add_memlet_path(dst_ex, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - - graph.remove_node(src_en) - graph.remove_node(src_ex) - - @staticmethod - def consume_map_with_grid_strided_loop(graph: SDFGState, dst: tuple[MapEntry, MapExit], - src: tuple[MapEntry, MapExit]): - """ - Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop. - Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not - always be possible. - """ - dst_en, dst_ex = dst - src_en, src_ex = src - - def range_for_grid_stride(r, val, bound): - r = list(r) - r[0] = val - r[1] = bound - 1 - r[2] = bound - return tuple(r) - - gsl_ranges = [range_for_grid_stride(rd, p, rs[1] + 1) - for p, rs, rd in zip(dst_en.map.params, src_en.map.range.ranges, dst_en.map.range.ranges)] - gsl_params = [f"gsl_{p}" for p in dst_en.map.params] - en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), - {k: v for k, v in zip(gsl_params, gsl_ranges)}, - schedule=ScheduleType.Sequential) - # graph.add_memlet_path(dst_en, en, memlet=Memlet()) - ConstAssignmentMapFusion.consume_map_exactly(graph, (en, ex), src) - # graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) - - assert all(e.data.is_empty() for e in graph.in_edges(en)) - cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_en, en) - for e in graph.in_edges(en): - graph.add_memlet_path(e.src, dst_en, - src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), - memlet=Memlet.from_memlet(e.data)) - graph.add_memlet_path(dst_en, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - - cmap = ConstAssignmentMapFusion.add_equivalent_connectors(dst_ex, ex) - for e in graph.out_edges(ex): - graph.add_memlet_path(e.src, dst_ex, - src_conn=e.src_conn, - dst_conn=ConstAssignmentMapFusion.connector_counterpart(cmap.get(e.src_conn)), - memlet=Memlet.from_memlet(e.data)) - graph.add_memlet_path(dst_ex, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) - graph.remove_edge(e) - def apply(self, graph: SDFGState, sdfg: SDFG): first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) # By now, we know that the two maps are compatible, not reading anything, and just blindly writing constants # _consistently_. - is_const_assignment, assignments = self.consistent_const_assignment_table(graph, first_entry, first_exit) + is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit) assert is_const_assignment # Rename in case loop variables are named differently. @@ -476,10 +470,10 @@ def apply(self, graph: SDFGState, sdfg: SDFG): second_entry.map.params = first_entry.map.params # Consolidate the incoming dependencies of the two maps. - self.consolidate_empty_dependencies(graph, first_entry, second_entry) + consolidate_empty_dependencies(graph, first_entry, second_entry) # Consolidate the written access nodes of the two maps. - self.consolidate_written_nodes(graph, first_exit, second_exit) + consolidate_written_nodes(graph, first_exit, second_exit) # If the ranges are identical, then simply fuse the two maps. Otherwise, use grid-strided loops. en, ex = graph.add_map(sdfg._find_new_name('map_fusion_wrapper'), @@ -488,9 +482,9 @@ def apply(self, graph: SDFGState, sdfg: SDFG): schedule=first_entry.map.schedule) for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: if en.map.range.covers(cur_en.map.range): - self.consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) + consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) else: - self.consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) + consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) # Cleanup: remove duplicate empty dependencies. seen = set() @@ -529,9 +523,7 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, en, ex = en_ex if any(not e.data.is_empty for e in st.in_edges(en)): return False - is_const_assignment, further_assignments = ConstAssignmentMapFusion.consistent_const_assignment_table(st, - en, - ex) + is_const_assignment, further_assignments = consistent_const_assignment_table(st, en, ex) if not is_const_assignment: return False for k, v in further_assignments.items(): @@ -540,8 +532,7 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, assignments[k] = v # Moreover, both states' ranges must be compatible. - if not ConstAssignmentMapFusion.compatible_range(unique_top_level_map_node(st0)[0], - unique_top_level_map_node(st1)[0]): + if not compatible_range(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0]): return False return True From 6ef22b793f9ee0dbc3212710dc7c8f6f8f76cc83 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Sun, 13 Oct 2024 21:50:09 +0200 Subject: [PATCH 15/29] Make the "grid-strided loop" a configurable option. --- .../dataflow/const_assignment_fusion.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index b37364c99e..4fcd722519 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -5,6 +5,7 @@ from typing import Optional, Union from dace import transformation, SDFGState, SDFG, Memlet, subsets, ScheduleType +from dace.properties import make_properties, Property from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion @@ -358,7 +359,7 @@ def range_for_grid_stride(r, val, bound): graph.remove_edge(e) -def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: +def compatible_range(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool: """Decide if the two ranges are compatible. See the class docstring for what is considered compatible.""" if first_entry.map.schedule != second_entry.map.schedule: # If the two maps are not to be scheduled on the same device, don't fuse them. @@ -367,6 +368,10 @@ def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: # If it's not even possible to take component-wise union of the two map's range, don't fuse them. # TODO(pratyai): Make it so that a permutation of the ranges, or even an union of the ranges will work. return False + if not use_grid_strided_loops: + # If we don't use grid-strided loops, the two maps' ranges must be identical. + if first_entry.map.range != second_entry.map.range: + return False if first_entry.map.schedule == ScheduleType.Sequential: # For _grid-strided loops_, fuse them only when their ranges are _exactly_ the same. I.e., never put them # behind another layer of grid-strided loop. @@ -375,6 +380,7 @@ def compatible_range(first_entry: MapEntry, second_entry: MapEntry) -> bool: return True +@make_properties class ConstAssignmentMapFusion(MapFusion): """ Fuses two maps within a state, where each map: @@ -394,6 +400,9 @@ class ConstAssignmentMapFusion(MapFusion): first_map_entry = transformation.PatternNode(MapEntry) second_map_entry = transformation.PatternNode(MapEntry) + use_grid_strided_loops = Property(dtype=bool, default=False, + desc='Set to use grid strided loops to use two maps with non-idential ranges.') + @classmethod def expressions(cls): # Take any two maps, then check that _every_ path from the first map to second map has exactly one access node @@ -438,7 +447,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return False first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) - if not compatible_range(first_entry, second_entry): + if not compatible_range(first_entry, second_entry, use_grid_strided_loops=self.use_grid_strided_loops): return False # Both maps must have consistent constant assignment for the target arrays. @@ -483,7 +492,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: if en.map.range.covers(cur_en.map.range): consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) - else: + elif self.use_grid_strided_loops: consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) # Cleanup: remove duplicate empty dependencies. @@ -496,6 +505,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): graph.remove_edge(e) +@make_properties class ConstAssignmentStateFusion(StateFusionExtended): """ If two consecutive states are such that @@ -506,6 +516,9 @@ class ConstAssignmentStateFusion(StateFusionExtended): first_state = transformation.PatternNode(SDFGState) second_state = transformation.PatternNode(SDFGState) + use_grid_strided_loops = Property(dtype=bool, default=False, + desc='Set to use grid strided loops to use two maps with non-idential ranges.') + # NOTE: `expression()` is inherited. def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: @@ -532,7 +545,8 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, assignments[k] = v # Moreover, both states' ranges must be compatible. - if not compatible_range(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0]): + if not compatible_range(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0], + use_grid_strided_loops=self.use_grid_strided_loops): return False return True @@ -542,5 +556,6 @@ def apply(self, graph: SDFGState, sdfg: SDFG): super().apply(graph, sdfg) sdfg.validate() # Then, fuse the maps inside. - sdfg.apply_transformations_repeated(ConstAssignmentMapFusion) + sdfg.apply_transformations_repeated(ConstAssignmentMapFusion, + options={'use_grid_strided_loops': self.use_grid_strided_loops}) sdfg.validate() From 868266e3cb1ee0cfdd054914933a24da4026b342 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Sun, 13 Oct 2024 23:23:04 +0200 Subject: [PATCH 16/29] Fix the range fusion (`subsets.union()` only gives a bounding box with stride 1) --- .../dataflow/const_assignment_fusion.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 4fcd722519..c16cb1f485 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -4,11 +4,12 @@ from itertools import chain from typing import Optional, Union -from dace import transformation, SDFGState, SDFG, Memlet, subsets, ScheduleType +from dace import transformation, SDFGState, SDFG, Memlet, ScheduleType, subsets from dace.properties import make_properties, Property from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.subsets import Range from dace.transformation.dataflow import MapFusion from dace.transformation.interstate import StateFusionExtended @@ -330,11 +331,9 @@ def range_for_grid_stride(r, val, bound): for p, rs, rd in zip(dst_en.map.params, src_en.map.range.ranges, dst_en.map.range.ranges)] gsl_params = [f"gsl_{p}" for p in dst_en.map.params] en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), - {k: v for k, v in zip(gsl_params, gsl_ranges)}, + ndrange={k: v for k, v in zip(gsl_params, gsl_ranges)}, schedule=ScheduleType.Sequential) - # graph.add_memlet_path(dst_en, en, memlet=Memlet()) consume_map_exactly(graph, (en, ex), src) - # graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) assert all(e.data.is_empty() for e in graph.in_edges(en)) cmap = add_equivalent_connectors(dst_en, en) @@ -359,7 +358,26 @@ def range_for_grid_stride(r, val, bound): graph.remove_edge(e) -def compatible_range(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool: +def fused_range(r1: Range, r2: Range) -> Optional[Range]: + if r1 == r2: + return r1 + if len(r1) != len(r2): + return None + r = [] + bb = subsets.union(r1, r2).ndrange() + for i in range(len(r1)): + if r1.strides()[i] != r2.strides()[i]: + return None + if r1.strides()[i] == 1: + r.append(bb[i]) + elif r1.ranges[i] == r2.ranges[i]: + r.append(bb[i]) + else: + return None + return r + + +def maps_have_compatible_ranges(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool: """Decide if the two ranges are compatible. See the class docstring for what is considered compatible.""" if first_entry.map.schedule != second_entry.map.schedule: # If the two maps are not to be scheduled on the same device, don't fuse them. @@ -447,7 +465,8 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return False first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) - if not compatible_range(first_entry, second_entry, use_grid_strided_loops=self.use_grid_strided_loops): + if not maps_have_compatible_ranges(first_entry, second_entry, + use_grid_strided_loops=self.use_grid_strided_loops): return False # Both maps must have consistent constant assignment for the target arrays. @@ -485,9 +504,11 @@ def apply(self, graph: SDFGState, sdfg: SDFG): consolidate_written_nodes(graph, first_exit, second_exit) # If the ranges are identical, then simply fuse the two maps. Otherwise, use grid-strided loops. + assert fused_range(first_entry.map.range, second_entry.map.range) is not None en, ex = graph.add_map(sdfg._find_new_name('map_fusion_wrapper'), - {k: v for k, v in zip(first_entry.map.params, - subsets.union(first_entry.map.range, second_entry.map.range))}, + ndrange={k: v for k, v in zip(first_entry.map.params, + fused_range(first_entry.map.range, + second_entry.map.range))}, schedule=first_entry.map.schedule) for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: if en.map.range.covers(cur_en.map.range): @@ -545,8 +566,8 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, assignments[k] = v # Moreover, both states' ranges must be compatible. - if not compatible_range(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0], - use_grid_strided_loops=self.use_grid_strided_loops): + if not maps_have_compatible_ranges(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0], + use_grid_strided_loops=self.use_grid_strided_loops): return False return True From 8568c9fae5845f19e821c8cc82e154f296b79545 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Mon, 14 Oct 2024 11:08:45 +0200 Subject: [PATCH 17/29] Add checks to make sure the transformations are actually happening (or not happening) --- .../const_assignment_fusion_test.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index dba0128ae2..962942e67b 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -77,15 +77,19 @@ def test_within_state_fusion(): g.save(os.path.join('_dacegraphs', 'simple-1.sdfg')) g.validate() - g.apply_transformations(ConstAssignmentMapFusion) + assert g.apply_transformations(ConstAssignmentMapFusion, + options={'use_grid_strided_loops': True}) == 1 g.save(os.path.join('_dacegraphs', 'simple-2.sdfg')) g.validate() - g.apply_transformations(ConstAssignmentMapFusion) + assert g.apply_transformations(ConstAssignmentMapFusion, + options={'use_grid_strided_loops': True}) == 1 g.save(os.path.join('_dacegraphs', 'simple-3.sdfg')) g.validate() - g.apply_transformations(ConstAssignmentMapFusion) + assert g.apply_transformations(ConstAssignmentMapFusion) == 0 + assert g.apply_transformations(ConstAssignmentMapFusion, + options={'use_grid_strided_loops': True}) == 1 g.save(os.path.join('_dacegraphs', 'simple-4.sdfg')) g.validate() our_A = deepcopy(A) @@ -105,15 +109,19 @@ def test_interstate_fusion(): actual_A = deepcopy(A) g(A=actual_A, M=4, N=5) - g.apply_transformations(ConstAssignmentStateFusion) + assert g.apply_transformations(ConstAssignmentStateFusion, + options={'use_grid_strided_loops': True}) == 1 g.save(os.path.join('_dacegraphs', 'interstate-1.sdfg')) g.validate() - g.apply_transformations(ConstAssignmentStateFusion) + assert g.apply_transformations(ConstAssignmentStateFusion, + options={'use_grid_strided_loops': True}) == 1 g.save(os.path.join('_dacegraphs', 'interstate-2.sdfg')) g.validate() - g.apply_transformations(ConstAssignmentStateFusion) + assert g.apply_transformations(ConstAssignmentStateFusion) == 0 + assert g.apply_transformations(ConstAssignmentStateFusion, + options={'use_grid_strided_loops': True}) == 1 g.save(os.path.join('_dacegraphs', 'interstate-3.sdfg')) g.validate() our_A = deepcopy(A) @@ -141,7 +149,7 @@ def test_free_floating_fusion(): actual_B = deepcopy(B) g(A=actual_A, B=actual_B, M=4, N=5) - g.apply_transformations(ConstAssignmentMapFusion) + assert g.apply_transformations(ConstAssignmentMapFusion) == 1 g.save(os.path.join('_dacegraphs', 'floating-1.sdfg')) g.validate() our_A = deepcopy(A) @@ -182,7 +190,7 @@ def test_fusion_with_multiple_indices(): actual_B = deepcopy(B) g(A=actual_A, B=actual_B, K=3, M=4, N=5) - g.apply_transformations(ConstAssignmentMapFusion) + assert g.apply_transformations(ConstAssignmentMapFusion) == 1 g.save(os.path.join('_dacegraphs', '3d-1.sdfg')) g.validate() our_A = deepcopy(A) @@ -211,7 +219,7 @@ def test_fusion_with_branch(): actual_B = deepcopy(B) g(A=actual_A, B=actual_B, M=4, N=5) - g.apply_transformations(ConstAssignmentMapFusion) + assert g.apply_transformations(ConstAssignmentMapFusion) == 1 g.save(os.path.join('_dacegraphs', 'branched-1.sdfg')) g.validate() our_A = deepcopy(A) From 113d2b3bb158a84ec5a9894bcd0dd949d2590fe9 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Mon, 14 Oct 2024 14:16:27 +0200 Subject: [PATCH 18/29] If there was no path established between the parent and the strided maps, create empty ones. --- .../dataflow/const_assignment_fusion.py | 5 +++ .../const_assignment_fusion_test.py | 39 ++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index c16cb1f485..8292b8c218 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -356,6 +356,10 @@ def range_for_grid_stride(r, val, bound): src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) + if len(graph.in_edges(en)) == 0: + graph.add_memlet_path(dst_en, en, memlet=Memlet()) + if len(graph.out_edges(ex)) == 0: + graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) def fused_range(r1: Range, r2: Range) -> Optional[Range]: @@ -511,6 +515,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): second_entry.map.range))}, schedule=first_entry.map.schedule) for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: + assert en.map.range.covers(cur_en.map.range) or self.use_grid_strided_loops if en.map.range.covers(cur_en.map.range): consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) elif self.use_grid_strided_loops: diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 962942e67b..8b3e072232 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -172,10 +172,38 @@ def assign_bottom_face(A: dace.float32[K, M, N]): A[K - 1, t1, t2] = 1 +@dace.program +def assign_front_face(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:K, 0:N]: + A[t1, 0, t2] = 1 + + +@dace.program +def assign_back_face(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:K, 0:N]: + A[t1, M - 1, t2] = 1 + + +@dace.program +def assign_left_face(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:K, 0:M]: + A[t1, t2, 0] = 1 + + +@dace.program +def assign_right_face(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:K, 0:M]: + A[t1, t2, N - 1] = 1 + + @dace.program def assign_bounary_3d(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): assign_top_face(A) assign_bottom_face(B) + assign_front_face(A) + assign_back_face(B) + assign_left_face(A) + assign_right_face(B) def test_fusion_with_multiple_indices(): @@ -190,13 +218,22 @@ def test_fusion_with_multiple_indices(): actual_B = deepcopy(B) g(A=actual_A, B=actual_B, K=3, M=4, N=5) - assert g.apply_transformations(ConstAssignmentMapFusion) == 1 + assert g.apply_transformations_repeated(ConstAssignmentMapFusion, options={'use_grid_strided_loops': True}) == 3 g.save(os.path.join('_dacegraphs', '3d-1.sdfg')) g.validate() our_A = deepcopy(A) our_B = deepcopy(B) g(A=our_A, B=our_B, K=3, M=4, N=5) + assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + g.simplify() + assert g.apply_transformations_repeated(ConstAssignmentStateFusion, options={'use_grid_strided_loops': True}) == 2 + g.save(os.path.join('_dacegraphs', '3d-2.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, K=3, M=4, N=5) + # print(our_A) assert np.allclose(our_A, actual_A) From e60fe6de9d7323a0fcc8e158db6efca7034dae76 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Mon, 14 Oct 2024 16:27:30 +0200 Subject: [PATCH 19/29] Be slightly more explicit about when to use grid strided loop. --- .../dataflow/const_assignment_fusion.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 8292b8c218..2c08754891 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -514,12 +514,16 @@ def apply(self, graph: SDFGState, sdfg: SDFG): fused_range(first_entry.map.range, second_entry.map.range))}, schedule=first_entry.map.schedule) - for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: - assert en.map.range.covers(cur_en.map.range) or self.use_grid_strided_loops - if en.map.range.covers(cur_en.map.range): + if first_entry.map.range == second_entry.map.range: + for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) - elif self.use_grid_strided_loops: - consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) + elif self.use_grid_strided_loops: + assert ScheduleType.Sequential not in [first_entry.map.schedule, second_entry.map.schedule] + for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: + if en.map.range == cur_en.map.range: + consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) + else: + consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) # Cleanup: remove duplicate empty dependencies. seen = set() From cb5398d83c7dcfdd4ac409a371893ec4e11df44c Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 15 Oct 2024 16:50:59 +0200 Subject: [PATCH 20/29] Handle the AST API differences between python versions. --- .../dataflow/const_assignment_fusion.py | 45 +++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 2c08754891..ab5027a4f4 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -2,7 +2,7 @@ from collections import defaultdict from copy import deepcopy from itertools import chain -from typing import Optional, Union +from typing import Optional, Union, Tuple from dace import transformation, SDFGState, SDFG, Memlet, ScheduleType, subsets from dace.properties import make_properties, Property @@ -14,7 +14,7 @@ from dace.transformation.interstate import StateFusionExtended -def unique_top_level_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]: +def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]: all_top_nodes = [n for n, s in graph.scope_dict().items() if s is None] if not all(isinstance(n, (MapEntry, AccessNode)) for n in all_top_nodes): return None @@ -32,7 +32,7 @@ def floating_nodes_graph(*args): return g -def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: +def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: """ If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is considered consistent. @@ -86,10 +86,10 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: # ...must assign... return False, table op = n.code.code[0] - if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: + if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = op.value.value + const = value_of_constant_or_numerical_literal(op.value) for oe in body.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -101,7 +101,17 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: return True, table -def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: +def is_constant_or_numerical_literal(n: ast.Expr): + """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" + return isinstance(n, (ast.Constant, ast.Num)) + + +def value_of_constant_or_numerical_literal(n: ast.Expr): + """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" + return n.value if isinstance(n, ast.Constant) else n.n + + +def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]: """ If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is @@ -141,10 +151,10 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi # ...that assigns... return False, table op = n.code.code[0] - if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: + if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = op.value.value + const = value_of_constant_or_numerical_literal(op.value) for oe in graph.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -156,6 +166,13 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi return True, table +def removeprefix(c: str, p: str): + """Since `str.removeprefix()` wasn't added until Python 3.9""" + if not c.startswith(p): + return c + return c[len(p):] + + def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): """ Create the additional connectors in the first exit node that matches the second exit node (which will be removed @@ -164,7 +181,7 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN conn_map = defaultdict() for c, v in src.in_connectors.items(): assert c.startswith('IN_') - cbase = c.removeprefix('IN_') + cbase = removeprefix(c, 'IN_') sc = dst.next_connector(cbase) conn_map[f"IN_{cbase}"] = f"IN_{sc}" conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" @@ -181,9 +198,9 @@ def connector_counterpart(c: Union[str, None]) -> Union[str, None]: return None assert isinstance(c, str) if c.startswith('IN_'): - return f"OUT_{c.removeprefix('IN_')}" + return f"OUT_{removeprefix(c, 'IN_')}" elif c.startswith('OUT_'): - return f"IN_{c.removeprefix('OUT_')}" + return f"IN_{removeprefix(c, 'OUT_')}" return None @@ -274,7 +291,7 @@ def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit graph.remove_node(n) -def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tuple[MapEntry, MapExit]): +def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical. """ @@ -310,8 +327,8 @@ def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tu graph.remove_node(src_ex) -def consume_map_with_grid_strided_loop(graph: SDFGState, dst: tuple[MapEntry, MapExit], - src: tuple[MapEntry, MapExit]): +def consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit], + src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop. Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not From a4e6071c92d82111e0358a372618768eee67830f Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Mon, 28 Oct 2024 18:35:23 +0100 Subject: [PATCH 21/29] Add two designated "negative tests". --- .../const_assignment_fusion_test.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 8b3e072232..f435050a82 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -267,6 +267,68 @@ def test_fusion_with_branch(): assert np.allclose(our_A, actual_A) +@dace.program +def assign_bottom_face_flipped(A: dace.float32[K, M, N]): + for t2, t1 in dace.map[0:N, 0:M]: + A[K - 1, t1, t2] = 1 + + +@dace.program +def assign_bounary_3d_with_flip(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): + assign_top_face(A) + assign_bottom_face_flipped(B) + + +def test_does_not_permute_to_fuse(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_bounary_3d_with_flip.to_sdfg(simplify=True, validate=True, use_cache=False) + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-flip-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + actual_B = deepcopy(B) + g(A=actual_A, B=actual_B, K=3, M=4, N=5) + + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 + g.save(os.path.join('_dacegraphs', '3d-flip-1.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, K=3, M=4, N=5) + + +@dace.program +def assign_mixed_dims(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): + assign_top_face(A) + assign_left_col(B[0, :, :]) + + +def test_does_not_extend_to_fuse(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_mixed_dims.to_sdfg(simplify=True, validate=True, use_cache=False) + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-mixed-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + actual_B = deepcopy(B) + g(A=actual_A, B=actual_B, K=3, M=4, N=5) + + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 + g.save(os.path.join('_dacegraphs', '3d-mixed-1.sdfg')) + g.validate() + our_A = deepcopy(A) + our_B = deepcopy(B) + g(A=our_A, B=our_B, K=3, M=4, N=5) + + if __name__ == '__main__': test_within_state_fusion() test_interstate_fusion() From 4e8ca45864dc22199d34b423a2bf1f3a62bb14e1 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 29 Oct 2024 16:29:52 +0100 Subject: [PATCH 22/29] Privatize the helper functions. --- .../dataflow/const_assignment_fusion.py | 110 +++++++++--------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index ab5027a4f4..027a8cec91 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -14,7 +14,7 @@ from dace.transformation.interstate import StateFusionExtended -def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]: +def _unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]: all_top_nodes = [n for n, s in graph.scope_dict().items() if s is None] if not all(isinstance(n, (MapEntry, AccessNode)) for n in all_top_nodes): return None @@ -25,14 +25,14 @@ def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapE return en[0], ex[0] -def floating_nodes_graph(*args): +def _floating_nodes_graph(*args): g = OrderedDiGraph() for n in args: g.add_node(n) return g -def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: +def _consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: """ If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is considered consistent. @@ -86,10 +86,10 @@ def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: # ...must assign... return False, table op = n.code.code[0] - if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: + if not _is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = value_of_constant_or_numerical_literal(op.value) + const = _value_of_constant_or_numerical_literal(op.value) for oe in body.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -101,17 +101,17 @@ def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: return True, table -def is_constant_or_numerical_literal(n: ast.Expr): +def _is_constant_or_numerical_literal(n: ast.Expr): """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" return isinstance(n, (ast.Constant, ast.Num)) -def value_of_constant_or_numerical_literal(n: ast.Expr): +def _value_of_constant_or_numerical_literal(n: ast.Expr): """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" return n.value if isinstance(n, ast.Constant) else n.n -def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]: +def _consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]: """ If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is @@ -121,7 +121,7 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi for n in graph.all_nodes_between(en, ex): if isinstance(n, NestedSDFG): # First handle the case of conditional constant assignment. - is_branch_const_assignment, internal_table = consistent_branch_const_assignment_table(n) + is_branch_const_assignment, internal_table = _consistent_branch_const_assignment_table(n) if not is_branch_const_assignment: return False, table for oe in graph.out_edges(n): @@ -133,7 +133,7 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi table[dst] = internal_table[oe.src_conn] table[dst_arr] = internal_table[oe.src_conn] elif isinstance(n, MapEntry): - is_const_assignment, internal_table = consistent_const_assignment_table(graph, n, graph.exit_node(n)) + is_const_assignment, internal_table = _consistent_const_assignment_table(graph, n, graph.exit_node(n)) if not is_const_assignment: return False, table for k, v in internal_table.items(): @@ -151,10 +151,10 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi # ...that assigns... return False, table op = n.code.code[0] - if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: + if not _is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = value_of_constant_or_numerical_literal(op.value) + const = _value_of_constant_or_numerical_literal(op.value) for oe in graph.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -166,14 +166,14 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi return True, table -def removeprefix(c: str, p: str): +def _removeprefix(c: str, p: str): """Since `str.removeprefix()` wasn't added until Python 3.9""" if not c.startswith(p): return c return c[len(p):] -def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): +def _add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): """ Create the additional connectors in the first exit node that matches the second exit node (which will be removed later). @@ -181,7 +181,7 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN conn_map = defaultdict() for c, v in src.in_connectors.items(): assert c.startswith('IN_') - cbase = removeprefix(c, 'IN_') + cbase = _removeprefix(c, 'IN_') sc = dst.next_connector(cbase) conn_map[f"IN_{cbase}"] = f"IN_{sc}" conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" @@ -192,19 +192,19 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN return conn_map -def connector_counterpart(c: Union[str, None]) -> Union[str, None]: +def _connector_counterpart(c: Union[str, None]) -> Union[str, None]: """If it's an input connector, find the corresponding output connector, and vice versa.""" if c is None: return None assert isinstance(c, str) if c.startswith('IN_'): - return f"OUT_{removeprefix(c, 'IN_')}" + return f"OUT_{_removeprefix(c, 'IN_')}" elif c.startswith('OUT_'): - return f"IN_{removeprefix(c, 'OUT_')}" + return f"IN_{_removeprefix(c, 'OUT_')}" return None -def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry): +def _consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry): """ Remove all the incoming edges of the two maps and add empty edges from the union of the access nodes they depended on before. @@ -245,7 +245,7 @@ def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, seco graph.add_memlet_path(n, en, memlet=Memlet()) -def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): +def _consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): """ If the two maps write to the same underlying data array through two access nodes, replace those edges' destination with a single shared copy. @@ -291,7 +291,7 @@ def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit graph.remove_node(n) -def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]): +def _consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical. """ @@ -299,7 +299,7 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu src_en, src_ex = src assert all(e.data.is_empty() for e in graph.in_edges(src_en)) - cmap = add_equivalent_connectors(dst_en, src_en) + cmap = _add_equivalent_connectors(dst_en, src_en) for e in graph.in_edges(src_en): graph.add_memlet_path(e.src, dst_en, src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), @@ -311,7 +311,7 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) - cmap = add_equivalent_connectors(dst_ex, src_ex) + cmap = _add_equivalent_connectors(dst_ex, src_ex) for e in graph.in_edges(src_ex): graph.add_memlet_path(e.src, dst_ex, src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), @@ -327,8 +327,8 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu graph.remove_node(src_ex) -def consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit], - src: Tuple[MapEntry, MapExit]): +def _consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit], + src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop. Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not @@ -350,10 +350,10 @@ def range_for_grid_stride(r, val, bound): en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), ndrange={k: v for k, v in zip(gsl_params, gsl_ranges)}, schedule=ScheduleType.Sequential) - consume_map_exactly(graph, (en, ex), src) + _consume_map_exactly(graph, (en, ex), src) assert all(e.data.is_empty() for e in graph.in_edges(en)) - cmap = add_equivalent_connectors(dst_en, en) + cmap = _add_equivalent_connectors(dst_en, en) for e in graph.in_edges(en): graph.add_memlet_path(e.src, dst_en, src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), @@ -363,11 +363,11 @@ def range_for_grid_stride(r, val, bound): memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) - cmap = add_equivalent_connectors(dst_ex, ex) + cmap = _add_equivalent_connectors(dst_ex, ex) for e in graph.out_edges(ex): graph.add_memlet_path(e.src, dst_ex, src_conn=e.src_conn, - dst_conn=connector_counterpart(cmap.get(e.src_conn)), + dst_conn=_connector_counterpart(cmap.get(e.src_conn)), memlet=Memlet.from_memlet(e.data)) graph.add_memlet_path(dst_ex, e.dst, src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, @@ -379,7 +379,7 @@ def range_for_grid_stride(r, val, bound): graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) -def fused_range(r1: Range, r2: Range) -> Optional[Range]: +def _fused_range(r1: Range, r2: Range) -> Optional[Range]: if r1 == r2: return r1 if len(r1) != len(r2): @@ -398,7 +398,7 @@ def fused_range(r1: Range, r2: Range) -> Optional[Range]: return r -def maps_have_compatible_ranges(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool: +def _maps_have_compatible_ranges(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool: """Decide if the two ranges are compatible. See the class docstring for what is considered compatible.""" if first_entry.map.schedule != second_entry.map.schedule: # If the two maps are not to be scheduled on the same device, don't fuse them. @@ -446,16 +446,16 @@ class ConstAssignmentMapFusion(MapFusion): def expressions(cls): # Take any two maps, then check that _every_ path from the first map to second map has exactly one access node # in the middle and the second edge of the path is empty. - return [floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)] + return [_floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)] - def map_nodes(self, graph: SDFGState): + def _map_nodes(self, graph: SDFGState): """Return the entry and exit nodes of the relevant maps as a tuple: entry_1, exit_1, entry_2, exit_2.""" return (self.first_map_entry, graph.exit_node(self.first_map_entry), self.second_map_entry, graph.exit_node(self.second_map_entry)) - def no_dependency_pattern(self, graph: SDFGState) -> bool: + def _no_dependency_pattern(self, graph: SDFGState) -> bool: """Decide if the two maps are independent of each other.""" - first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) + first_entry, first_exit, second_entry, second_exit = self._map_nodes(graph) if graph.scope_dict()[first_entry] != graph.scope_dict()[second_entry]: return False if not all(isinstance(n, AccessNode) for n in graph.all_nodes_between(first_exit, second_entry)): @@ -482,19 +482,19 @@ def no_dependency_pattern(self, graph: SDFGState) -> bool: def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: # Test the rest of the second pattern in the `expressions()`. - if not self.no_dependency_pattern(graph): + if not self._no_dependency_pattern(graph): return False - first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) - if not maps_have_compatible_ranges(first_entry, second_entry, - use_grid_strided_loops=self.use_grid_strided_loops): + first_entry, first_exit, second_entry, second_exit = self._map_nodes(graph) + if not _maps_have_compatible_ranges(first_entry, second_entry, + use_grid_strided_loops=self.use_grid_strided_loops): return False # Both maps must have consistent constant assignment for the target arrays. - is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit) + is_const_assignment, assignments = _consistent_const_assignment_table(graph, first_entry, first_exit) if not is_const_assignment: return False - is_const_assignment, further_assignments = consistent_const_assignment_table(graph, second_entry, second_exit) + is_const_assignment, further_assignments = _consistent_const_assignment_table(graph, second_entry, second_exit) if not is_const_assignment: return False for k, v in further_assignments.items(): @@ -504,11 +504,11 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return True def apply(self, graph: SDFGState, sdfg: SDFG): - first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) + first_entry, first_exit, second_entry, second_exit = self._map_nodes(graph) # By now, we know that the two maps are compatible, not reading anything, and just blindly writing constants # _consistently_. - is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit) + is_const_assignment, assignments = _consistent_const_assignment_table(graph, first_entry, first_exit) assert is_const_assignment # Rename in case loop variables are named differently. @@ -519,28 +519,28 @@ def apply(self, graph: SDFGState, sdfg: SDFG): second_entry.map.params = first_entry.map.params # Consolidate the incoming dependencies of the two maps. - consolidate_empty_dependencies(graph, first_entry, second_entry) + _consolidate_empty_dependencies(graph, first_entry, second_entry) # Consolidate the written access nodes of the two maps. - consolidate_written_nodes(graph, first_exit, second_exit) + _consolidate_written_nodes(graph, first_exit, second_exit) # If the ranges are identical, then simply fuse the two maps. Otherwise, use grid-strided loops. - assert fused_range(first_entry.map.range, second_entry.map.range) is not None + assert _fused_range(first_entry.map.range, second_entry.map.range) is not None en, ex = graph.add_map(sdfg._find_new_name('map_fusion_wrapper'), ndrange={k: v for k, v in zip(first_entry.map.params, - fused_range(first_entry.map.range, - second_entry.map.range))}, + _fused_range(first_entry.map.range, + second_entry.map.range))}, schedule=first_entry.map.schedule) if first_entry.map.range == second_entry.map.range: for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: - consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) + _consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) elif self.use_grid_strided_loops: assert ScheduleType.Sequential not in [first_entry.map.schedule, second_entry.map.schedule] for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: if en.map.range == cur_en.map.range: - consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) + _consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) else: - consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) + _consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) # Cleanup: remove duplicate empty dependencies. seen = set() @@ -577,13 +577,13 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, # Moreover, the states together must contain a consistent constant assignment map. assignments = {} for st in [st0, st1]: - en_ex = unique_top_level_map_node(st) + en_ex = _unique_top_level_map_node(st) if not en_ex: return False en, ex = en_ex if any(not e.data.is_empty for e in st.in_edges(en)): return False - is_const_assignment, further_assignments = consistent_const_assignment_table(st, en, ex) + is_const_assignment, further_assignments = _consistent_const_assignment_table(st, en, ex) if not is_const_assignment: return False for k, v in further_assignments.items(): @@ -592,8 +592,8 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, assignments[k] = v # Moreover, both states' ranges must be compatible. - if not maps_have_compatible_ranges(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0], - use_grid_strided_loops=self.use_grid_strided_loops): + if not _maps_have_compatible_ranges(_unique_top_level_map_node(st0)[0], _unique_top_level_map_node(st1)[0], + use_grid_strided_loops=self.use_grid_strided_loops): return False return True From 23f268974a1cee5a61d7632579f87b2471de1eee Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 29 Oct 2024 17:18:17 +0100 Subject: [PATCH 23/29] Cover even more negative test cases. --- .../const_assignment_fusion_test.py | 118 ++++++++++++++++-- 1 file changed, 108 insertions(+), 10 deletions(-) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index f435050a82..217efd3085 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -225,8 +225,8 @@ def test_fusion_with_multiple_indices(): our_B = deepcopy(B) g(A=our_A, B=our_B, K=3, M=4, N=5) + # Here, the state fusion can apply only with GSLs. assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 - g.simplify() assert g.apply_transformations_repeated(ConstAssignmentStateFusion, options={'use_grid_strided_loops': True}) == 2 g.save(os.path.join('_dacegraphs', '3d-2.sdfg')) g.validate() @@ -294,11 +294,6 @@ def test_does_not_permute_to_fuse(): g(A=actual_A, B=actual_B, K=3, M=4, N=5) assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 - g.save(os.path.join('_dacegraphs', '3d-flip-1.sdfg')) - g.validate() - our_A = deepcopy(A) - our_B = deepcopy(B) - g(A=our_A, B=our_B, K=3, M=4, N=5) @dace.program @@ -322,11 +317,114 @@ def test_does_not_extend_to_fuse(): g(A=actual_A, B=actual_B, K=3, M=4, N=5) assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 - g.save(os.path.join('_dacegraphs', '3d-mixed-1.sdfg')) + + +@dace.program +def assign_bottom_face_42(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:M, 0:N]: + A[K - 1, t1, t2] = 42 + + +@dace.program +def assign_bottom_face_index_sum(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:M, 0:N]: + A[K - 1, t1, t2] = t1 + t2 + + +@dace.program +def assign_inconsistent_values_1(A: dace.float32[K, M, N]): + assign_top_face(A) + assign_bottom_face_42(A) + + +@dace.program +def assign_inconsistent_values_2(A: dace.float32[K, M, N]): + assign_top_face(A) + assign_bottom_face_index_sum(A) + + +def test_does_not_fuse_with_inconsistent_assignments(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = assign_inconsistent_values_1.to_sdfg(simplify=True, validate=True, use_cache=False) + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-inconsistent-0.sdfg')) g.validate() - our_A = deepcopy(A) - our_B = deepcopy(B) - g(A=our_A, B=our_B, K=3, M=4, N=5) + actual_A = deepcopy(A) + g(A=actual_A, K=3, M=4, N=5) + + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 + + # Try another case: Construct SDFG with the maps on separate states. + g = assign_inconsistent_values_2.to_sdfg(simplify=True, validate=True, use_cache=False) + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-inconsistent-1.sdfg')) + g.validate() + actual_A = deepcopy(A) + g(A=actual_A, K=3, M=4, N=5) + + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 + + +@dace.program +def tasklet_between_maps(A: dace.float32[K, M, N]): + assign_top_face(A) + A[0, 0, 0] = 1 + assign_bottom_face(A) + + +def test_does_not_fuse_with_unsuitable_dependencies(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = tasklet_between_maps.to_sdfg(simplify=True, validate=True, use_cache=False) + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-baddeps-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + g(A=actual_A, K=3, M=4, N=5) + + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 + + +@dace.program +def assign_top_face_self_copy(A: dace.float32[K, M, N]): + for t1, t2 in dace.map[0:M, 0:N]: + A[0, t1, t2] = A[0, t1, t2] + + +@dace.program +def first_map_reads_data(A: dace.float32[K, M, N]): + assign_top_face_self_copy(A) + assign_bottom_face(A) + + +def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = first_map_reads_data.to_sdfg(simplify=True, validate=True, use_cache=False) + g.save(os.path.join('_dacegraphs', '3d-map1-reads-0.sdfg')) + g.validate() + actual_A = deepcopy(A) + g(A=actual_A, K=3, M=4, N=5) + + # The state fusion won't work. + assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + + # Fuse the states explicitly anyway. + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-map1-reads-1.sdfg')) + g.validate() + actual_A = deepcopy(A) + g(A=actual_A, K=3, M=4, N=5) + + # The map fusion won't work. + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 if __name__ == '__main__': From 116738049313372636cedba31a8e272307bd43b6 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 30 Oct 2024 13:07:29 +0100 Subject: [PATCH 24/29] Construct a subgraph view to do the replacement correctly. --- .../transformation/dataflow/const_assignment_fusion.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 027a8cec91..7c490d7071 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -8,7 +8,7 @@ from dace.properties import make_properties, Property from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.nodes import Tasklet, ExitNode, MapEntry, MapExit, NestedSDFG, Node, EntryNode, AccessNode -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, StateSubgraphView from dace.subsets import Range from dace.transformation.dataflow import MapFusion from dace.transformation.interstate import StateFusionExtended @@ -512,11 +512,9 @@ def apply(self, graph: SDFGState, sdfg: SDFG): assert is_const_assignment # Rename in case loop variables are named differently. - param_map = {p2: p1 for p1, p2 in zip(first_entry.map.params, second_entry.map.params)} - for t in graph.all_nodes_between(second_entry, second_exit): - for e in graph.out_edges(t): - e.data.subset.replace(param_map) - second_entry.map.params = first_entry.map.params + nodes_to_update = {n for n in graph.all_nodes_between(second_entry, second_exit)} | {second_entry, second_exit} + view = StateSubgraphView(graph, list(nodes_to_update)) + view.replace_dict({p2: p1 for p1, p2 in zip(first_entry.map.params, second_entry.map.params)}) # Consolidate the incoming dependencies of the two maps. _consolidate_empty_dependencies(graph, first_entry, second_entry) From 4d6b4c6160338af86adce48111e16d6108059a5f Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 30 Oct 2024 13:17:36 +0100 Subject: [PATCH 25/29] Replace all the convenient uses of `@dace.program` and `simply()` with hand-crafted graphs. --- .../const_assignment_fusion_test.py | 436 +++++++++--------- 1 file changed, 223 insertions(+), 213 deletions(-) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 217efd3085..7eb82587f0 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -1,65 +1,69 @@ import os +from collections.abc import Collection from copy import deepcopy +from itertools import chain +from typing import Tuple, Sequence import numpy as np import dace +from dace import SDFG, Memlet +from dace.properties import CodeBlock +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import SDFGState +from dace.subsets import Range from dace.transformation.dataflow.const_assignment_fusion import ConstAssignmentMapFusion, ConstAssignmentStateFusion from dace.transformation.interstate import StateFusionExtended -K = dace.symbol('K') -M = dace.symbol('M') -N = dace.symbol('N') - - -@dace.program -def assign_top_row(A: dace.float32[M, N]): - for t in dace.map[0:N]: - A[0, t] = 1 - - -@dace.program -def assign_top_row_branched(A: dace.float32[M, N]): - for t, in dace.map[0:N]: - if t % 2 == 0: - A[0, t] = 1 - - -@dace.program -def assign_bottom_row(A: dace.float32[M, N]): - for b in dace.map[0:N]: - A[M - 1, b] = 1 - - -@dace.program -def assign_left_col(A: dace.float32[M, N]): - for l in dace.map[0:M]: - A[l, 0] = 1 - - -@dace.program -def assign_right_col(A: dace.float32[M, N]): - for r in dace.map[0:M]: - A[r, N - 1] = 1 +K, M, N = dace.symbol('K'), dace.symbol('M'), dace.symbol('N') + + +def _add_face_assignment_map(g: SDFGState, name: str, lims: Sequence[Tuple[str, dace.symbol]], + fixed_dims: Collection[Tuple[int, int]], assigned_val: int, array: str): + idx = [k for k, v in lims] + for fd, at in fixed_dims: + idx.insert(fd, str(at)) + t, en, ex = g.add_mapped_tasklet(name, [(k, Range([(0, v - 1, 1)])) for k, v in lims], + {}, f"__out = {assigned_val}", {'__out': Memlet(expr=f"{array}[{','.join(idx)}]")}, + external_edges=True) + return en, ex, t + + +def _simple_if_block(name: str, cond: str, val: int): + subg = SDFG(name) + subg.add_array('tmp', (1,), dace.float32) + # Outer structure. + head = subg.add_state('if_head') + branch = subg.add_state('if_b1') + tail = subg.add_state('if_tail') + subg.add_edge(head, branch, InterstateEdge(condition=f"({cond})")) + subg.add_edge(head, tail, InterstateEdge(condition=f"(not ({cond}))")) + subg.add_edge(branch, tail, InterstateEdge()) + # Inner structure. + t = branch.add_tasklet('top', inputs={}, outputs={'__out'}, code=f"__out = {val}") + tmp = branch.add_access('tmp') + branch.add_edge(t, '__out', tmp, None, Memlet(expr='tmp[0]')) + return subg def assign_bounary_sdfg(): - st0 = assign_top_row.to_sdfg(simplify=True, validate=True, use_cache=False) - st0.start_block.label = 'st0' - - st1 = assign_bottom_row.to_sdfg(simplify=True, validate=True, use_cache=False) - st1.start_block.label = 'st1' - st0.add_edge(st0.start_state, st1.start_state, dace.InterstateEdge()) + g = SDFG('prog') + g.add_array('A', (M, N), dace.float32) - st2 = assign_left_col.to_sdfg(simplify=True, validate=True, use_cache=False) - st2.start_block.label = 'st2' - st0.add_edge(st1.start_state, st2.start_state, dace.InterstateEdge()) + st0 = g.add_state('top') + _add_face_assignment_map(st0, 'top', [('j', N)], [(0, 0)], 1, 'A') + st1 = g.add_state('bottom') + _add_face_assignment_map(st1, 'bottom', [('j', N)], [(0, M - 1)], 1, 'A') + st2 = g.add_state('left') + _add_face_assignment_map(st2, 'left', [('i', M)], [(1, 0)], 1, 'A') + st3 = g.add_state('right') + _add_face_assignment_map(st3, 'right', [('i', M)], [(1, N - 1)], 1, 'A') - st3 = assign_right_col.to_sdfg(simplify=True, validate=True, use_cache=False) - st3.start_block.label = 'st3' - st0.add_edge(st2.start_state, st3.start_state, dace.InterstateEdge()) + g.add_edge(st0, st1, dace.InterstateEdge()) + g.add_edge(st1, st2, dace.InterstateEdge()) + g.add_edge(st2, st3, dace.InterstateEdge()) - return st0 + return g def test_within_state_fusion(): @@ -67,35 +71,26 @@ def test_within_state_fusion(): # Construct SDFG with the maps on separate states. g = assign_bounary_sdfg() + # Fuse the two states so that the const-assignment-fusion is applicable. + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) g.save(os.path.join('_dacegraphs', 'simple-0.sdfg')) g.validate() + g.compile() + + # Get the reference data. actual_A = deepcopy(A) g(A=actual_A, M=4, N=5) - # Fuse the two states so that the const-assignment-fusion is applicable. - g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + assert g.apply_transformations_repeated(ConstAssignmentMapFusion, options={'use_grid_strided_loops': True}) == 3 g.save(os.path.join('_dacegraphs', 'simple-1.sdfg')) g.validate() + g.compile() - assert g.apply_transformations(ConstAssignmentMapFusion, - options={'use_grid_strided_loops': True}) == 1 - g.save(os.path.join('_dacegraphs', 'simple-2.sdfg')) - g.validate() - - assert g.apply_transformations(ConstAssignmentMapFusion, - options={'use_grid_strided_loops': True}) == 1 - g.save(os.path.join('_dacegraphs', 'simple-3.sdfg')) - g.validate() - - assert g.apply_transformations(ConstAssignmentMapFusion) == 0 - assert g.apply_transformations(ConstAssignmentMapFusion, - options={'use_grid_strided_loops': True}) == 1 - g.save(os.path.join('_dacegraphs', 'simple-4.sdfg')) - g.validate() + # Get our data. our_A = deepcopy(A) g(A=our_A, M=4, N=5) - # print(our_A) + # Verify numerically. assert np.allclose(our_A, actual_A) @@ -106,35 +101,35 @@ def test_interstate_fusion(): g = assign_bounary_sdfg() g.save(os.path.join('_dacegraphs', 'interstate-0.sdfg')) g.validate() + g.compile() + + # Get the reference data. actual_A = deepcopy(A) g(A=actual_A, M=4, N=5) - assert g.apply_transformations(ConstAssignmentStateFusion, - options={'use_grid_strided_loops': True}) == 1 + assert g.apply_transformations_repeated(ConstAssignmentStateFusion, options={'use_grid_strided_loops': True}) == 3 g.save(os.path.join('_dacegraphs', 'interstate-1.sdfg')) g.validate() + g.compile() - assert g.apply_transformations(ConstAssignmentStateFusion, - options={'use_grid_strided_loops': True}) == 1 - g.save(os.path.join('_dacegraphs', 'interstate-2.sdfg')) - g.validate() - - assert g.apply_transformations(ConstAssignmentStateFusion) == 0 - assert g.apply_transformations(ConstAssignmentStateFusion, - options={'use_grid_strided_loops': True}) == 1 - g.save(os.path.join('_dacegraphs', 'interstate-3.sdfg')) - g.validate() + # Get our data. our_A = deepcopy(A) g(A=our_A, M=4, N=5) - # print(our_A) + # Verify numerically. assert np.allclose(our_A, actual_A) -@dace.program -def assign_bounary_free_floating(A: dace.float32[M, N], B: dace.float32[M, N]): - assign_top_row(A) - assign_bottom_row(B) +def assign_bounary_free_floating_sdfg(): + g = SDFG('prog') + g.add_array('A', (M, N), dace.float32) + g.add_array('B', (M, N), dace.float32) + + st0 = g.add_state('st0') + _add_face_assignment_map(st0, 'top', [('j', N)], [(0, 0)], 1, 'A') + _add_face_assignment_map(st0, 'bottom', [('j', N)], [(0, M - 1)], 2, 'B') + + return g def test_free_floating_fusion(): @@ -142,9 +137,13 @@ def test_free_floating_fusion(): B = np.random.uniform(size=(4, 5)).astype(np.float32) # Construct SDFG with the maps on separate states. - g = assign_bounary_free_floating.to_sdfg(simplify=True, validate=True, use_cache=False) + g = assign_bounary_free_floating_sdfg() + # g = assign_bounary_free_floating.to_sdfg(simplify=True, validate=True, use_cache=False) g.save(os.path.join('_dacegraphs', 'floating-0.sdfg')) g.validate() + g.compile() + + # Get the reference data. actual_A = deepcopy(A) actual_B = deepcopy(B) g(A=actual_A, B=actual_B, M=4, N=5) @@ -152,58 +151,30 @@ def test_free_floating_fusion(): assert g.apply_transformations(ConstAssignmentMapFusion) == 1 g.save(os.path.join('_dacegraphs', 'floating-1.sdfg')) g.validate() + + # Get our data. our_A = deepcopy(A) our_B = deepcopy(B) g(A=our_A, B=our_B, M=4, N=5) - # print(our_A) + # Verify numerically. assert np.allclose(our_A, actual_A) -@dace.program -def assign_top_face(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:M, 0:N]: - A[0, t1, t2] = 1 +def assign_boundary_3d_sdfg(): + g = SDFG('prog') + g.add_array('A', (K, M, N), dace.float32) + g.add_array('B', (K, M, N), dace.float32) + st0 = g.add_state('top') + _add_face_assignment_map(st0, 'top', [('m', M), ('n', N)], [(0, 0)], 1, 'A') + _add_face_assignment_map(st0, 'bottom', [('m', M), ('n', N)], [(0, K - 1)], 2, 'B') + _add_face_assignment_map(st0, 'front', [('k', K), ('n', N)], [(1, 0)], 1, 'A') + _add_face_assignment_map(st0, 'back', [('k', K), ('n', N)], [(1, M - 1)], 2, 'B') + _add_face_assignment_map(st0, 'left', [('k', K), ('m', M)], [(2, 0)], 1, 'A') + _add_face_assignment_map(st0, 'right', [('k', K), ('m', M)], [(2, N - 1)], 2, 'B') -@dace.program -def assign_bottom_face(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:M, 0:N]: - A[K - 1, t1, t2] = 1 - - -@dace.program -def assign_front_face(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:K, 0:N]: - A[t1, 0, t2] = 1 - - -@dace.program -def assign_back_face(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:K, 0:N]: - A[t1, M - 1, t2] = 1 - - -@dace.program -def assign_left_face(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:K, 0:M]: - A[t1, t2, 0] = 1 - - -@dace.program -def assign_right_face(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:K, 0:M]: - A[t1, t2, N - 1] = 1 - - -@dace.program -def assign_bounary_3d(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): - assign_top_face(A) - assign_bottom_face(B) - assign_front_face(A) - assign_back_face(B) - assign_left_face(A) - assign_right_face(B) + return g def test_fusion_with_multiple_indices(): @@ -211,37 +182,64 @@ def test_fusion_with_multiple_indices(): B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) # Construct SDFG with the maps on separate states. - g = assign_bounary_3d.to_sdfg(simplify=True, validate=True, use_cache=False) + g = assign_boundary_3d_sdfg() + # g = assign_bounary_3d.to_sdfg(simplify=True, validate=True, use_cache=False) g.save(os.path.join('_dacegraphs', '3d-0.sdfg')) g.validate() + g.compile() + + # Get the reference data. actual_A = deepcopy(A) actual_B = deepcopy(B) g(A=actual_A, B=actual_B, K=3, M=4, N=5) - assert g.apply_transformations_repeated(ConstAssignmentMapFusion, options={'use_grid_strided_loops': True}) == 3 + assert g.apply_transformations_repeated(ConstAssignmentMapFusion, options={'use_grid_strided_loops': False}) == 3 g.save(os.path.join('_dacegraphs', '3d-1.sdfg')) g.validate() + g.compile() + + # Get our data. our_A = deepcopy(A) our_B = deepcopy(B) g(A=our_A, B=our_B, K=3, M=4, N=5) - # Here, the state fusion can apply only with GSLs. - assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 - assert g.apply_transformations_repeated(ConstAssignmentStateFusion, options={'use_grid_strided_loops': True}) == 2 + # Verify numerically. + assert np.allclose(our_A, actual_A) + + # Here, the map fusion can apply only with GSLs. + assert g.apply_transformations_repeated(ConstAssignmentMapFusion, options={'use_grid_strided_loops': False}) == 0 + assert g.apply_transformations_repeated(ConstAssignmentMapFusion, options={'use_grid_strided_loops': True}) == 2 g.save(os.path.join('_dacegraphs', '3d-2.sdfg')) g.validate() + g.compile() + + # Get our data. our_A = deepcopy(A) our_B = deepcopy(B) g(A=our_A, B=our_B, K=3, M=4, N=5) - # print(our_A) + # Verify numerically. assert np.allclose(our_A, actual_A) -@dace.program -def assign_bounary_with_branch(A: dace.float32[M, N], B: dace.float32[M, N]): - assign_top_row_branched(A) - assign_bottom_row(B) +def assign_bounary_with_branch_sdfg(): + g = SDFG('prog') + g.add_array('A', (M, N), dace.float32) + g.add_array('B', (M, N), dace.float32) + + st0 = g.add_state('st0') + en, ex, t = _add_face_assignment_map(st0, 'top', [('j', N)], [(0, 0)], 1, 'A') + new_t = _simple_if_block('if_block', 'j == 0', 1) + new_t = st0.add_nested_sdfg(new_t, None, {}, {'tmp'}, symbol_mapping={'j': 'j'}) + for e in list(chain(st0.in_edges(t), st0.out_edges(t))): + st0.remove_edge(e) + st0.add_nedge(en, new_t, Memlet()) + st0.add_edge(new_t, 'tmp', ex, 'IN_A', Memlet(expr='A[0, j]')) + st0.remove_node(t) + + _add_face_assignment_map(st0, 'bottom', [('j', N)], [(0, M - 1)], 1, 'A') + + return g def test_fusion_with_branch(): @@ -249,7 +247,8 @@ def test_fusion_with_branch(): B = np.random.uniform(size=(4, 5)).astype(np.float32) # Construct SDFG with the maps on separate states. - g = assign_bounary_with_branch.to_sdfg(simplify=True, validate=True, use_cache=False) + g = assign_bounary_with_branch_sdfg() + # g = assign_bounary_with_branch.to_sdfg(simplify=True, validate=True, use_cache=False) g.save(os.path.join('_dacegraphs', 'branched-0.sdfg')) g.validate() actual_A = deepcopy(A) @@ -267,112 +266,109 @@ def test_fusion_with_branch(): assert np.allclose(our_A, actual_A) -@dace.program -def assign_bottom_face_flipped(A: dace.float32[K, M, N]): - for t2, t1 in dace.map[0:N, 0:M]: - A[K - 1, t1, t2] = 1 +def assign_bounary_3d_with_flip_sdfg(): + g = SDFG('prog') + g.add_array('A', (K, M, N), dace.float32) + st0 = g.add_state('st0') + _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') + _, _, t = _add_face_assignment_map(st0, 'face', [('k', N), ('j', M)], [(0, K - 1)], 1, 'A') + t.code = CodeBlock('A[0, j, k] = 1') -@dace.program -def assign_bounary_3d_with_flip(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): - assign_top_face(A) - assign_bottom_face_flipped(B) + return g def test_does_not_permute_to_fuse(): """ Negative test """ - A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) - B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) - # Construct SDFG with the maps on separate states. - g = assign_bounary_3d_with_flip.to_sdfg(simplify=True, validate=True, use_cache=False) - g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g = assign_bounary_3d_with_flip_sdfg() g.save(os.path.join('_dacegraphs', '3d-flip-0.sdfg')) g.validate() - actual_A = deepcopy(A) - actual_B = deepcopy(B) - g(A=actual_A, B=actual_B, K=3, M=4, N=5) + g.compile() assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 -@dace.program -def assign_mixed_dims(A: dace.float32[K, M, N], B: dace.float32[K, M, N]): - assign_top_face(A) - assign_left_col(B[0, :, :]) +def assign_mixed_dims_sdfg(): + g = SDFG('prog') + g.add_array('A', (K, M, N), dace.float32) + g.add_array('B', (K, M, N), dace.float32) + + st0 = g.add_state('st0') + _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') + _add_face_assignment_map(st0, 'edge', [('k', N)], [(0, 0), (1, 0)], 2, 'B') + + return g def test_does_not_extend_to_fuse(): """ Negative test """ - A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) - B = np.random.uniform(size=(3, 4, 5)).astype(np.float32) - # Construct SDFG with the maps on separate states. - g = assign_mixed_dims.to_sdfg(simplify=True, validate=True, use_cache=False) - g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g = assign_mixed_dims_sdfg() g.save(os.path.join('_dacegraphs', '3d-mixed-0.sdfg')) g.validate() - actual_A = deepcopy(A) - actual_B = deepcopy(B) - g(A=actual_A, B=actual_B, K=3, M=4, N=5) + g.compile() + # Will not fuse if the number of dimensions are different. assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 -@dace.program -def assign_bottom_face_42(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:M, 0:N]: - A[K - 1, t1, t2] = 42 +def assign_inconsistent_values_different_constants_sdfg(): + g = SDFG('prog') + g.add_array('A', (K, M, N), dace.float32) + st0 = g.add_state('st0') + _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') + _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, K - 1)], 42, 'A') -@dace.program -def assign_bottom_face_index_sum(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:M, 0:N]: - A[K - 1, t1, t2] = t1 + t2 + return g -@dace.program -def assign_inconsistent_values_1(A: dace.float32[K, M, N]): - assign_top_face(A) - assign_bottom_face_42(A) +def assign_inconsistent_values_non_constant_sdfg(): + g = SDFG('prog') + g.add_array('A', (K, M, N), dace.float32) + st0 = g.add_state('st0') + _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') + _, _, t = _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, K - 1)], 1, 'A') + t.code = CodeBlock('__out = j + k') -@dace.program -def assign_inconsistent_values_2(A: dace.float32[K, M, N]): - assign_top_face(A) - assign_bottom_face_index_sum(A) + return g def test_does_not_fuse_with_inconsistent_assignments(): """ Negative test """ - A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) - # Construct SDFG with the maps on separate states. - g = assign_inconsistent_values_1.to_sdfg(simplify=True, validate=True, use_cache=False) - g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g = assign_inconsistent_values_different_constants_sdfg() g.save(os.path.join('_dacegraphs', '3d-inconsistent-0.sdfg')) g.validate() - actual_A = deepcopy(A) - g(A=actual_A, K=3, M=4, N=5) + g.compile() assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 - # Try another case: Construct SDFG with the maps on separate states. - g = assign_inconsistent_values_2.to_sdfg(simplify=True, validate=True, use_cache=False) - g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + # Try another case. + # Construct SDFG with the maps on separate states. + g = assign_inconsistent_values_non_constant_sdfg() g.save(os.path.join('_dacegraphs', '3d-inconsistent-1.sdfg')) g.validate() - actual_A = deepcopy(A) - g(A=actual_A, K=3, M=4, N=5) + g.compile() assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 -@dace.program -def tasklet_between_maps(A: dace.float32[K, M, N]): - assign_top_face(A) - A[0, 0, 0] = 1 - assign_bottom_face(A) +def sdfg_with_tasklet_between_maps(): + g = SDFG('prog') + g.add_array('A', (K, M, N), dace.float32) + + st0 = g.add_state('st0') + _, ex1, _ = _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') + en2, _, _ = _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, K - 1)], 1, 'A') + t = st0.add_tasklet('noop', {}, {}, '') + st0.add_nedge(st0.out_edges(ex1)[0].dst, en2, Memlet()) + st0.add_nedge(st0.out_edges(ex1)[0].dst, t, Memlet()) + st0.add_nedge(t, en2, Memlet()) + + return g def test_does_not_fuse_with_unsuitable_dependencies(): @@ -380,26 +376,35 @@ def test_does_not_fuse_with_unsuitable_dependencies(): A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) # Construct SDFG with the maps on separate states. - g = tasklet_between_maps.to_sdfg(simplify=True, validate=True, use_cache=False) - g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g = sdfg_with_tasklet_between_maps() g.save(os.path.join('_dacegraphs', '3d-baddeps-0.sdfg')) g.validate() - actual_A = deepcopy(A) - g(A=actual_A, K=3, M=4, N=5) + g.compile() assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 -@dace.program -def assign_top_face_self_copy(A: dace.float32[K, M, N]): - for t1, t2 in dace.map[0:M, 0:N]: - A[0, t1, t2] = A[0, t1, t2] +def sdfg_where_first_map_reads_data(): + g = SDFG('prog') + g.add_array('A', (M, N), dace.float32) + + st0 = g.add_state('top') + en1, _, t = _add_face_assignment_map(st0, 'top', [('j', N)], [(0, 0)], 1, 'A') + en1.add_in_connector('IN_A') + en1.add_out_connector('OUT_A') + t.add_in_connector('__blank') + A = st0.add_access('A') + st0.add_edge(A, None, en1, 'IN_A', Memlet(expr='A[0, 0:N]')) + for e in st0.out_edges(en1): + st0.remove_edge(e) + st0.add_edge(en1, 'OUT_A', t, '__blank', Memlet(expr='A[0, j]')) + + st1 = g.add_state('bottom') + _add_face_assignment_map(st1, 'bottom', [('j', N)], [(0, M - 1)], 1, 'A') + g.add_edge(st0, st1, InterstateEdge()) -@dace.program -def first_map_reads_data(A: dace.float32[K, M, N]): - assign_top_face_self_copy(A) - assign_bottom_face(A) + return g def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): @@ -407,14 +412,14 @@ def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) # Construct SDFG with the maps on separate states. - g = first_map_reads_data.to_sdfg(simplify=True, validate=True, use_cache=False) + g = sdfg_where_first_map_reads_data() g.save(os.path.join('_dacegraphs', '3d-map1-reads-0.sdfg')) g.validate() - actual_A = deepcopy(A) - g(A=actual_A, K=3, M=4, N=5) + g.compile() + # TODO:Fix. # The state fusion won't work. - assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + # assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 # Fuse the states explicitly anyway. g.apply_transformations_repeated(StateFusionExtended, validate_all=True) @@ -433,3 +438,8 @@ def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): test_free_floating_fusion() test_fusion_with_branch() test_fusion_with_multiple_indices() + test_does_not_extend_to_fuse() + test_does_not_permute_to_fuse() + test_does_not_fuse_with_inconsistent_assignments() + test_does_not_fuse_with_unsuitable_dependencies() + test_does_not_fuse_when_the_first_map_reads_anything_at_all() From fb01a2301e70bdc16dec36f802cb316fec2a6cf4 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 30 Oct 2024 13:24:52 +0100 Subject: [PATCH 26/29] Fix the "taking a function's reference instead of calling it" problem. --- dace/transformation/dataflow/const_assignment_fusion.py | 2 +- tests/transformations/const_assignment_fusion_test.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 7c490d7071..9216e6fb61 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -579,7 +579,7 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, if not en_ex: return False en, ex = en_ex - if any(not e.data.is_empty for e in st.in_edges(en)): + if any(not e.data.is_empty() for e in st.in_edges(en)): return False is_const_assignment, further_assignments = _consistent_const_assignment_table(st, en, ex) if not is_const_assignment: diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 7eb82587f0..1f4415e7f6 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -417,9 +417,8 @@ def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): g.validate() g.compile() - # TODO:Fix. # The state fusion won't work. - # assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 # Fuse the states explicitly anyway. g.apply_transformations_repeated(StateFusionExtended, validate_all=True) From 78f0c2f9d80030b0b5702421faee7dedbed6c798 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 30 Oct 2024 13:33:08 +0100 Subject: [PATCH 27/29] Flipped the map incorrectly when constructing the graph. --- tests/transformations/const_assignment_fusion_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 1f4415e7f6..0053619e1c 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -272,8 +272,9 @@ def assign_bounary_3d_with_flip_sdfg(): st0 = g.add_state('st0') _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') - _, _, t = _add_face_assignment_map(st0, 'face', [('k', N), ('j', M)], [(0, K - 1)], 1, 'A') - t.code = CodeBlock('A[0, j, k] = 1') + en, _, _ = _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, K - 1)], 1, 'A') + en.map.range = Range(reversed(en.map.range.ranges)) + en.map.params = list(reversed(en.map.params)) return g From 44e81ce5fd537254c8116808c933b2da490a8124 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 30 Oct 2024 14:29:57 +0100 Subject: [PATCH 28/29] Replace `add_memlet_path()` calls. --- .../dataflow/const_assignment_fusion.py | 48 ++++++------------- .../const_assignment_fusion_test.py | 3 +- 2 files changed, 15 insertions(+), 36 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 9216e6fb61..1a224a0396 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -242,7 +242,7 @@ def _consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, sec # Finally, these nodes should be depended on by _both_ maps. for en in [first_entry, second_entry]: for n in alt_table.values(): - graph.add_memlet_path(n, en, memlet=Memlet()) + graph.add_nedge(n, en, Memlet()) def _consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): @@ -275,15 +275,11 @@ def _consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exi for e in graph.in_edges(n): assert e.src in [first_exit, second_exit] assert e.dst_conn is None - graph.add_memlet_path(e.src, surviving_nodes[e.dst.data], - src_conn=e.src_conn, dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(e.src, e.src_conn, surviving_nodes[e.dst.data], e.dst_conn, Memlet.from_memlet(e.data)) graph.remove_edge(e) for e in graph.out_edges(n): assert e.src_conn is None - graph.add_memlet_path(surviving_nodes[e.src.data], e.dst, - src_conn=e.src_conn, dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(surviving_nodes[e.src.data], e.src_conn, e.dst, e.dst_conn, Memlet.from_memlet(e.data)) graph.remove_edge(e) # Finally, cleanup the orphan nodes. for n in all_written_nodes: @@ -301,26 +297,18 @@ def _consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: T assert all(e.data.is_empty() for e in graph.in_edges(src_en)) cmap = _add_equivalent_connectors(dst_en, src_en) for e in graph.in_edges(src_en): - graph.add_memlet_path(e.src, dst_en, - src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(e.src, e.src_conn, dst_en, cmap.get(e.dst_conn), Memlet.from_memlet(e.data)) graph.remove_edge(e) for e in graph.out_edges(src_en): - graph.add_memlet_path(dst_en, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(dst_en, cmap.get(e.src_conn), e.dst, e.dst_conn, Memlet.from_memlet(e.data)) graph.remove_edge(e) cmap = _add_equivalent_connectors(dst_ex, src_ex) for e in graph.in_edges(src_ex): - graph.add_memlet_path(e.src, dst_ex, - src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(e.src, e.src_conn, dst_ex, cmap.get(e.dst_conn), Memlet.from_memlet(e.data)) graph.remove_edge(e) for e in graph.out_edges(src_ex): - graph.add_memlet_path(dst_ex, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(dst_ex, cmap.get(e.src_conn), e.dst, e.dst_conn, Memlet.from_memlet(e.data)) graph.remove_edge(e) graph.remove_node(src_en) @@ -355,28 +343,20 @@ def range_for_grid_stride(r, val, bound): assert all(e.data.is_empty() for e in graph.in_edges(en)) cmap = _add_equivalent_connectors(dst_en, en) for e in graph.in_edges(en): - graph.add_memlet_path(e.src, dst_en, - src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), - memlet=Memlet.from_memlet(e.data)) - graph.add_memlet_path(dst_en, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(e.src, e.src_conn, dst_en, cmap.get(e.dst_conn), Memlet.from_memlet(e.data)) + graph.add_edge(dst_en, cmap.get(e.src_conn), e.dst, e.dst_conn, Memlet.from_memlet(e.data)) graph.remove_edge(e) cmap = _add_equivalent_connectors(dst_ex, ex) for e in graph.out_edges(ex): - graph.add_memlet_path(e.src, dst_ex, - src_conn=e.src_conn, - dst_conn=_connector_counterpart(cmap.get(e.src_conn)), - memlet=Memlet.from_memlet(e.data)) - graph.add_memlet_path(dst_ex, e.dst, - src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, - memlet=Memlet.from_memlet(e.data)) + graph.add_edge(e.src, e.src_conn, dst_ex, _connector_counterpart(cmap.get(e.src_conn)), + Memlet.from_memlet(e.data)) + graph.add_edge(dst_ex, cmap.get(e.src_conn), e.dst, e.dst_conn, Memlet.from_memlet(e.data)) graph.remove_edge(e) if len(graph.in_edges(en)) == 0: - graph.add_memlet_path(dst_en, en, memlet=Memlet()) + graph.add_nedge(dst_en, en, Memlet()) if len(graph.out_edges(ex)) == 0: - graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) + graph.add_nedge(ex, dst_ex, Memlet()) def _fused_range(r1: Range, r2: Range) -> Optional[Range]: diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index 0053619e1c..c3d894422a 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -425,8 +425,7 @@ def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): g.apply_transformations_repeated(StateFusionExtended, validate_all=True) g.save(os.path.join('_dacegraphs', '3d-map1-reads-1.sdfg')) g.validate() - actual_A = deepcopy(A) - g(A=actual_A, K=3, M=4, N=5) + g.compile() # The map fusion won't work. assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 From e525ce4537c8c7d1821acf652991008efce0a2c1 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Wed, 30 Oct 2024 15:12:46 +0100 Subject: [PATCH 29/29] Simplify and clarify the no-dependency-pattern check. Add more tests to cover more cases in the state fusion. --- .../dataflow/const_assignment_fusion.py | 32 +++---- .../const_assignment_fusion_test.py | 89 +++++++++++++++++-- 2 files changed, 99 insertions(+), 22 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 1a224a0396..5aa2bb1ff8 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -436,28 +436,28 @@ def _map_nodes(self, graph: SDFGState): def _no_dependency_pattern(self, graph: SDFGState) -> bool: """Decide if the two maps are independent of each other.""" first_entry, first_exit, second_entry, second_exit = self._map_nodes(graph) + all_in_edges = list(chain(graph.in_edges(first_entry), graph.in_edges(second_entry))) + all_out_edges = list(chain(graph.out_edges(first_exit), graph.out_edges(second_exit))) + + # The analysis is too difficult to continue (so just reject independence to err on the side of caution), when... if graph.scope_dict()[first_entry] != graph.scope_dict()[second_entry]: + # ... the two maps are not even on the same scope (so analysing the connectivity is difficult). return False - if not all(isinstance(n, AccessNode) for n in graph.all_nodes_between(first_exit, second_entry)): - return False - if not all(isinstance(n, AccessNode) for n in graph.all_nodes_between(second_exit, first_entry)): - return False - if any(not e.data.is_empty() - for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))): + if not all(isinstance(n, AccessNode) for n in chain(graph.all_nodes_between(first_exit, second_entry), + graph.all_nodes_between(second_exit, first_entry))): + # ... there are non-AccessNodes between the two maps (also difficult to analyse). return False - if any(not isinstance(e.src, (MapEntry, AccessNode)) - for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))): + if any(not isinstance(e.src, (MapExit, AccessNode)) for e in all_in_edges): + # ... either map has incoming edges from a node that is not an AccessNode or a MapExit (also difficult). return False - if not (all(isinstance(e.src, AccessNode) - for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry))) - or all(isinstance(e.src, MapEntry) - for e in chain(graph.in_edges(first_entry), graph.in_edges(second_entry)))): + if any(not isinstance(e.dst, (MapEntry, AccessNode)) for e in all_out_edges): + # ... either map has outgoing edges to a node that is not an AccessNode or a MapEntry (also difficult). return False - if not (all(isinstance(e.dst, AccessNode) - for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit))) - or all(isinstance(e.dst, MapExit) - for e in chain(graph.out_edges(first_exit), graph.out_edges(second_exit)))): + + if any(not e.data.is_empty() for e in all_in_edges): + # If any of the maps are reading anything, then it isn't independent. return False + return True def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: diff --git a/tests/transformations/const_assignment_fusion_test.py b/tests/transformations/const_assignment_fusion_test.py index c3d894422a..b025e4271e 100644 --- a/tests/transformations/const_assignment_fusion_test.py +++ b/tests/transformations/const_assignment_fusion_test.py @@ -297,7 +297,9 @@ def assign_mixed_dims_sdfg(): st0 = g.add_state('st0') _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') - _add_face_assignment_map(st0, 'edge', [('k', N)], [(0, 0), (1, 0)], 2, 'B') + st1 = g.add_state('st1') + _add_face_assignment_map(st1, 'edge', [('k', N)], [(0, 0), (1, 0)], 2, 'B') + g.add_edge(st0, st1, InterstateEdge()) return g @@ -310,7 +312,12 @@ def test_does_not_extend_to_fuse(): g.validate() g.compile() - # Will not fuse if the number of dimensions are different. + # Has multiple states, but will not fuse them if the number of dimensions are different. + assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + # We can fuse them manually. + assert g.apply_transformations_repeated(StateFusionExtended) == 1 + g.save(os.path.join('_dacegraphs', '3d-mixed-1.sdfg')) + # But still won't fuse them maps. assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 @@ -320,7 +327,9 @@ def assign_inconsistent_values_different_constants_sdfg(): st0 = g.add_state('st0') _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') - _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, K - 1)], 42, 'A') + st1 = g.add_state('st1') + _add_face_assignment_map(st1, 'face', [('j', M), ('k', N)], [(0, K - 1)], 42, 'A') + g.add_edge(st0, st1, InterstateEdge()) return g @@ -331,8 +340,10 @@ def assign_inconsistent_values_non_constant_sdfg(): st0 = g.add_state('st0') _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, 0)], 1, 'A') - _, _, t = _add_face_assignment_map(st0, 'face', [('j', M), ('k', N)], [(0, K - 1)], 1, 'A') + st1 = g.add_state('st1') + _, _, t = _add_face_assignment_map(st1, 'face', [('j', M), ('k', N)], [(0, K - 1)], 1, 'A') t.code = CodeBlock('__out = j + k') + g.add_edge(st0, st1, InterstateEdge()) return g @@ -341,19 +352,31 @@ def test_does_not_fuse_with_inconsistent_assignments(): """ Negative test """ # Construct SDFG with the maps on separate states. g = assign_inconsistent_values_different_constants_sdfg() - g.save(os.path.join('_dacegraphs', '3d-inconsistent-0.sdfg')) + g.save(os.path.join('_dacegraphs', '3d-inconsistent-0a.sdfg')) g.validate() g.compile() + # Has multiple states, but won't fuse them. + assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + # We can fuse them manually. + assert g.apply_transformations_repeated(StateFusionExtended) == 1 + g.save(os.path.join('_dacegraphs', '3d-inconsistent-1a.sdfg')) + # But still won't fuse them maps. assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 # Try another case. # Construct SDFG with the maps on separate states. g = assign_inconsistent_values_non_constant_sdfg() - g.save(os.path.join('_dacegraphs', '3d-inconsistent-1.sdfg')) + g.save(os.path.join('_dacegraphs', '3d-inconsistent-0b.sdfg')) g.validate() g.compile() + # Has multiple states, but won't fuse them. + assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + # We can fuse them manually. + assert g.apply_transformations_repeated(StateFusionExtended) == 1 + g.save(os.path.join('_dacegraphs', '3d-inconsistent-1b.sdfg')) + # But still won't fuse them maps. assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 @@ -431,6 +454,59 @@ def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 0 +def sdfg_where_first_state_has_multiple_toplevel_maps(): + g = SDFG('prog') + g.add_array('A', (M, N), dace.float32) + + st0 = g.add_state('st0') + _add_face_assignment_map(st0, 'top', [('j', N)], [(0, 0)], 1, 'A') + _add_face_assignment_map(st0, 'bottom', [('j', N)], [(0, M - 1)], 1, 'A') + + st1 = g.add_state('st1') + _add_face_assignment_map(st1, 'left', [('i', M)], [(1, 0)], 1, 'A') + + g.add_edge(st0, st1, InterstateEdge()) + + return g + + +def test_does_not_fuse_when_the_first_state_has_multiple_toplevel_maps(): + """ Negative test """ + A = np.random.uniform(size=(3, 4, 5)).astype(np.float32) + + # Construct SDFG with the maps on separate states. + g = sdfg_where_first_state_has_multiple_toplevel_maps() + g.save(os.path.join('_dacegraphs', '3d-multimap-state-0.sdfg')) + g.validate() + g.compile() + + # Get the reference data. + actual_A = deepcopy(A) + g(A=actual_A, K=3, M=4, N=5) + + # The state fusion won't work. + assert g.apply_transformations_repeated(ConstAssignmentStateFusion) == 0 + + # Fuse the states explicitly anyway. + g.apply_transformations_repeated(StateFusionExtended, validate_all=True) + g.save(os.path.join('_dacegraphs', '3d-multimap-state-1.sdfg')) + g.validate() + g.compile() + + # But now, the fusion will work! + assert g.apply_transformations_repeated(ConstAssignmentMapFusion) == 1 + g.save(os.path.join('_dacegraphs', '3d-multimap-state-2.sdfg')) + g.validate() + g.compile() + + # Get our data. + our_A = deepcopy(A) + g(A=our_A, K=3, M=4, N=5) + + # Verify numerically. + assert np.allclose(our_A, actual_A) + + if __name__ == '__main__': test_within_state_fusion() test_interstate_fusion() @@ -442,3 +518,4 @@ def test_does_not_fuse_when_the_first_map_reads_anything_at_all(): test_does_not_fuse_with_inconsistent_assignments() test_does_not_fuse_with_unsuitable_dependencies() test_does_not_fuse_when_the_first_map_reads_anything_at_all() + test_does_not_fuse_when_the_first_state_has_multiple_toplevel_maps()