diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index 2c08754891..ab5027a4f4 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -2,7 +2,7 @@ from collections import defaultdict from copy import deepcopy from itertools import chain -from typing import Optional, Union +from typing import Optional, Union, Tuple from dace import transformation, SDFGState, SDFG, Memlet, ScheduleType, subsets from dace.properties import make_properties, Property @@ -14,7 +14,7 @@ from dace.transformation.interstate import StateFusionExtended -def unique_top_level_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]: +def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]: all_top_nodes = [n for n, s in graph.scope_dict().items() if s is None] if not all(isinstance(n, (MapEntry, AccessNode)) for n in all_top_nodes): return None @@ -32,7 +32,7 @@ def floating_nodes_graph(*args): return g -def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: +def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: """ If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is considered consistent. @@ -86,10 +86,10 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: # ...must assign... return False, table op = n.code.code[0] - if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: + if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = op.value.value + const = value_of_constant_or_numerical_literal(op.value) for oe in body.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -101,7 +101,17 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]: return True, table -def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]: +def is_constant_or_numerical_literal(n: ast.Expr): + """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" + return isinstance(n, (ast.Constant, ast.Num)) + + +def value_of_constant_or_numerical_literal(n: ast.Expr): + """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" + return n.value if isinstance(n, ast.Constant) else n.n + + +def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]: """ If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is @@ -141,10 +151,10 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi # ...that assigns... return False, table op = n.code.code[0] - if not isinstance(op.value, ast.Constant) or len(op.targets) != 1: + if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = op.value.value + const = value_of_constant_or_numerical_literal(op.value) for oe in graph.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -156,6 +166,13 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi return True, table +def removeprefix(c: str, p: str): + """Since `str.removeprefix()` wasn't added until Python 3.9""" + if not c.startswith(p): + return c + return c[len(p):] + + def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): """ Create the additional connectors in the first exit node that matches the second exit node (which will be removed @@ -164,7 +181,7 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN conn_map = defaultdict() for c, v in src.in_connectors.items(): assert c.startswith('IN_') - cbase = c.removeprefix('IN_') + cbase = removeprefix(c, 'IN_') sc = dst.next_connector(cbase) conn_map[f"IN_{cbase}"] = f"IN_{sc}" conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" @@ -181,9 +198,9 @@ def connector_counterpart(c: Union[str, None]) -> Union[str, None]: return None assert isinstance(c, str) if c.startswith('IN_'): - return f"OUT_{c.removeprefix('IN_')}" + return f"OUT_{removeprefix(c, 'IN_')}" elif c.startswith('OUT_'): - return f"IN_{c.removeprefix('OUT_')}" + return f"IN_{removeprefix(c, 'OUT_')}" return None @@ -274,7 +291,7 @@ def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit graph.remove_node(n) -def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tuple[MapEntry, MapExit]): +def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical. """ @@ -310,8 +327,8 @@ def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tu graph.remove_node(src_ex) -def consume_map_with_grid_strided_loop(graph: SDFGState, dst: tuple[MapEntry, MapExit], - src: tuple[MapEntry, MapExit]): +def consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit], + src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop. Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not