Skip to content

Commit

Permalink
Handle the AST API differences between python versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Oct 24, 2024
1 parent 3342621 commit e1bf9a6
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions dace/transformation/dataflow/const_assignment_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Optional, Union
from typing import Optional, Union, Tuple

from dace import transformation, SDFGState, SDFG, Memlet, ScheduleType, subsets
from dace.properties import make_properties, Property
Expand All @@ -14,7 +14,7 @@
from dace.transformation.interstate import StateFusionExtended


def unique_top_level_map_node(graph: SDFGState) -> Optional[tuple[MapEntry, MapExit]]:
def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]:
all_top_nodes = [n for n, s in graph.scope_dict().items() if s is None]
if not all(isinstance(n, (MapEntry, AccessNode)) for n in all_top_nodes):
return None
Expand All @@ -32,7 +32,7 @@ def floating_nodes_graph(*args):
return g


def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]:
def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]:
"""
If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays
and memlets to their consistent constant assignments. See the class docstring for what is considered consistent.
Expand Down Expand Up @@ -86,10 +86,10 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]:
# ...must assign...
return False, table
op = n.code.code[0]
if not isinstance(op.value, ast.Constant) or len(op.targets) != 1:
if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1:
# ...a constant to a single target.
return False, table
const = op.value.value
const = value_of_constant_or_numerical_literal(op.value)
for oe in body.out_edges(n):
dst = oe.data
dst_arr = oe.data.data
Expand All @@ -101,7 +101,17 @@ def consistent_branch_const_assignment_table(graph: Node) -> tuple[bool, dict]:
return True, table


def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> tuple[bool, dict]:
def is_constant_or_numerical_literal(n: ast.Expr):
"""Work around the API differences between Python versions (e.g., 3.7 and 3.12)"""
return isinstance(n, (ast.Constant, ast.Num))


def value_of_constant_or_numerical_literal(n: ast.Expr):
"""Work around the API differences between Python versions (e.g., 3.7 and 3.12)"""
return n.value if isinstance(n, ast.Constant) else n.n


def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]:
"""
If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table
mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is
Expand Down Expand Up @@ -141,10 +151,10 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi
# ...that assigns...
return False, table
op = n.code.code[0]
if not isinstance(op.value, ast.Constant) or len(op.targets) != 1:
if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1:
# ...a constant to a single target.
return False, table
const = op.value.value
const = value_of_constant_or_numerical_literal(op.value)
for oe in graph.out_edges(n):
dst = oe.data
dst_arr = oe.data.data
Expand All @@ -156,6 +166,13 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi
return True, table


def removeprefix(c: str, p: str):
"""Since `str.removeprefix()` wasn't added until Python 3.9"""
if not c.startswith(p):
return c
return c[len(p):]


def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]):
"""
Create the additional connectors in the first exit node that matches the second exit node (which will be removed
Expand All @@ -164,7 +181,7 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN
conn_map = defaultdict()
for c, v in src.in_connectors.items():
assert c.startswith('IN_')
cbase = c.removeprefix('IN_')
cbase = removeprefix(c, 'IN_')
sc = dst.next_connector(cbase)
conn_map[f"IN_{cbase}"] = f"IN_{sc}"
conn_map[f"OUT_{cbase}"] = f"OUT_{sc}"
Expand All @@ -181,9 +198,9 @@ def connector_counterpart(c: Union[str, None]) -> Union[str, None]:
return None
assert isinstance(c, str)
if c.startswith('IN_'):
return f"OUT_{c.removeprefix('IN_')}"
return f"OUT_{removeprefix(c, 'IN_')}"
elif c.startswith('OUT_'):
return f"IN_{c.removeprefix('OUT_')}"
return f"IN_{removeprefix(c, 'OUT_')}"
return None


Expand Down Expand Up @@ -274,7 +291,7 @@ def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit
graph.remove_node(n)


def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tuple[MapEntry, MapExit]):
def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]):
"""
Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical.
"""
Expand Down Expand Up @@ -310,8 +327,8 @@ def consume_map_exactly(graph: SDFGState, dst: tuple[MapEntry, MapExit], src: tu
graph.remove_node(src_ex)


def consume_map_with_grid_strided_loop(graph: SDFGState, dst: tuple[MapEntry, MapExit],
src: tuple[MapEntry, MapExit]):
def consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit],
src: Tuple[MapEntry, MapExit]):
"""
Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop.
Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not
Expand Down

0 comments on commit e1bf9a6

Please sign in to comment.