From 9e01539bc6d7ee555ce3dd3db17e056237b3e340 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Sun, 30 Jun 2024 15:41:35 +0200 Subject: [PATCH 01/38] Create ConditionalRegion and implement its serialization --- dace/sdfg/sdfg.py | 2 +- dace/sdfg/state.py | 80 +++++++++++++++++++++++++-- tests/sdfg/conditional_region_test.py | 32 +++++++++++ 3 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 tests/sdfg/conditional_region_test.py diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 82d98c1e18..1091ad3e94 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -610,7 +610,7 @@ def from_json(cls, json_obj, context_info=None): nci = copy.copy(context_info) nci['sdfg'] = ret - state = SDFGState.from_json(n, context=nci) + state = dace.serialize.from_json(n, context=nci) ret.add_node(state) nodelist.append(state) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 736a4799df..d1b91a0416 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -22,7 +22,7 @@ from dace.properties import (CodeBlock, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge +from dace.sdfg.graph import MultiConnectorEdge, NodeNotFoundError, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -2740,11 +2740,10 @@ def from_json(cls, json_obj, context_info=None): if _type != cls.__name__: raise TypeError("Class type mismatch") - attrs = json_obj['attributes'] nodes = json_obj['nodes'] edges = json_obj['edges'] - ret = ControlFlowRegion(label=attrs['label'], sdfg=context_info['sdfg']) + ret = ControlFlowRegion(label=json_obj['label'], sdfg=context_info['sdfg']) dace.serialize.set_properties_from_json(ret, json_obj) @@ -2753,7 +2752,7 @@ def from_json(cls, json_obj, context_info=None): nci = copy.copy(context_info) nci['parent_graph'] = ret - state = SDFGState.from_json(n, context=nci) + state = dace.serialize.from_json(n, nci) ret.add_node(state) nodelist.append(state) @@ -3119,3 +3118,76 @@ def has_return(self) -> bool: if isinstance(node, ReturnBlock): return True return False + +@dace.serialize.serializable +class ConditionalRegion(ControlFlowBlock): + def __init__(self, label: str): + super().__init__(label) + self.branches: List[Tuple[CodeBlock, ControlFlowRegion]] = [] + + def nodes(self) -> List['ControlFlowBlock']: + return [node for branch in self.branches for node in branch[1].nodes()] + + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + return [edge for branch in self.branches for edge in branch[1].edges()] + + def node_id(self, node: ControlFlowBlock) -> int: + try: + return self.nodes().index(node) + except ValueError: + raise NodeNotFoundError(node) + + def _used_symbols_internal(self, + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment + + b_free_symbols, b_defined_symbols, b_used_before_assignment = super()._used_symbols_internal( + all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) + free_syms |= b_free_symbols + defined_syms |= b_defined_symbols + used_before_assignment |= b_used_before_assignment + + for condition, cfg in self.branches: + free_syms |= condition.get_free_symbols() + b_free_symbols, b_defined_symbols, b_used_before_assignment = cfg._used_symbols_internal( + all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) + free_syms |= b_free_symbols + defined_syms |= b_defined_symbols + used_before_assignment |= b_used_before_assignment + + defined_syms -= used_before_assignment + free_syms -= defined_syms + + return free_syms, defined_syms, used_before_assignment + + def replace_dict(self, + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, + replace_keys: bool = True): + if replace_keys: + from dace.sdfg.replace import replace_properties_dict + replace_properties_dict(self, repl, symrepl) + + super().replace_dict(repl, symrepl, replace_in_graph) + for _, cfg in self.branches: + cfg.replace_dict(repl, symrepl, replace_in_graph) + + def to_json(self, parent=None): + json = super().to_json(parent) + json["branches"] = [(condition.to_json(), cfg.to_json()) for condition, cfg in self.branches] + return json + + @classmethod + def from_json(cls, json_obj, context=None): + cond_region = ConditionalRegion(json_obj["label"]) + cond_region.is_collapsed = json_obj["collapsed"] + cond_region.branches = [(CodeBlock.from_json(condition), ControlFlowRegion.from_json(cfg, context_info=context)) + for condition, cfg in json_obj["branches"]] + return cond_region diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py new file mode 100644 index 0000000000..b9aaab8ecb --- /dev/null +++ b/tests/sdfg/conditional_region_test.py @@ -0,0 +1,32 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +from dace.properties import CodeBlock +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ConditionalRegion, ControlFlowRegion +import dace.serialize + + +def test_cond_region_if(): + sdfg = dace.SDFG('regular_if') + sdfg.add_symbol("i", dace.int32) + state0 = sdfg.add_state('state0', is_start_block=True) + + if1 = ConditionalRegion("if1") + if_body = ControlFlowRegion("if_body") + state1 = if_body.add_state("state1", is_start_block=True) + state2 = if_body.add_state("state2") + if_body.add_edge(state1, state2, InterstateEdge(assignments={"i": "100"})) + if1.branches.append((CodeBlock("i == 1"), if_body)) + + sdfg.add_edge(state0, if1, InterstateEdge()) + + assert sdfg.is_valid() + + json = sdfg.to_json() + new_sdfg = SDFG.from_json(json) + + assert new_sdfg.is_valid() + +if __name__ == '__main__': + test_cond_region_if() From 4430c10c4994af4d1a63a4be1c5c8294cebe6d76 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Mon, 1 Jul 2024 18:52:22 +0200 Subject: [PATCH 02/38] Inline conditional regions and add test for serialization --- dace/sdfg/sdfg.py | 1 + dace/sdfg/state.py | 47 ++++++++++++++++++++++----- dace/sdfg/utils.py | 14 +++++++- dace/sdfg/validation.py | 5 ++- tests/sdfg/conditional_region_test.py | 47 +++++++++++++++++++++++---- 5 files changed, 98 insertions(+), 16 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 1091ad3e94..3adff66834 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -2226,6 +2226,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # TODO (later): Adapt codegen to deal with hierarchical CFGs instead. sdutils.inline_loop_blocks(sdfg) sdutils.inline_control_flow_regions(sdfg) + sdutils.inline_conditional_regions(sdfg) # Rename SDFG to avoid runtime issues with clashing names index = 0 diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index d1b91a0416..b677d6a658 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3126,16 +3126,10 @@ def __init__(self, label: str): self.branches: List[Tuple[CodeBlock, ControlFlowRegion]] = [] def nodes(self) -> List['ControlFlowBlock']: - return [node for branch in self.branches for node in branch[1].nodes()] + return [node for _, node in self.branches] def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: - return [edge for branch in self.branches for edge in branch[1].edges()] - - def node_id(self, node: ControlFlowBlock) -> int: - try: - return self.nodes().index(node) - except ValueError: - raise NodeNotFoundError(node) + return [] def _used_symbols_internal(self, all_symbols: bool, @@ -3191,3 +3185,40 @@ def from_json(cls, json_obj, context=None): cond_region.branches = [(CodeBlock.from_json(condition), ControlFlowRegion.from_json(cfg, context_info=context)) for condition, cfg in json_obj["branches"]] return cond_region + + def inline(self) -> Tuple[bool, Any]: + """ + Inlines the conditional region into its parent control flow region. + + :return: True if the inlining succeeded, false otherwise. + """ + parent = self.parent_graph + if not parent: + raise RuntimeError('No top-level SDFG present to inline into') + + # Add all boilerplate states necessary for the structure. + guard_state = parent.add_state(self.label + '_guard') + end_state = parent.add_state(self.label + '_end') + + # Redirect all edges to the region to the init state. + for b_edge in parent.in_edges(self): + parent.add_edge(b_edge.src, guard_state, b_edge.data) + parent.remove_edge(b_edge) + # Redirect all edges exiting the region to instead exit the end state. + for a_edge in parent.out_edges(self): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + from dace.sdfg.sdfg import InterstateEdge + for condition, cfg in self.branches: + parent.add_node(cfg) + parent.add_edge(guard_state, cfg, InterstateEdge(condition=condition)) + parent.add_edge(cfg, end_state, InterstateEdge()) + + parent.remove_node(self) + + sdfg = parent if isinstance(parent, dace.SDFG) else parent.sdfg + sdfg.reset_cfg_list() + + return True, (guard_state, end_state) + diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 12f66db85f..5567081b0f 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,7 +13,7 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion +from dace.sdfg.state import ConditionalRegion, SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs @@ -1274,6 +1274,18 @@ def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: return count +def inline_conditional_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalRegion)] + count = 0 + + for _block in optional_progressbar(reversed(blocks), title='Inlining conditional regions', + n=len(blocks), progress=progress): + block: ControlFlowRegion = _block + if block.inline()[0]: + count += 1 + + return count + def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int: """ diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 480fb9c262..ffb0483da8 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -35,7 +35,7 @@ def validate_control_flow_region(sdfg: 'SDFG', symbols: dict, references: Set[int] = None, **context: bool): - from dace.sdfg.state import SDFGState, ControlFlowRegion + from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalRegion from dace.sdfg.scope import is_in_scope if len(region.source_nodes()) > 1 and region.start_block is None: @@ -119,6 +119,9 @@ def validate_control_flow_region(sdfg: 'SDFG', if isinstance(edge.dst, SDFGState): validate_state(edge.dst, region.node_id(edge.dst), sdfg, symbols, initialized_transients, references, **context) + elif isinstance(edge.dst, ConditionalRegion): + for _, cfg in edge.dst.branches: + validate_control_flow_region(sdfg, cfg, initialized_transients, symbols, references, **context) elif isinstance(edge.dst, ControlFlowRegion): validate_control_flow_region(sdfg, edge.dst, initialized_transients, symbols, references, **context) # End of block DFS diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index b9aaab8ecb..b22b378da3 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np import dace from dace.properties import CodeBlock from dace.sdfg.sdfg import SDFG, InterstateEdge @@ -9,24 +10,58 @@ def test_cond_region_if(): sdfg = dace.SDFG('regular_if') + sdfg.add_array("A", (1,), dace.float32) sdfg.add_symbol("i", dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) if1 = ConditionalRegion("if1") - if_body = ControlFlowRegion("if_body") - state1 = if_body.add_state("state1", is_start_block=True) - state2 = if_body.add_state("state2") - if_body.add_edge(state1, state2, InterstateEdge(assignments={"i": "100"})) + sdfg.add_node(if1) + sdfg.add_edge(state0, if1, InterstateEdge()) + + if_body = ControlFlowRegion("if_body", sdfg=sdfg) if1.branches.append((CodeBlock("i == 1"), if_body)) - sdfg.add_edge(state0, if1, InterstateEdge()) + state1 = if_body.add_state("state1", is_start_block=True) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) - assert sdfg.is_valid() + def assertions(sdfg): + assert sdfg.is_valid() + A = np.ones((1,), dtype=np.float32) + sdfg(i=1, A=A) + assert A[0] == 100 + + A = np.ones((1,), dtype=np.float32) + sdfg(i=0, A=A) + assert A[0] == 1 + + assertions(sdfg) json = sdfg.to_json() new_sdfg = SDFG.from_json(json) + assertions(new_sdfg) + +def test_serialization(): + sdfg = SDFG("test_serialization") + cond_region = ConditionalRegion("cond_region") + sdfg.add_node(cond_region, is_start_block=True) + sdfg.add_symbol("i", dace.int32) + + for j in range(10): + cfg = ControlFlowRegion(f"cfg_{j}", sdfg) + cond_region.branches.append((CodeBlock(f"i == {j}"), cfg)) + + assert sdfg.is_valid() + new_sdfg = SDFG.from_json(sdfg.to_json()) assert new_sdfg.is_valid() + new_cond_region: ConditionalRegion = new_sdfg.nodes()[0] + for j in range(10): + condition, cfg = new_cond_region.branches[j] + assert condition == CodeBlock(f"i == {j}") + assert cfg.label == f"cfg_{j}" if __name__ == '__main__': test_cond_region_if() + test_serialization() From c6f8cc8eed5f5be89d320ca5a77b966440bbd2b9 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Mon, 1 Jul 2024 18:53:10 +0200 Subject: [PATCH 03/38] Clean up test --- tests/sdfg/conditional_region_test.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index b22b378da3..de4b3f2514 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -26,21 +26,14 @@ def test_cond_region_if(): t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) - def assertions(sdfg): - assert sdfg.is_valid() - A = np.ones((1,), dtype=np.float32) - sdfg(i=1, A=A) - assert A[0] == 100 - - A = np.ones((1,), dtype=np.float32) - sdfg(i=0, A=A) - assert A[0] == 1 - - assertions(sdfg) + assert sdfg.is_valid() + A = np.ones((1,), dtype=np.float32) + sdfg(i=1, A=A) + assert A[0] == 100 - json = sdfg.to_json() - new_sdfg = SDFG.from_json(json) - assertions(new_sdfg) + A = np.ones((1,), dtype=np.float32) + sdfg(i=0, A=A) + assert A[0] == 1 def test_serialization(): sdfg = SDFG("test_serialization") From e49beac5de2c2cc55a2a87bc2c04486f0f6ac7d2 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Mon, 1 Jul 2024 18:59:09 +0200 Subject: [PATCH 04/38] Test if else --- tests/sdfg/conditional_region_test.py | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index de4b3f2514..08e14cbb72 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -55,6 +55,40 @@ def test_serialization(): assert condition == CodeBlock(f"i == {j}") assert cfg.label == f"cfg_{j}" +def test_if_else(): + sdfg = dace.SDFG('regular_if_else') + sdfg.add_array("A", (1,), dace.float32) + sdfg.add_symbol("i", dace.int32) + state0 = sdfg.add_state('state0', is_start_block=True) + + if1 = ConditionalRegion("if1") + sdfg.add_node(if1) + sdfg.add_edge(state0, if1, InterstateEdge()) + + if_body = ControlFlowRegion("if_body", sdfg=sdfg) + state1 = if_body.add_state("state1", is_start_block=True) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) + if1.branches.append((CodeBlock("i == 1"), if_body)) + + else_body = ControlFlowRegion("else_body", sdfg=sdfg) + state2 = else_body.add_state("state1", is_start_block=True) + acc_a2 = state2.add_access('A') + t2 = state2.add_tasklet("t1", None, {"a"}, "a = 200") + state2.add_edge(t2, 'a', acc_a2, None, dace.Memlet('A[0]')) + if1.branches.append((CodeBlock("i == 0"), else_body)) + + assert sdfg.is_valid() + A = np.ones((1,), dtype=np.float32) + sdfg(i=1, A=A) + assert A[0] == 100 + + A = np.ones((1,), dtype=np.float32) + sdfg(i=0, A=A) + assert A[0] == 200 + if __name__ == '__main__': test_cond_region_if() test_serialization() + test_if_else() From e258d64a8f10e7061e6c9d7fbfe1e06aa78b029a Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Mon, 1 Jul 2024 18:59:55 +0200 Subject: [PATCH 05/38] Fix typo --- tests/sdfg/conditional_region_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index 08e14cbb72..767f81983b 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -75,7 +75,7 @@ def test_if_else(): else_body = ControlFlowRegion("else_body", sdfg=sdfg) state2 = else_body.add_state("state1", is_start_block=True) acc_a2 = state2.add_access('A') - t2 = state2.add_tasklet("t1", None, {"a"}, "a = 200") + t2 = state2.add_tasklet("t2", None, {"a"}, "a = 200") state2.add_edge(t2, 'a', acc_a2, None, dace.Memlet('A[0]')) if1.branches.append((CodeBlock("i == 0"), else_body)) From 281a6203e607b0d81bded25e305b8e9e5f12372c Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Tue, 2 Jul 2024 14:17:23 +0200 Subject: [PATCH 06/38] Inline conditional regions --- dace/codegen/codegen.py | 1 + dace/frontend/python/parser.py | 1 + dace/sdfg/region_inline.py | 208 +++++++++++++++++++++++++++++++++ dace/sdfg/sdfg.py | 2 +- 4 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 dace/sdfg/region_inline.py diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index f73e3f8d11..5809bf5f4d 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -188,6 +188,7 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]: # Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops. # TODO (later): Adapt codegen to deal with hierarchical CFGs instead. + sdutils.inline_conditional_regions(sdfg) sdutils.inline_loop_blocks(sdfg) sdutils.inline_control_flow_regions(sdfg) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index e55829933c..72a41e0475 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -494,6 +494,7 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) if not self.use_experimental_cfg_blocks: + sdutils.inline_conditional_regions(sdfg) sdutils.inline_loop_blocks(sdfg) sdutils.inline_control_flow_regions(sdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks diff --git a/dace/sdfg/region_inline.py b/dace/sdfg/region_inline.py new file mode 100644 index 0000000000..724bca006e --- /dev/null +++ b/dace/sdfg/region_inline.py @@ -0,0 +1,208 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Tuple, Set +from dace.frontend.python import astutils +from dace.properties import CodeBlock +from dace.sdfg.state import ConditionalRegion, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnState, SDFGState + + +def inline(block: ControlFlowBlock) \ + -> Tuple[Set[LoopRegion.BreakState], Set[LoopRegion.ContinueState], Set[ReturnState]]: + """ + Inline all ControlFlowRegions inside this region recursively. + Returns three sets containing the Break, Continue and Return states which have to be handled by + the caller. + """ + + break_states: set[LoopRegion.BreakState] = set() + continue_states: set[LoopRegion.ContinueState] = set() + return_states: set[ReturnState] = set() + + for node in block.nodes(): + bs, cs, rs = set(), set(), set() + if isinstance(node, ConditionalRegion): + bs, cs, rs = inline_conditional_region(node, block) + elif isinstance(node, LoopRegion): + bs, cs, rs = inline_loop_region(node, block) + elif isinstance(node, LoopRegion.BreakState): + break_states.add(node) + elif isinstance(node, LoopRegion.ContinueState): + continue_states.add(node) + elif isinstance(node, ReturnState): + return_states.add(node) + elif isinstance(node, ControlFlowRegion): + bs, cs, rs = inline_control_flow_region(node, block) + break_states.update(bs) + continue_states.update(cs) + return_states.update(rs) + + if isinstance(block, ControlFlowRegion): + block.reset_cfg_list() + + return break_states, continue_states, return_states + +def inline_control_flow_region(region: ControlFlowRegion, parent: ControlFlowRegion): + from dace.sdfg.sdfg import InterstateEdge + + break_states, continue_states, return_states = inline(region) + + # Add all region states and make sure to keep track of all the ones that need to be connected in the end. + to_connect: Set[ControlFlowBlock] = set() + for node in region.nodes(): + parent.add_node(node, ensure_unique_name=True) + if region.out_degree(node) == 0 and not isinstance(node, (LoopRegion.BreakState, LoopRegion.ContinueState, ReturnState)): + to_connect.add(node) + + end_state = parent.add_state(region.label + '_end') + if len(region.nodes()) > 0: + internal_start = region.start_block + else: + internal_start = end_state + + # Add all region edges. + for edge in region.edges(): + parent.add_edge(edge.src, edge.dst, edge.data) + + # Redirect all edges to the region to the internal start state. + for b_edge in parent.in_edges(region): + parent.add_edge(b_edge.src, internal_start, b_edge.data) + parent.remove_edge(b_edge) + # Redirect all edges exiting the region to instead exit the end state. + for a_edge in parent.out_edges(region): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + for node in to_connect: + parent.add_edge(node, end_state, InterstateEdge()) + + # Remove the original loop. + parent.remove_node(region) + + if parent.in_degree(end_state) == 0: + parent.remove_node(end_state) + return break_states, continue_states, return_states + + +def inline_loop_region(loop: LoopRegion, parent: ControlFlowRegion): + from dace.sdfg.sdfg import InterstateEdge + + break_states, continue_states, return_states = inline(loop) + + internal_start = loop.start_block + + # Add all boilerplate loop states necessary for the structure. + init_state = parent.add_state(loop.label + '_init') + guard_state = parent.add_state(loop.label + '_guard') + end_state = parent.add_state(loop.label + '_end') + loop_tail_state = parent.add_state(loop.label + '_tail') + + # Add all loop states and make sure to keep track of all the ones that need to be connected in the end. + connect_to_tail: Set[SDFGState] = set() + for node in loop.nodes(): + node.label = loop.label + '_' + node.label + parent.add_node(node, ensure_unique_name=True) + if loop.out_degree(node) == 0 and not isinstance(node, (LoopRegion.BreakState, LoopRegion.ContinueState, ReturnState)): + connect_to_tail.add(node) + + # Add all internal loop edges. + for edge in loop.edges(): + parent.add_edge(edge.src, edge.dst, edge.data) + + # Redirect all edges to the loop to the init state. + for b_edge in parent.in_edges(loop): + parent.add_edge(b_edge.src, init_state, b_edge.data) + parent.remove_edge(b_edge) + # Redirect all edges exiting the loop to instead exit the end state. + for a_edge in parent.out_edges(loop): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + # Add an initialization edge that initializes the loop variable if applicable. + init_edge = InterstateEdge() + if loop.init_statement is not None: + init_edge.assignments = {} + for stmt in loop.init_statement.code: + assign: astutils.ast.Assign = stmt + init_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) + if loop.inverted: + parent.add_edge(init_state, internal_start, init_edge) + else: + parent.add_edge(init_state, guard_state, init_edge) + + # Connect the loop tail. + update_edge = InterstateEdge() + if loop.update_statement is not None: + update_edge.assignments = {} + for stmt in loop.update_statement.code: + assign: astutils.ast.Assign = stmt + update_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) + parent.add_edge(loop_tail_state, guard_state, update_edge) + + # Add condition checking edges and connect the guard state. + cond_expr = loop.loop_condition.code + parent.add_edge(guard_state, end_state, + InterstateEdge(CodeBlock(astutils.negate_expr(cond_expr)).code)) + parent.add_edge(guard_state, internal_start, InterstateEdge(CodeBlock(cond_expr).code)) + + # Connect any end states from the loop's internal state machine to the tail state so they end a + # loop iteration. Do the same for any continue states, and connect any break states to the end of the loop. + for node in continue_states | connect_to_tail: + parent.add_edge(node, loop_tail_state, InterstateEdge()) + for node in break_states: + parent.add_edge(node, end_state, InterstateEdge(assignments={f'did_break_{loop.label}': '1'})) + + # Remove the original loop. + parent.remove_node(loop) + if parent.in_degree(end_state) == 0: + parent.remove_node(end_state) + return set(), set(), return_states + +def inline_conditional_region(conditional: ConditionalRegion, parent: ControlFlowRegion): + from dace.sdfg.sdfg import InterstateEdge + + break_states, continue_states, return_states = inline(conditional) + + # Add all boilerplate states necessary for the structure. + guard_state = parent.add_state(conditional.label + '_guard') + endif_state = parent.add_state(conditional.label + '_endinf') + + connect_to_end : Set[ControlFlowBlock] = set() + # Add all states and make sure to keep track of all the ones that need to be connected in the end. + for node in conditional.nodes(): + node.label = conditional.label + '_' + node.label + parent.add_node(node, ensure_unique_name=True) + if conditional.out_degree(node) == 0 and not isinstance(node, (LoopRegion.BreakState, LoopRegion.ContinueState, ReturnState)): + connect_to_end.add(node) + + # Add all internal region edges. + for edge in conditional.edges(): + parent.add_edge(edge.src, edge.dst, edge.data) + + # Redirect all edges entering the region to the init state. + for b_edge in parent.in_edges(conditional): + parent.add_edge(b_edge.src, guard_state, b_edge.data) + parent.remove_edge(b_edge) + # Redirect all edges exiting the region to instead exit the end state. + for a_edge in parent.out_edges(conditional): + parent.add_edge(endif_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + # Add condition checking edges and connect the guard state. + parent.add_edge(guard_state, conditional.start_block, InterstateEdge(conditional.condition_expr)) + parent.add_edge(guard_state, conditional.else_branch, InterstateEdge(conditional.condition_else_expr)) + + for node in connect_to_end: + parent.add_edge(node, endif_state, InterstateEdge()) + for node in return_states: + parent.add_edge(node, endif_state, InterstateEdge(condition="False")) + parent.add_edge(conditional.else_branch, endif_state, InterstateEdge()) + bs, cs, rs = inline_control_flow_region(conditional.else_branch, parent) + break_states.update(bs) + continue_states.update(cs) + return_states.update(rs) + + parent.remove_node(conditional) + if parent.in_degree(endif_state) == 0: + parent.remove_node(endif_state) + return break_states, continue_states, return_states + \ No newline at end of file diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 3adff66834..09bea5aeda 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -2224,9 +2224,9 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops. # TODO (later): Adapt codegen to deal with hierarchical CFGs instead. + sdutils.inline_conditional_regions(sdfg) sdutils.inline_loop_blocks(sdfg) sdutils.inline_control_flow_regions(sdfg) - sdutils.inline_conditional_regions(sdfg) # Rename SDFG to avoid runtime issues with clashing names index = 0 From c310a82725d008925189845a68ac764548439bca Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Tue, 2 Jul 2024 14:18:06 +0200 Subject: [PATCH 07/38] Parse complex tests region before adding the conditional region --- dace/frontend/python/newast.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 5269f1cf83..9f0e4857fd 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -32,7 +32,7 @@ from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock from dace.sdfg import SDFG, SDFGState -from dace.sdfg.state import BreakBlock, ContinueBlock, ControlFlowBlock, LoopRegion, ControlFlowRegion +from dace.sdfg.state import BreakBlock, ConditionalRegion, ContinueBlock, ControlFlowBlock, LoopRegion, ControlFlowRegion from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -2552,34 +2552,28 @@ def visit_Continue(self, node: ast.Continue): raise DaceSyntaxError(self, node, error_msg) def visit_If(self, node: ast.If): - # Add a guard state - self._add_state('if_guard') - self.last_block.debuginfo = self.current_lineinfo - # Generate conditions cond, cond_else, _ = self._visit_test(node.test) + # Add conditional region + cond_region = ConditionalRegion(f"if_{node.lineno}") + self.cfg_target.add_node(cond_region) + self._on_block_added(cond_region) + + if_body = ControlFlowRegion(cond_region.label + "_body", sdfg=self.sdfg) + cond_region.branches.append((CodeBlock(cond), if_body)) + # Visit recursively - laststate, first_if_state, last_if_state, return_stmt = \ - self._recursive_visit(node.body, 'if', node.lineno, self.cfg_target, True) - end_if_state = self.last_block + self._recursive_visit(node.body, 'if', node.lineno, if_body, False) - # Connect the states - self.cfg_target.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) - self.cfg_target.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) + else_body = ControlFlowRegion("", sdfg=self.sdfg) + cond_region.branches.append((CodeBlock(cond_else), else_body)) # Process 'else'/'elif' statements if len(node.orelse) > 0: + else_body.label = f"{cond_region.label}_else_{node.orelse[0].lineno}" # Visit recursively - _, first_else_state, last_else_state, return_stmt = \ - self._recursive_visit(node.orelse, 'else', node.lineno, self.cfg_target, False) - - # Connect the states - self.cfg_target.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) - self.cfg_target.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) - else: - self.cfg_target.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) - self.last_block = end_if_state + self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): From 58451fa3ba798855015cd75b5db522e989bdad17 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 3 Jul 2024 11:17:50 +0200 Subject: [PATCH 08/38] Allow break and continue states inside cfgs nested in loops and remove end state if not used in inline --- dace/frontend/python/newast.py | 22 ++++++---------------- dace/sdfg/state.py | 8 +++++--- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 9f0e4857fd..f791bb1260 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2532,24 +2532,14 @@ def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowB self.cfg_target.add_edge(loop_region, postloop_block, dace.InterstateEdge(condition=f"{did_break_symbol} == 1")) def visit_Break(self, node: ast.Break): - if isinstance(self.cfg_target, LoopRegion): - self._on_block_added(self.cfg_target.add_break(f'break_{self.cfg_target.label}_{node.lineno}')) - else: - error_msg = "'break' is only supported inside loops " - if self.nested: - error_msg += ("('break' is not supported in Maps and cannot be used in nested DaCe program calls to " - " break out of loops of outer scopes)") - raise DaceSyntaxError(self, node, error_msg) + break_block = BreakBlock(f'break_{node.lineno}') + self.cfg_target.add_node(break_block, ensure_unique_name=True) + self._on_block_added(break_block) def visit_Continue(self, node: ast.Continue): - if isinstance(self.cfg_target, LoopRegion): - self._on_block_added(self.cfg_target.add_continue(f'continue_{self.cfg_target.label}_{node.lineno}')) - else: - error_msg = ("'continue' is only supported inside loops ") - if self.nested: - error_msg += ("('continue' is not supported in Maps and cannot be used in nested DaCe program calls to " - " continue loops of outer scopes)") - raise DaceSyntaxError(self, node, error_msg) + continue_block = BreakBlock(f'continue_{node.lineno}') + self.cfg_target.add_node(continue_block, ensure_unique_name=True) + self._on_block_added(continue_block) def visit_If(self, node: ast.If): # Generate conditions diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index b677d6a658..038e7d8acf 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2513,7 +2513,7 @@ def inline(self) -> Tuple[bool, Any]: # remains as-is. newnode = parent.add_state(node.label) block_to_state_map[node] = newnode - elif self.out_degree(node) == 0: + elif self.out_degree(node) == 0 and not isinstance(node, (BreakBlock, ContinueBlock)): to_connect.add(node) # Add all region edges. @@ -2533,6 +2533,9 @@ def inline(self) -> Tuple[bool, Any]: for node in to_connect: parent.add_edge(node, end_state, dace.InterstateEdge()) + + if parent.in_degree(end_state) == 0: + parent.remove_node(end_state) # Remove the original control flow region (self) from the parent graph. parent.remove_node(self) @@ -2948,7 +2951,7 @@ def inline(self) -> Tuple[bool, Any]: # and return are inlined correctly. def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: for block in region.nodes(): - if isinstance(block, ControlFlowRegion) and not isinstance(block, LoopRegion): + if (isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalRegion)) and not isinstance(block, LoopRegion): recursive_inline_cf_regions(block) block.inline() recursive_inline_cf_regions(self) @@ -3169,7 +3172,6 @@ def replace_dict(self, from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) - super().replace_dict(repl, symrepl, replace_in_graph) for _, cfg in self.branches: cfg.replace_dict(repl, symrepl, replace_in_graph) From ebb69becc2bbe566a3409bce3ed551a5ca0f2ba5 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 3 Jul 2024 15:02:08 +0200 Subject: [PATCH 09/38] Call start_block before removing the current start block --- dace/transformation/interstate/state_fusion.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 3abbe085f5..b99eb276b1 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -471,18 +471,24 @@ def apply(self, _, sdfg): # Special case 1: first state is empty if first_state.is_empty(): + new_start_block = False + if graph.start_block == first_state: + new_start_block = True sdutil.change_edge_dest(graph, first_state, second_state) graph.remove_node(first_state) - if graph.start_block == first_state: + if new_start_block: graph.start_block = graph.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): + new_start_block = False + if graph.start_block == second_state: + new_start_block = True sdutil.change_edge_src(graph, second_state, first_state) sdutil.change_edge_dest(graph, second_state, first_state) graph.remove_node(second_state) - if graph.start_block == second_state: + if new_start_block: graph.start_block = graph.node_id(first_state) return From 7f429ea63ee2a1f2afd245f0cd24989d5c8249b9 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 3 Jul 2024 15:11:37 +0200 Subject: [PATCH 10/38] Avoid double additions of return blocks --- dace/sdfg/state.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 038e7d8acf..aa39629e46 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2507,14 +2507,15 @@ def inline(self) -> Tuple[bool, Any]: block_to_state_map: Dict[ControlFlowBlock, SDFGState] = dict() for node in self.nodes(): node.label = self.label + '_' + node.label - parent.add_node(node, ensure_unique_name=True) if isinstance(node, ReturnBlock) and isinstance(parent, dace.SDFG): # If a return block is being inlined into an SDFG, convert it into a regular state. Otherwise it # remains as-is. newnode = parent.add_state(node.label) block_to_state_map[node] = newnode - elif self.out_degree(node) == 0 and not isinstance(node, (BreakBlock, ContinueBlock)): - to_connect.add(node) + else: + parent.add_node(node, ensure_unique_name=True) + if self.out_degree(node) == 0 and not isinstance(node, (BreakBlock, ContinueBlock, ReturnBlock)): + to_connect.add(node) # Add all region edges. for edge in self.edges(): @@ -2534,6 +2535,7 @@ def inline(self) -> Tuple[bool, Any]: for node in to_connect: parent.add_edge(node, end_state, dace.InterstateEdge()) + # NOTE: this should be unnecessesary if parent.in_degree(end_state) == 0: parent.remove_node(end_state) @@ -3213,9 +3215,12 @@ def inline(self) -> Tuple[bool, Any]: from dace.sdfg.sdfg import InterstateEdge for condition, cfg in self.branches: - parent.add_node(cfg) - parent.add_edge(guard_state, cfg, InterstateEdge(condition=condition)) - parent.add_edge(cfg, end_state, InterstateEdge()) + if cfg.number_of_nodes() > 0: + parent.add_node(cfg) + parent.add_edge(guard_state, cfg, InterstateEdge(condition=condition)) + parent.add_edge(cfg, end_state, InterstateEdge()) + else: + parent.add_edge(guard_state, end_state, InterstateEdge(condition=condition)) parent.remove_node(self) From 9fe25326641d148b8f2968b61a599278c8155151 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 4 Jul 2024 11:34:34 +0200 Subject: [PATCH 11/38] Do not invalidate start_block cache everytime a new node is added --- dace/sdfg/state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index aa39629e46..dc453f48b7 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2591,7 +2591,6 @@ def add_node(self, node.label = self._ensure_unique_block_name(node.label) super().add_node(node) - self._cached_start_block = None node.parent_graph = self if isinstance(self, dace.SDFG): node.sdfg = self From d91ef62fe39c262f575d0bd0701d0f5a33dea884 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 4 Jul 2024 11:59:03 +0200 Subject: [PATCH 12/38] Fix typo --- dace/frontend/python/newast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index f791bb1260..09380ddecc 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2537,7 +2537,7 @@ def visit_Break(self, node: ast.Break): self._on_block_added(break_block) def visit_Continue(self, node: ast.Continue): - continue_block = BreakBlock(f'continue_{node.lineno}') + continue_block = ContinueBlock(f'continue_{node.lineno}') self.cfg_target.add_node(continue_block, ensure_unique_name=True) self._on_block_added(continue_block) From c8b1beea8f5e3c2f7687fcf8fcf283dc7fdd9c1b Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 4 Jul 2024 16:34:24 +0200 Subject: [PATCH 13/38] Inline all cfg for each nested sdfg and remove dead states after all the inlining --- dace/frontend/python/parser.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 72a41e0475..8455843278 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -15,6 +15,7 @@ from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing) from dace.sdfg import SDFG, utils as sdutils from dace.data import create_datadescriptor, Data +from dace.sdfg.state import BreakBlock, ContinueBlock try: from typing import get_origin, get_args @@ -494,9 +495,15 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) if not self.use_experimental_cfg_blocks: - sdutils.inline_conditional_regions(sdfg) - sdutils.inline_loop_blocks(sdfg) - sdutils.inline_control_flow_regions(sdfg) + for nsdfg in sdfg.all_sdfgs_recursive(): + sdutils.inline_conditional_regions(nsdfg) + sdutils.inline_loop_blocks(nsdfg) + sdutils.inline_control_flow_regions(nsdfg) + for node in nsdfg.nodes(): + if isinstance(node, (BreakBlock, ContinueBlock)): + raise pycommon.DaceSyntaxError(None, None, "Break or continue blocks were not handled") + from dace.transformation.passes.dead_state_elimination import DeadStateElimination + DeadStateElimination().apply_pass(nsdfg, {}) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks # Apply simplification pass automatically From ad5e57469571a50280e04b1669bf32714255d7bf Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Fri, 5 Jul 2024 15:25:17 +0200 Subject: [PATCH 14/38] Set did_break_symbol to 1 in loop region inline method --- dace/frontend/python/newast.py | 4 ---- dace/sdfg/state.py | 2 ++ 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 09380ddecc..2a5a6eeee7 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2513,10 +2513,6 @@ def visit_While(self, node: ast.While): def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowBlock): did_break_symbol = 'did_break_' + loop_region.label self.sdfg.add_symbol(did_break_symbol, dace.int32) - for n in loop_region.nodes(): - if isinstance(n, BreakBlock): - for iedge in loop_region.in_edges(n): - iedge.data.assignments[did_break_symbol] = '1' for iedge in self.cfg_target.in_edges(loop_region): iedge.data.assignments[did_break_symbol] = '0' oedges = self.cfg_target.out_edges(loop_region) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index dc453f48b7..d8c764cbdc 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3035,6 +3035,8 @@ def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: parent.add_edge(node, loop_latch_state, dace.InterstateEdge()) for node in connect_to_end: parent.add_edge(node, end_state, dace.InterstateEdge()) + for iedge in parent.in_edges(node): + iedge.data.assignments['did_break_' + self.label] = '1' parent.remove_node(self) From 25e52bb429d90d64b5e43cceb19f73213ba94bf8 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Tue, 9 Jul 2024 17:49:45 +0200 Subject: [PATCH 15/38] Use codegen generation for loop regions --- dace/codegen/codegen.py | 4 +++- dace/frontend/python/interface.py | 4 +++- dace/frontend/python/newast.py | 9 ++++++++- dace/sdfg/state.py | 4 +--- tests/python_frontend/loops_test.py | 2 +- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index d1427bf037..864362fe88 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -9,7 +9,7 @@ from dace.codegen.targets import framecode from dace.codegen.codeobject import CodeObject from dace.config import Config -from dace.sdfg import infer_types +from dace.sdfg import infer_types, utils as sdutils # Import CPU code generator. TODO: Remove when refactored from dace.codegen.targets import cpp, cpu @@ -158,6 +158,8 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: """ from dace.codegen.targets.target import TargetCodeGenerator # Avoid import loop + sdutils.inline_conditional_regions(sdfg) + sdutils.inline_control_flow_regions(sdfg) # Before compiling, validate SDFG correctness if validate: sdfg.validate() diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index ecd0b164d6..035e9472c3 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -44,6 +44,7 @@ def program(f: F, recompile: bool = True, distributed_compilation: bool = False, constant_functions=False, + use_experimental_cfg_blocks=False, **kwargs) -> Callable[..., parser.DaceProgram]: """ Entry point to a data-centric program. For methods and ``classmethod``s, use @@ -83,7 +84,8 @@ def program(f: F, recreate_sdfg=recreate_sdfg, regenerate_code=regenerate_code, recompile=recompile, - distributed_compilation=distributed_compilation) + distributed_compilation=distributed_compilation, + use_experimental_cfg_blocks=use_experimental_cfg_blocks) function = program diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 2a5a6eeee7..43b4b095e7 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2371,7 +2371,7 @@ def visit_For(self, node: ast.For): extra_symbols=extra_syms, parent=loop_region, unconnected_last_block=False) loop_region.start_block = loop_region.node_id(first_subblock) - + self._connect_break_blocks(loop_region) # Handle else clause if node.orelse: # Continue visiting body @@ -2509,6 +2509,13 @@ def visit_While(self, node: ast.While): self._generate_orelse(loop_region, postloop_block) self.last_block = loop_region + self._connect_break_blocks(loop_region) + + def _connect_break_blocks(self, loop_region: LoopRegion): + for node, parent in loop_region.all_nodes_recursive(lambda n, _: not isinstance(n, (LoopRegion, SDFGState))): + if isinstance(node, BreakBlock): + for in_edge in parent.in_edges(node): + in_edge.data.assignments["did_break_" + loop_region.label] = "1" def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowBlock): did_break_symbol = 'did_break_' + loop_region.label diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 95d9d837e9..762abcde06 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3093,8 +3093,6 @@ def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: parent.add_edge(node, loop_latch_state, dace.InterstateEdge()) for node in connect_to_end: parent.add_edge(node, end_state, dace.InterstateEdge()) - for iedge in parent.in_edges(node): - iedge.data.assignments['did_break_' + self.label] = '1' parent.remove_node(self) @@ -3182,7 +3180,7 @@ def has_return(self) -> bool: return False @dace.serialize.serializable -class ConditionalRegion(ControlFlowBlock): +class ConditionalRegion(ControlFlowBlock, ControlGraphView): def __init__(self, label: str): super().__init__(label) self.branches: List[Tuple[CodeBlock, ControlFlowRegion]] = [] diff --git a/tests/python_frontend/loops_test.py b/tests/python_frontend/loops_test.py index e0c869f20c..019e3addae 100644 --- a/tests/python_frontend/loops_test.py +++ b/tests/python_frontend/loops_test.py @@ -416,7 +416,7 @@ def test_nested_map_with_symbol(): reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_for_else(): - @dace.program + @dace.program(use_experimental_cfg_blocks=True) def for_else(A: dace.float64[20]): for i in range(1, 20): if A[i] >= 10: From e5aa40a7ef4301d82e413f5c710fc2092243f1bb Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Tue, 9 Jul 2024 22:24:22 +0200 Subject: [PATCH 16/38] Raise exception when creating a break or continue block outside a loop region --- dace/frontend/python/newast.py | 14 ++++++++++++++ dace/frontend/python/parser.py | 4 ---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 43b4b095e7..c28f02825c 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2534,12 +2534,24 @@ def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowB self.cfg_target.remove_edge(oedge) self.cfg_target.add_edge(loop_region, postloop_block, dace.InterstateEdge(condition=f"{did_break_symbol} == 1")) + def _has_loop_ancestor(self, node: ControlFlowBlock) -> bool: + while node is not None and node != self.sdfg: + if isinstance(node, LoopRegion): + return True + node = node.parent_graph + return False + + def visit_Break(self, node: ast.Break): + if not self._has_loop_ancestor(self.cfg_target): + raise DaceSyntaxError(self, node, "Break block outside loop region") break_block = BreakBlock(f'break_{node.lineno}') self.cfg_target.add_node(break_block, ensure_unique_name=True) self._on_block_added(break_block) def visit_Continue(self, node: ast.Continue): + if not self._has_loop_ancestor(self.cfg_target): + raise DaceSyntaxError(self, node, "Continue block outside loop region") continue_block = ContinueBlock(f'continue_{node.lineno}') self.cfg_target.add_node(continue_block, ensure_unique_name=True) self._on_block_added(continue_block) @@ -2555,12 +2567,14 @@ def visit_If(self, node: ast.If): if_body = ControlFlowRegion(cond_region.label + "_body", sdfg=self.sdfg) cond_region.branches.append((CodeBlock(cond), if_body)) + if_body.parent_graph = cond_region # Visit recursively self._recursive_visit(node.body, 'if', node.lineno, if_body, False) else_body = ControlFlowRegion("", sdfg=self.sdfg) cond_region.branches.append((CodeBlock(cond_else), else_body)) + else_body.parent_graph = cond_region # Process 'else'/'elif' statements if len(node.orelse) > 0: diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 8455843278..9a6f1416e9 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -15,7 +15,6 @@ from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing) from dace.sdfg import SDFG, utils as sdutils from dace.data import create_datadescriptor, Data -from dace.sdfg.state import BreakBlock, ContinueBlock try: from typing import get_origin, get_args @@ -499,9 +498,6 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdutils.inline_conditional_regions(nsdfg) sdutils.inline_loop_blocks(nsdfg) sdutils.inline_control_flow_regions(nsdfg) - for node in nsdfg.nodes(): - if isinstance(node, (BreakBlock, ContinueBlock)): - raise pycommon.DaceSyntaxError(None, None, "Break or continue blocks were not handled") from dace.transformation.passes.dead_state_elimination import DeadStateElimination DeadStateElimination().apply_pass(nsdfg, {}) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks From e4d9a85870bf0e8e2a3c1ce4e1d19a30eebc84cf Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 10 Jul 2024 11:06:02 +0200 Subject: [PATCH 17/38] Fix _used_symbols_internal in conditional region and remove dead blocks when encountering breaks and contiunes --- dace/frontend/python/parser.py | 4 +--- dace/sdfg/state.py | 36 ++++++++++++++--------------- tests/python_frontend/loops_test.py | 3 ++- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 9a6f1416e9..1c3510c51f 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -496,10 +496,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF if not self.use_experimental_cfg_blocks: for nsdfg in sdfg.all_sdfgs_recursive(): sdutils.inline_conditional_regions(nsdfg) - sdutils.inline_loop_blocks(nsdfg) sdutils.inline_control_flow_regions(nsdfg) - from dace.transformation.passes.dead_state_elimination import DeadStateElimination - DeadStateElimination().apply_pass(nsdfg, {}) + sdutils.inline_loop_blocks(nsdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks # Apply simplification pass automatically diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 762abcde06..2ead72461a 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2547,7 +2547,6 @@ def inline(self) -> Tuple[bool, Any]: """ parent = self.parent_graph if parent: - end_state = parent.add_state(self.label + '_end') # Add all region states and make sure to keep track of all the ones that need to be connected in the end. to_connect: Set[SDFGState] = set() @@ -2574,18 +2573,25 @@ def inline(self) -> Tuple[bool, Any]: for b_edge in parent.in_edges(self): parent.add_edge(b_edge.src, self.start_block, b_edge.data) parent.remove_edge(b_edge) - # Redirect all edges exiting the region to instead exit the end state. - for a_edge in parent.out_edges(self): - parent.add_edge(end_state, a_edge.dst, a_edge.data) - parent.remove_edge(a_edge) - - for node in to_connect: - parent.add_edge(node, end_state, dace.InterstateEdge()) - # NOTE: this should be unnecessesary - if parent.in_degree(end_state) == 0: - parent.remove_node(end_state) - + end_state = None + if len(to_connect) > 0: + end_state = parent.add_state(self.label + '_end') + # Redirect all edges exiting the region to instead exit the end state. + for a_edge in parent.out_edges(self): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + for node in to_connect: + parent.add_edge(node, end_state, dace.InterstateEdge()) + else: + dead_blocks = [succ for succ in parent.successors(self) if parent.in_degree(succ) == 1] + while dead_blocks: + layer = list(dead_blocks) + dead_blocks.clear() + for u in layer: + dead_blocks.extend([succ for succ in parent.successors(u) if parent.in_degree(succ) == 1]) + parent.remove_node(u) # Remove the original control flow region (self) from the parent graph. parent.remove_node(self) @@ -3201,12 +3207,6 @@ def _used_symbols_internal(self, free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment - b_free_symbols, b_defined_symbols, b_used_before_assignment = super()._used_symbols_internal( - all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) - free_syms |= b_free_symbols - defined_syms |= b_defined_symbols - used_before_assignment |= b_used_before_assignment - for condition, cfg in self.branches: free_syms |= condition.get_free_symbols() b_free_symbols, b_defined_symbols, b_used_before_assignment = cfg._used_symbols_internal( diff --git a/tests/python_frontend/loops_test.py b/tests/python_frontend/loops_test.py index 019e3addae..2d8b2dc83c 100644 --- a/tests/python_frontend/loops_test.py +++ b/tests/python_frontend/loops_test.py @@ -496,10 +496,12 @@ def branch_in_while(cond: dace.int32): def test_branch_in_while(): sdfg = branch_in_while.to_sdfg(simplify=False) + sdfg.save("branch_in_while.sdfg") assert len(sdfg.source_nodes()) == 1 if __name__ == "__main__": + test_branch_in_while() test_for_loop() test_for_loop_with_break_continue() test_nested_for_loop() @@ -521,4 +523,3 @@ def test_branch_in_while(): test_for_else() test_while_else() test_branch_in_for() - test_branch_in_while() From e8deac149ca3795e02ca18d5ec97c6ed50dcb98f Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 10 Jul 2024 14:23:46 +0200 Subject: [PATCH 18/38] Fix use of start_block in state fusion transformation --- dace/transformation/interstate/state_fusion.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index b99eb276b1..5362695b49 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -461,6 +461,8 @@ def apply(self, _, sdfg): graph = first_state.parent_graph + start_block = graph.start_block + # Remove interstate edge(s) edges = graph.edges_between(first_state, second_state) for edge in edges: @@ -471,24 +473,18 @@ def apply(self, _, sdfg): # Special case 1: first state is empty if first_state.is_empty(): - new_start_block = False - if graph.start_block == first_state: - new_start_block = True sdutil.change_edge_dest(graph, first_state, second_state) graph.remove_node(first_state) - if new_start_block: + if start_block == first_state: graph.start_block = graph.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): - new_start_block = False - if graph.start_block == second_state: - new_start_block = True sdutil.change_edge_src(graph, second_state, first_state) sdutil.change_edge_dest(graph, second_state, first_state) graph.remove_node(second_state) - if new_start_block: + if start_block == second_state: graph.start_block = graph.node_id(first_state) return From 9c506bd187beac9f328fbc9e68b1c07ff7fd6fc7 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 10 Jul 2024 17:02:07 +0200 Subject: [PATCH 19/38] Fix symbols internal for conditional region --- dace/sdfg/state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 2ead72461a..a5f9b18ea9 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3208,9 +3208,9 @@ def _used_symbols_internal(self, used_before_assignment = set() if used_before_assignment is None else used_before_assignment for condition, cfg in self.branches: - free_syms |= condition.get_free_symbols() + free_syms |= condition.get_free_symbols(defined_syms) b_free_symbols, b_defined_symbols, b_used_before_assignment = cfg._used_symbols_internal( - all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) free_syms |= b_free_symbols defined_syms |= b_defined_symbols used_before_assignment |= b_used_before_assignment From c614bcbfa263f2a7bce36563588ebcb3089bbf5c Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 10 Jul 2024 17:14:27 +0200 Subject: [PATCH 20/38] Connect arrays to views in all states for each cfg --- dace/frontend/python/newast.py | 2 +- dace/sdfg/state.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c28f02825c..411eadfbea 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1301,7 +1301,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List return new_nodes # Map view access nodes to their respective data - for state in self.sdfg.states(): + for state in self.sdfg.all_states(): # NOTE: We need to support views of views nodes = list(state.data_nodes()) while nodes: diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a5f9b18ea9..bfd7c04280 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2751,6 +2751,9 @@ def all_states(self) -> Iterator[SDFGState]: yield block elif isinstance(block, ControlFlowRegion): yield from block.all_states() + elif isinstance(block, ConditionalRegion): + for _, cfr in block.branches: + yield from cfr.all_states() def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: """ Iterate over all control flow blocks in this control flow graph. """ From 7a4584ecabf7fb5b1dd1c7abfda15e103a9245b5 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Wed, 10 Jul 2024 22:05:26 +0200 Subject: [PATCH 21/38] Fix from_json --- dace/sdfg/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index bfd7c04280..21419cb33c 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3244,7 +3244,7 @@ def to_json(self, parent=None): def from_json(cls, json_obj, context=None): cond_region = ConditionalRegion(json_obj["label"]) cond_region.is_collapsed = json_obj["collapsed"] - cond_region.branches = [(CodeBlock.from_json(condition), ControlFlowRegion.from_json(cfg, context_info=context)) + cond_region.branches = [(CodeBlock.from_json(condition), ControlFlowRegion.from_json(cfg, context)) for condition, cfg in json_obj["branches"]] return cond_region From feed8aa4898aef1c13805904f7ab52039fb8699b Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 11 Jul 2024 14:41:33 +0200 Subject: [PATCH 22/38] Fix _used_symbols_internal in ControlFlowRegion for handling correctly ConditionalRegions --- dace/sdfg/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 21419cb33c..1529a5a3d9 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2787,7 +2787,7 @@ def _used_symbols_internal(self, for block in ordered_blocks: state_symbols = set() - if isinstance(block, ControlFlowRegion): + if isinstance(block, (ControlFlowRegion, ConditionalRegion)): b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols, defined_syms, free_syms, From faad41d328467186088575e064586667cd44cb35 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 11 Jul 2024 15:26:01 +0200 Subject: [PATCH 23/38] Remove unused file --- dace/sdfg/region_inline.py | 208 ------------------------------------- 1 file changed, 208 deletions(-) delete mode 100644 dace/sdfg/region_inline.py diff --git a/dace/sdfg/region_inline.py b/dace/sdfg/region_inline.py deleted file mode 100644 index 724bca006e..0000000000 --- a/dace/sdfg/region_inline.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. - -from typing import Tuple, Set -from dace.frontend.python import astutils -from dace.properties import CodeBlock -from dace.sdfg.state import ConditionalRegion, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnState, SDFGState - - -def inline(block: ControlFlowBlock) \ - -> Tuple[Set[LoopRegion.BreakState], Set[LoopRegion.ContinueState], Set[ReturnState]]: - """ - Inline all ControlFlowRegions inside this region recursively. - Returns three sets containing the Break, Continue and Return states which have to be handled by - the caller. - """ - - break_states: set[LoopRegion.BreakState] = set() - continue_states: set[LoopRegion.ContinueState] = set() - return_states: set[ReturnState] = set() - - for node in block.nodes(): - bs, cs, rs = set(), set(), set() - if isinstance(node, ConditionalRegion): - bs, cs, rs = inline_conditional_region(node, block) - elif isinstance(node, LoopRegion): - bs, cs, rs = inline_loop_region(node, block) - elif isinstance(node, LoopRegion.BreakState): - break_states.add(node) - elif isinstance(node, LoopRegion.ContinueState): - continue_states.add(node) - elif isinstance(node, ReturnState): - return_states.add(node) - elif isinstance(node, ControlFlowRegion): - bs, cs, rs = inline_control_flow_region(node, block) - break_states.update(bs) - continue_states.update(cs) - return_states.update(rs) - - if isinstance(block, ControlFlowRegion): - block.reset_cfg_list() - - return break_states, continue_states, return_states - -def inline_control_flow_region(region: ControlFlowRegion, parent: ControlFlowRegion): - from dace.sdfg.sdfg import InterstateEdge - - break_states, continue_states, return_states = inline(region) - - # Add all region states and make sure to keep track of all the ones that need to be connected in the end. - to_connect: Set[ControlFlowBlock] = set() - for node in region.nodes(): - parent.add_node(node, ensure_unique_name=True) - if region.out_degree(node) == 0 and not isinstance(node, (LoopRegion.BreakState, LoopRegion.ContinueState, ReturnState)): - to_connect.add(node) - - end_state = parent.add_state(region.label + '_end') - if len(region.nodes()) > 0: - internal_start = region.start_block - else: - internal_start = end_state - - # Add all region edges. - for edge in region.edges(): - parent.add_edge(edge.src, edge.dst, edge.data) - - # Redirect all edges to the region to the internal start state. - for b_edge in parent.in_edges(region): - parent.add_edge(b_edge.src, internal_start, b_edge.data) - parent.remove_edge(b_edge) - # Redirect all edges exiting the region to instead exit the end state. - for a_edge in parent.out_edges(region): - parent.add_edge(end_state, a_edge.dst, a_edge.data) - parent.remove_edge(a_edge) - - for node in to_connect: - parent.add_edge(node, end_state, InterstateEdge()) - - # Remove the original loop. - parent.remove_node(region) - - if parent.in_degree(end_state) == 0: - parent.remove_node(end_state) - return break_states, continue_states, return_states - - -def inline_loop_region(loop: LoopRegion, parent: ControlFlowRegion): - from dace.sdfg.sdfg import InterstateEdge - - break_states, continue_states, return_states = inline(loop) - - internal_start = loop.start_block - - # Add all boilerplate loop states necessary for the structure. - init_state = parent.add_state(loop.label + '_init') - guard_state = parent.add_state(loop.label + '_guard') - end_state = parent.add_state(loop.label + '_end') - loop_tail_state = parent.add_state(loop.label + '_tail') - - # Add all loop states and make sure to keep track of all the ones that need to be connected in the end. - connect_to_tail: Set[SDFGState] = set() - for node in loop.nodes(): - node.label = loop.label + '_' + node.label - parent.add_node(node, ensure_unique_name=True) - if loop.out_degree(node) == 0 and not isinstance(node, (LoopRegion.BreakState, LoopRegion.ContinueState, ReturnState)): - connect_to_tail.add(node) - - # Add all internal loop edges. - for edge in loop.edges(): - parent.add_edge(edge.src, edge.dst, edge.data) - - # Redirect all edges to the loop to the init state. - for b_edge in parent.in_edges(loop): - parent.add_edge(b_edge.src, init_state, b_edge.data) - parent.remove_edge(b_edge) - # Redirect all edges exiting the loop to instead exit the end state. - for a_edge in parent.out_edges(loop): - parent.add_edge(end_state, a_edge.dst, a_edge.data) - parent.remove_edge(a_edge) - - # Add an initialization edge that initializes the loop variable if applicable. - init_edge = InterstateEdge() - if loop.init_statement is not None: - init_edge.assignments = {} - for stmt in loop.init_statement.code: - assign: astutils.ast.Assign = stmt - init_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) - if loop.inverted: - parent.add_edge(init_state, internal_start, init_edge) - else: - parent.add_edge(init_state, guard_state, init_edge) - - # Connect the loop tail. - update_edge = InterstateEdge() - if loop.update_statement is not None: - update_edge.assignments = {} - for stmt in loop.update_statement.code: - assign: astutils.ast.Assign = stmt - update_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) - parent.add_edge(loop_tail_state, guard_state, update_edge) - - # Add condition checking edges and connect the guard state. - cond_expr = loop.loop_condition.code - parent.add_edge(guard_state, end_state, - InterstateEdge(CodeBlock(astutils.negate_expr(cond_expr)).code)) - parent.add_edge(guard_state, internal_start, InterstateEdge(CodeBlock(cond_expr).code)) - - # Connect any end states from the loop's internal state machine to the tail state so they end a - # loop iteration. Do the same for any continue states, and connect any break states to the end of the loop. - for node in continue_states | connect_to_tail: - parent.add_edge(node, loop_tail_state, InterstateEdge()) - for node in break_states: - parent.add_edge(node, end_state, InterstateEdge(assignments={f'did_break_{loop.label}': '1'})) - - # Remove the original loop. - parent.remove_node(loop) - if parent.in_degree(end_state) == 0: - parent.remove_node(end_state) - return set(), set(), return_states - -def inline_conditional_region(conditional: ConditionalRegion, parent: ControlFlowRegion): - from dace.sdfg.sdfg import InterstateEdge - - break_states, continue_states, return_states = inline(conditional) - - # Add all boilerplate states necessary for the structure. - guard_state = parent.add_state(conditional.label + '_guard') - endif_state = parent.add_state(conditional.label + '_endinf') - - connect_to_end : Set[ControlFlowBlock] = set() - # Add all states and make sure to keep track of all the ones that need to be connected in the end. - for node in conditional.nodes(): - node.label = conditional.label + '_' + node.label - parent.add_node(node, ensure_unique_name=True) - if conditional.out_degree(node) == 0 and not isinstance(node, (LoopRegion.BreakState, LoopRegion.ContinueState, ReturnState)): - connect_to_end.add(node) - - # Add all internal region edges. - for edge in conditional.edges(): - parent.add_edge(edge.src, edge.dst, edge.data) - - # Redirect all edges entering the region to the init state. - for b_edge in parent.in_edges(conditional): - parent.add_edge(b_edge.src, guard_state, b_edge.data) - parent.remove_edge(b_edge) - # Redirect all edges exiting the region to instead exit the end state. - for a_edge in parent.out_edges(conditional): - parent.add_edge(endif_state, a_edge.dst, a_edge.data) - parent.remove_edge(a_edge) - - # Add condition checking edges and connect the guard state. - parent.add_edge(guard_state, conditional.start_block, InterstateEdge(conditional.condition_expr)) - parent.add_edge(guard_state, conditional.else_branch, InterstateEdge(conditional.condition_else_expr)) - - for node in connect_to_end: - parent.add_edge(node, endif_state, InterstateEdge()) - for node in return_states: - parent.add_edge(node, endif_state, InterstateEdge(condition="False")) - parent.add_edge(conditional.else_branch, endif_state, InterstateEdge()) - bs, cs, rs = inline_control_flow_region(conditional.else_branch, parent) - break_states.update(bs) - continue_states.update(cs) - return_states.update(rs) - - parent.remove_node(conditional) - if parent.in_degree(endif_state) == 0: - parent.remove_node(endif_state) - return break_states, continue_states, return_states - \ No newline at end of file From 39a183eaa4a56a608690bf85a9bf6e96dc34fc5e Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 11 Jul 2024 15:58:44 +0200 Subject: [PATCH 24/38] Revert ControlFlowRegion add_node change --- dace/sdfg/state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 1529a5a3d9..d1e787930b 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2647,6 +2647,7 @@ def add_node(self, node.label = self._ensure_unique_block_name(node.label) super().add_node(node) + self._cached_start_block = None node.parent_graph = self if isinstance(self, dace.SDFG): node.sdfg = self From cbc51cd8e01695561758b3d807d08ab79392a55d Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 11 Jul 2024 22:39:24 +0200 Subject: [PATCH 25/38] Fix test --- tests/sdfg/work_depth_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index e677cca752..b4a986aa1d 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -192,7 +192,7 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_3') * N, sp.Symbol('num_execs_0_3'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), 'break_for_loop': (break_for_loop, (N**2, N)), @@ -317,6 +317,7 @@ def test_assumption_system_contradictions(assumptions): if __name__ == '__main__': + test_work_depth("unbounded_while_do") for test_name in work_depth_test_cases.keys(): test_work_depth(test_name) From 575f3518d2ed19367d023a9474ccd2ddb2f76a58 Mon Sep 17 00:00:00 2001 From: Luca Patrignani Date: Thu, 11 Jul 2024 23:32:35 +0200 Subject: [PATCH 26/38] Represent else branch as tuple (else condition, None) --- dace/frontend/python/newast.py | 9 +++---- dace/sdfg/state.py | 44 +++++++++++++++++++--------------- dace/sdfg/validation.py | 5 ++-- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index b3b12c371f..877dcc3f18 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2572,15 +2572,16 @@ def visit_If(self, node: ast.If): # Visit recursively self._recursive_visit(node.body, 'if', node.lineno, if_body, False) - else_body = ControlFlowRegion("", sdfg=self.sdfg) - cond_region.branches.append((CodeBlock(cond_else), else_body)) - else_body.parent_graph = cond_region # Process 'else'/'elif' statements if len(node.orelse) > 0: - else_body.label = f"{cond_region.label}_else_{node.orelse[0].lineno}" + else_body = ControlFlowRegion(f"{cond_region.label}_else_{node.orelse[0].lineno}", sdfg=self.sdfg) + cond_region.branches.append((CodeBlock(cond_else), else_body)) + else_body.parent_graph = cond_region # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) + else: + cond_region.branches.append((CodeBlock(cond_else), None)) def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index d1e787930b..d8afbb22fe 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2753,8 +2753,9 @@ def all_states(self) -> Iterator[SDFGState]: elif isinstance(block, ControlFlowRegion): yield from block.all_states() elif isinstance(block, ConditionalRegion): - for _, cfr in block.branches: - yield from cfr.all_states() + for _, region in block.branches: + if region is not None: + yield from region.all_states() def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: """ Iterate over all control flow blocks in this control flow graph. """ @@ -3196,7 +3197,7 @@ def __init__(self, label: str): self.branches: List[Tuple[CodeBlock, ControlFlowRegion]] = [] def nodes(self) -> List['ControlFlowBlock']: - return [node for _, node in self.branches] + return [node for _, node in self.branches if node is not None] def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: return [] @@ -3211,13 +3212,14 @@ def _used_symbols_internal(self, free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment - for condition, cfg in self.branches: - free_syms |= condition.get_free_symbols(defined_syms) - b_free_symbols, b_defined_symbols, b_used_before_assignment = cfg._used_symbols_internal( - all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) - free_syms |= b_free_symbols - defined_syms |= b_defined_symbols - used_before_assignment |= b_used_before_assignment + for condition, region in self.branches: + if region is not None: + free_syms |= condition.get_free_symbols(defined_syms) + b_free_symbols, b_defined_symbols, b_used_before_assignment = region._used_symbols_internal( + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) + free_syms |= b_free_symbols + defined_syms |= b_defined_symbols + used_before_assignment |= b_used_before_assignment defined_syms -= used_before_assignment free_syms -= defined_syms @@ -3233,8 +3235,9 @@ def replace_dict(self, from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) - for _, cfg in self.branches: - cfg.replace_dict(repl, symrepl, replace_in_graph) + for _, region in self.branches: + if region is not None: + region.replace_dict(repl, symrepl, replace_in_graph) def to_json(self, parent=None): json = super().to_json(parent) @@ -3245,8 +3248,11 @@ def to_json(self, parent=None): def from_json(cls, json_obj, context=None): cond_region = ConditionalRegion(json_obj["label"]) cond_region.is_collapsed = json_obj["collapsed"] - cond_region.branches = [(CodeBlock.from_json(condition), ControlFlowRegion.from_json(cfg, context)) - for condition, cfg in json_obj["branches"]] + for condition, region in json_obj["branches"]: + if region is not None: + cond_region.branches.append((CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context))) + else: + cond_region.branches.append((CodeBlock.from_json(condition), None)) return cond_region def inline(self) -> Tuple[bool, Any]: @@ -3273,11 +3279,11 @@ def inline(self) -> Tuple[bool, Any]: parent.remove_edge(a_edge) from dace.sdfg.sdfg import InterstateEdge - for condition, cfg in self.branches: - if cfg.number_of_nodes() > 0: - parent.add_node(cfg) - parent.add_edge(guard_state, cfg, InterstateEdge(condition=condition)) - parent.add_edge(cfg, end_state, InterstateEdge()) + for condition, region in self.branches: + if region is not None: + parent.add_node(region) + parent.add_edge(guard_state, region, InterstateEdge(condition=condition)) + parent.add_edge(region, end_state, InterstateEdge()) else: parent.add_edge(guard_state, end_state, InterstateEdge(condition=condition)) diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 4f03b35b3d..edbd0cd349 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -119,8 +119,9 @@ def validate_control_flow_region(sdfg: 'SDFG', validate_state(edge.dst, region.node_id(edge.dst), sdfg, symbols, initialized_transients, references, **context) elif isinstance(edge.dst, ConditionalRegion): - for _, cfg in edge.dst.branches: - validate_control_flow_region(sdfg, cfg, initialized_transients, symbols, references, **context) + for _, r in edge.dst.branches: + if r is not None: + validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.dst, ControlFlowRegion): validate_control_flow_region(sdfg, edge.dst, initialized_transients, symbols, references, **context) # End of block DFS From 5cfcfbb4f05155cd54e818854605f7d7f421638f Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 25 Sep 2024 10:40:19 +0200 Subject: [PATCH 27/38] Lots of bugfixes --- dace/codegen/codegen.py | 4 +- dace/codegen/control_flow.py | 113 ++++++++---------- dace/frontend/python/interface.py | 2 + dace/frontend/python/newast.py | 34 +++--- dace/sdfg/analysis/cutout.py | 19 ++- dace/sdfg/state.py | 101 ++++++++++------ dace/sdfg/utils.py | 7 +- dace/sdfg/validation.py | 4 +- .../conditional_regions_test.py | 89 ++++++++++++++ tests/python_frontend/loops_test.py | 5 +- tests/sdfg/conditional_region_test.py | 10 +- tests/sdfg/work_depth_test.py | 1 - 12 files changed, 247 insertions(+), 142 deletions(-) create mode 100644 tests/python_frontend/conditional_regions_test.py diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index 864362fe88..d1427bf037 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -9,7 +9,7 @@ from dace.codegen.targets import framecode from dace.codegen.codeobject import CodeObject from dace.config import Config -from dace.sdfg import infer_types, utils as sdutils +from dace.sdfg import infer_types # Import CPU code generator. TODO: Remove when refactored from dace.codegen.targets import cpp, cpu @@ -158,8 +158,6 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: """ from dace.codegen.targets.target import TargetCodeGenerator # Avoid import loop - sdutils.inline_conditional_regions(sdfg) - sdutils.inline_control_flow_regions(sdfg) # Before compiling, validate SDFG correctness if validate: sdfg.validate() diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index ae9351fc43..08ffa4d6ee 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -62,7 +62,7 @@ import sympy as sp from dace import dtypes from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import (BreakBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, +from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnBlock, SDFGState) from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.graph import Edge @@ -601,6 +601,43 @@ def children(self) -> List[ControlFlow]: return list(self.cases.values()) +@dataclass +class GeneralConditionalScope(ControlFlow): + """ General conditional block based on a conditional control flow region. """ + + conditional: ConditionalBlock + branch_bodies: List[Tuple[Optional[CodeBlock], ControlFlow]] + + def as_cpp(self, codegen, symbols) -> str: + sdfg = self.conditional.sdfg + expr = '' + for i in range(len(self.branch_bodies)): + branch = self.branch_bodies[i] + if branch[0] is not None: + cond = unparse_interstate_edge(branch[0].code, sdfg, codegen=codegen, symbols=symbols) + cond = cond.strip(';') + if i == 0: + expr += f'if ({cond}) {{\n' + else: + expr += f'}} else if ({cond}) {{\n' + else: + if i < len(self.branch_bodies) - 1 or i == 0: + raise RuntimeError('Missing branch condition for non-final conditional branch') + expr += '} else {\n' + expr += branch[1].as_cpp(codegen, symbols) + if i == len(self.branch_bodies) - 1: + expr += '}\n' + return expr + + @property + def first_block(self) -> ControlFlowBlock: + return self.conditional + + @property + def children(self) -> List[ControlFlow]: + return [b for _, b in self.branch_bodies] + + def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge], leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]], dispatch_state: Callable[[SDFGState], @@ -973,7 +1010,6 @@ def _structured_control_flow_traversal_with_regions(cfg: ControlFlowRegion, if branch_merges is None: branch_merges = cfg_analysis.branch_merges(cfg) - if ptree is None: ptree = cfg_analysis.block_parent_tree(cfg, with_loops=False) @@ -1004,6 +1040,14 @@ def make_empty_block(): cfg_block = ContinueCFBlock(dispatch_state, parent_block, True, node) elif isinstance(node, ReturnBlock): cfg_block = ReturnCFBlock(dispatch_state, parent_block, True, node) + elif isinstance(node, ConditionalBlock): + cfg_block = GeneralConditionalScope(dispatch_state, parent_block, False, node, []) + for cond, branch in node.branches: + if branch is not None: + body = make_empty_block() + body.parent = cfg_block + _structured_control_flow_traversal_with_regions(branch, dispatch_state, body) + cfg_block.branch_bodies.append((cond, body)) elif isinstance(node, ControlFlowRegion): if isinstance(node, LoopRegion): body = make_empty_block() @@ -1027,69 +1071,8 @@ def make_empty_block(): stack.append(oe[0].dst) parent_block.elements.append(cfg_block) continue - - # Potential branch or loop - if node in branch_merges: - mergeblock = branch_merges[node] - - # Add branching node and ignore outgoing edges - parent_block.elements.append(cfg_block) - parent_block.gotos_to_ignore.extend(oe) # TODO: why? - parent_block.assignments_to_ignore.extend(oe) # TODO: why? - cfg_block.last_block = True - - # Parse all outgoing edges recursively first - cblocks: Dict[Edge[InterstateEdge], GeneralBlock] = {} - for branch in oe: - if branch.dst is mergeblock: - # If we hit the merge state (if without else), defer to end of branch traversal - continue - cblocks[branch] = make_empty_block() - _structured_control_flow_traversal_with_regions(cfg=cfg, - dispatch_state=dispatch_state, - parent_block=cblocks[branch], - start=branch.dst, - stop=mergeblock, - generate_children_of=node, - branch_merges=branch_merges, - ptree=ptree, - visited=visited) - - # Classify branch type: - branch_block = None - # If there are 2 out edges, one negation of the other: - # * if/else in case both branches are not merge state - # * if without else in case one branch is merge state - if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())): - if oe[0].dst is mergeblock: - # If without else - branch_block = IfScope(dispatch_state, parent_block, False, node, oe[1].data.condition, - cblocks[oe[1]]) - elif oe[1].dst is mergeblock: - branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition, - cblocks[oe[0]]) - else: - branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition, - cblocks[oe[0]], cblocks[oe[1]]) - else: - # If there are 2 or more edges (one is not the negation of the - # other): - switch = _cases_from_branches(oe, cblocks) - if switch: - # If all edges are of form "x == y" for a single x and - # integer y, it is a switch/case - branch_block = SwitchCaseScope(dispatch_state, parent_block, False, node, switch[0], switch[1]) - else: - # Otherwise, create if/else if/.../else goto exit chain - branch_block = IfElseChain(dispatch_state, parent_block, False, node, - [(e.data.condition, cblocks[e] if e in cblocks else make_empty_block()) - for e in oe]) - # End of branch classification - parent_block.elements.append(branch_block) - if mergeblock != stop: - stack.append(mergeblock) - - else: # No merge state: Unstructured control flow + else: + # Unstructured control flow. parent_block.sequential = False parent_block.elements.append(cfg_block) stack.extend([e.dst for e in oe]) diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index c4f6c827a9..14164054d3 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -69,6 +69,8 @@ def program(f: F, not depend on internal variables are constant. This will hardcode their return values into the resulting program. + :param use_experimental_cfg_blocks: If True, makes use of experimental CFG blocks susch as loop and conditional + regions. :note: If arguments are defined with type hints, the program can be compiled ahead-of-time with ``.compile()``. """ diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index ffd4dc1db3..b0c9f11c62 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3,7 +3,6 @@ from collections import OrderedDict import copy import itertools -import inspect import networkx as nx import re import sys @@ -25,14 +24,13 @@ from dace.frontend.python.astutils import ExtNodeVisitor, ExtNodeTransformer from dace.frontend.python.astutils import rname from dace.frontend.python import nested_call, replacements, preprocessing -from dace.frontend.python.memlet_parser import (DaceSyntaxError, parse_memlet, pyexpr_to_symbolic, ParseMemlet, - inner_eval_ast, MemletExpr) -from dace.sdfg import nodes, utils as sdutil +from dace.frontend.python.memlet_parser import DaceSyntaxError, parse_memlet, ParseMemlet, inner_eval_ast, MemletExpr +from dace.sdfg import nodes from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock from dace.sdfg import SDFG, SDFGState -from dace.sdfg.state import (BreakBlock, ConditionalRegion, ContinueBlock, ControlFlowBlock, FunctionCallRegion, +from dace.sdfg.state import (BranchRegion, BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, FunctionCallRegion, LoopRegion, ControlFlowRegion, NamedRegion) from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -2552,37 +2550,35 @@ def visit_Break(self, node: ast.Break): def visit_Continue(self, node: ast.Continue): if not self._has_loop_ancestor(self.cfg_target): - raise DaceSyntaxError(self, node, "Continue block outside loop region") + raise DaceSyntaxError(self, node, 'Continue block outside loop region') continue_block = ContinueBlock(f'continue_{node.lineno}') self.cfg_target.add_node(continue_block, ensure_unique_name=True) self._on_block_added(continue_block) def visit_If(self, node: ast.If): # Generate conditions - cond, cond_else, _ = self._visit_test(node.test) + cond, _, _ = self._visit_test(node.test) # Add conditional region - cond_region = ConditionalRegion(f"if_{node.lineno}") - self.cfg_target.add_node(cond_region) - self._on_block_added(cond_region) + cond_block = ConditionalBlock(f'if_{node.lineno}') + self.cfg_target.add_node(cond_block) + self._on_block_added(cond_block) - if_body = ControlFlowRegion(cond_region.label + "_body", sdfg=self.sdfg) - cond_region.branches.append((CodeBlock(cond), if_body)) - if_body.parent_graph = cond_region + if_body = BranchRegion(cond_block.label + '_body', sdfg=self.sdfg) + cond_block.branches.append((CodeBlock(cond), if_body)) + if_body.parent_graph = self.cfg_target # Visit recursively self._recursive_visit(node.body, 'if', node.lineno, if_body, False) - # Process 'else'/'elif' statements if len(node.orelse) > 0: - else_body = ControlFlowRegion(f"{cond_region.label}_else_{node.orelse[0].lineno}", sdfg=self.sdfg) - cond_region.branches.append((CodeBlock(cond_else), else_body)) - else_body.parent_graph = cond_region + else_body = BranchRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) + #cond_block.branches.append((CodeBlock(cond_else), else_body)) + cond_block.branches.append((None, else_body)) + else_body.parent_graph = self.cfg_target # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) - else: - cond_region.branches.append((CodeBlock(cond_else), None)) def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index 50272167bb..5d2eae7c6f 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -13,7 +13,7 @@ from dace.sdfg import nodes as nd, SDFG, SDFGState, utils as sdutil, InterstateEdge from dace.memlet import Memlet from dace.sdfg.graph import Edge, MultiConnectorEdge -from dace.sdfg.state import StateSubgraphView, SubgraphView +from dace.sdfg.state import ControlFlowBlock, StateSubgraphView, SubgraphView from dace.transformation.transformation import (MultiStateTransformation, PatternTransformation, SubgraphTransformation, @@ -321,7 +321,8 @@ def singlestate_cutout(cls, @classmethod def multistate_cutout(cls, *states: SDFGState, - make_side_effects_global: bool = True) -> Union['SDFGCutout', SDFG]: + make_side_effects_global: bool = True, + override_start_block: Optional[ControlFlowBlock] = None) -> Union['SDFGCutout', SDFG]: """ Cut out a multi-state subgraph from an SDFG to run separately for localized testing or optimization. @@ -336,6 +337,9 @@ def multistate_cutout(cls, :param make_side_effects_global: If True, all transient data containers which are read inside the cutout but may be written to _before_ the cutout, or any data containers which are written to inside the cutout but may be read _after_ the cutout, are made global. + :param override_start_block: If set, explicitly force a given control flow block to be the start block. If left + None (default), the start block is automatically determined based on domination + relationships in the original graph. :return: The created SDFGCutout or the original SDFG where no smaller cutout could be obtained. """ create_element = copy.deepcopy @@ -350,10 +354,13 @@ def multistate_cutout(cls, # Determine the start state and ensure there IS a unique start state. If there is no unique start state, keep # adding states from the predecessor frontier in the state machine until a unique start state can be determined. start_state: Optional[SDFGState] = None - for state in cutout_states: - if state == sdfg.start_state: - start_state = state - break + if override_start_block is not None: + start_state = override_start_block + else: + for state in cutout_states: + if state == sdfg.start_state: + start_state = state + break if start_state is None: bfs_queue: Deque[Tuple[Set[SDFGState], Set[Edge[InterstateEdge]]]] = deque() diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c190471c0e..eb2e46ea4e 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1140,6 +1140,11 @@ def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): """ self._default_lineinfo = lineinfo + def view(self): + from dace.sdfg.analysis.cutout import SDFGCutout + cutout = SDFGCutout.multistate_cutout(self, make_side_effects_global=False, override_start_block=self) + cutout.view() + def to_json(self, parent=None): tmp = { 'type': self.__class__.__name__, @@ -2752,6 +2757,9 @@ def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegi yield from node.sdfg.all_control_flow_regions(recursive=recursive) elif isinstance(block, ControlFlowRegion): yield from block.all_control_flow_regions(recursive=recursive) + elif isinstance(block, ConditionalBlock): + for _, branch in block.branches: + yield from branch.all_control_flow_regions(recursive=recursive) def all_sdfgs_recursive(self) -> Iterator['SDFG']: """ Iterate over this and all nested SDFGs. """ @@ -2766,7 +2774,7 @@ def all_states(self) -> Iterator[SDFGState]: yield block elif isinstance(block, ControlFlowRegion): yield from block.all_states() - elif isinstance(block, ConditionalRegion): + elif isinstance(block, ConditionalBlock): for _, region in block.branches: if region is not None: yield from region.all_states() @@ -2803,7 +2811,7 @@ def _used_symbols_internal(self, for block in ordered_blocks: state_symbols = set() - if isinstance(block, (ControlFlowRegion, ConditionalRegion)): + if isinstance(block, (ControlFlowRegion, ConditionalBlock)): b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols, defined_syms, free_syms, @@ -3035,7 +3043,7 @@ def inline(self) -> Tuple[bool, Any]: # and return are inlined correctly. def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: for block in region.nodes(): - if (isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalRegion)) and not isinstance(block, LoopRegion): + if (isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalBlock)) and not isinstance(block, LoopRegion): recursive_inline_cf_regions(block) block.inline() recursive_inline_cf_regions(self) @@ -3204,36 +3212,56 @@ def has_return(self) -> bool: return True return False -@dace.serialize.serializable -class ConditionalRegion(ControlFlowBlock, ControlGraphView): + +class BranchRegion(ControlFlowRegion): + + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None): + super().__init__(label, sdfg) + + +@make_properties +class ConditionalBlock(ControlFlowBlock, ControlGraphView): + + _branches: List[Tuple[Optional[CodeBlock], BranchRegion]] + def __init__(self, label: str): super().__init__(label) - self.branches: List[Tuple[CodeBlock, ControlFlowRegion]] = [] + self._branches = [] + + def __str__(self): + return self._label + + def __repr__(self) -> str: + return f'ConditionalBlock ({self.label})' + + @property + def branches(self) -> List[Tuple[Optional[CodeBlock], BranchRegion]]: + return self._branches def nodes(self) -> List['ControlFlowBlock']: - return [node for _, node in self.branches if node is not None] + return [node for _, node in self._branches if node is not None] def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: return [] def _used_symbols_internal(self, - all_symbols: bool, - defined_syms: Optional[Set] = None, - free_syms: Optional[Set] = None, - used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + all_symbols: bool, + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment - for condition, region in self.branches: - if region is not None: + for condition, region in self._branches: + if condition is not None: free_syms |= condition.get_free_symbols(defined_syms) - b_free_symbols, b_defined_symbols, b_used_before_assignment = region._used_symbols_internal( - all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) - free_syms |= b_free_symbols - defined_syms |= b_defined_symbols - used_before_assignment |= b_used_before_assignment + b_free_symbols, b_defined_symbols, b_used_before_assignment = region._used_symbols_internal( + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) + free_syms |= b_free_symbols + defined_syms |= b_defined_symbols + used_before_assignment |= b_used_before_assignment defined_syms -= used_before_assignment free_syms -= defined_syms @@ -3241,32 +3269,33 @@ def _used_symbols_internal(self, return free_syms, defined_syms, used_before_assignment def replace_dict(self, - repl: Dict[str, str], - symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, - replace_in_graph: bool = True, - replace_keys: bool = True): + repl: Dict[str, str], + symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, + replace_in_graph: bool = True, + replace_keys: bool = True): if replace_keys: from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) - for _, region in self.branches: - if region is not None: - region.replace_dict(repl, symrepl, replace_in_graph) + for _, region in self._branches: + region.replace_dict(repl, symrepl, replace_in_graph) def to_json(self, parent=None): json = super().to_json(parent) - json["branches"] = [(condition.to_json(), cfg.to_json()) for condition, cfg in self.branches] + json['branches'] = [(condition.to_json() if condition is not None else None, cfg.to_json()) + for condition, cfg in self._branches] return json @classmethod def from_json(cls, json_obj, context=None): - cond_region = ConditionalRegion(json_obj["label"]) - cond_region.is_collapsed = json_obj["collapsed"] - for condition, region in json_obj["branches"]: - if region is not None: - cond_region.branches.append((CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context))) + cond_region = ConditionalBlock(json_obj['label']) + cond_region.is_collapsed = json_obj['collapsed'] + for condition, region in json_obj['branches']: + if condition is not None: + cond_region._branches.append((CodeBlock.from_json(condition), + BranchRegion.from_json(region, context))) else: - cond_region.branches.append((CodeBlock.from_json(condition), None)) + cond_region._branches.append((None, BranchRegion.from_json(region, context))) return cond_region def inline(self) -> Tuple[bool, Any]: @@ -3293,7 +3322,7 @@ def inline(self) -> Tuple[bool, Any]: parent.remove_edge(a_edge) from dace.sdfg.sdfg import InterstateEdge - for condition, region in self.branches: + for condition, region in self._branches: if region is not None: parent.add_node(region) parent.add_edge(guard_state, region, InterstateEdge(condition=condition)) @@ -3311,14 +3340,18 @@ def inline(self) -> Tuple[bool, Any]: @make_properties class NamedRegion(ControlFlowRegion): + debuginfo = DebugInfoProperty() + def __init__(self, label: str, sdfg: Optional['SDFG']=None, debuginfo: Optional[dtypes.DebugInfo]=None): super().__init__(label, sdfg) self.debuginfo = debuginfo @make_properties class FunctionCallRegion(ControlFlowRegion): + arguments = DictProperty(str, str) + def __init__(self, label: str, arguments: Dict[str, str] = {}, sdfg: 'SDFG' = None): super().__init__(label, sdfg) self.arguments = arguments diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 4b3830257e..87039fa27a 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,7 +13,7 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import ConditionalRegion, SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs @@ -1262,8 +1262,7 @@ def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = No def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() - if isinstance(n, ControlFlowRegion) and not isinstance(n, (LoopRegion, SDFG))] + blocks = [n for n, _ in sdfg.all_nodes_recursive() if type(n) is ControlFlowRegion] count = 0 for _block in optional_progressbar(reversed(blocks), title='Inlining control flow blocks', @@ -1275,7 +1274,7 @@ def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: return count def inline_conditional_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalRegion)] + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)] count = 0 for _block in optional_progressbar(reversed(blocks), title='Inlining conditional regions', diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index edbd0cd349..d2752f086d 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -34,7 +34,7 @@ def validate_control_flow_region(sdfg: 'SDFG', symbols: dict, references: Set[int] = None, **context: bool): - from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalRegion + from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock from dace.sdfg.scope import is_in_scope if len(region.source_nodes()) > 1 and region.start_block is None: @@ -118,7 +118,7 @@ def validate_control_flow_region(sdfg: 'SDFG', if isinstance(edge.dst, SDFGState): validate_state(edge.dst, region.node_id(edge.dst), sdfg, symbols, initialized_transients, references, **context) - elif isinstance(edge.dst, ConditionalRegion): + elif isinstance(edge.dst, ConditionalBlock): for _, r in edge.dst.branches: if r is not None: validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) diff --git a/tests/python_frontend/conditional_regions_test.py b/tests/python_frontend/conditional_regions_test.py new file mode 100644 index 0000000000..d9463abce7 --- /dev/null +++ b/tests/python_frontend/conditional_regions_test.py @@ -0,0 +1,89 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +from dace.sdfg.state import ConditionalBlock + + +def test_dataflow_if_check(): + + @dace.program + def dataflow_if_check(A: dace.float64, i: dace.int64): + if A[i] < 10: + return 0 + elif A[i] == 10: + return 10 + return 100 + + dataflow_if_check.use_experimental_cfg_blocks = True + sdfg = dataflow_if_check.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + A = [0., 0., 0., 0., 10., 100., 0.] + assert sdfg(A, 0)[0] == 0 + assert sdfg(A, 4)[0] == 10 + assert sdfg(A, 5)[0] == 100 + assert sdfg(A, 6)[0] == 0 + + +def test_nested_if_chain(): + + @dace.program + def nested_if_chain(i: dace.int64): + if i < 2: + return 0 + else: + if i < 4: + return 1 + else: + if i < 6: + return 2 + else: + if i < 8: + return 3 + else: + return 4 + + nested_if_chain.use_experimental_cfg_blocks = True + sdfg = nested_if_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert nested_if_chain(0)[0] == 0 + assert nested_if_chain(2)[0] == 1 + assert nested_if_chain(4)[0] == 2 + assert nested_if_chain(7)[0] == 3 + assert nested_if_chain(15)[0] == 4 + + +def test_elif_chain(): + + @dace.program + def elif_chain(i: dace.int64): + if i < 2: + return 0 + elif i < 4: + return 1 + elif i < 6: + return 2 + elif i < 8: + return 3 + else: + return 4 + + elif_chain.use_experimental_cfg_blocks = True + sdfg = elif_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert elif_chain(0)[0] == 0 + assert elif_chain(2)[0] == 1 + assert elif_chain(4)[0] == 2 + assert elif_chain(7)[0] == 3 + assert elif_chain(15)[0] == 4 + + +if __name__ == '__main__': + test_dataflow_if_check() + test_nested_if_chain() + test_elif_chain() diff --git a/tests/python_frontend/loops_test.py b/tests/python_frontend/loops_test.py index 2d8b2dc83c..d4a54a0456 100644 --- a/tests/python_frontend/loops_test.py +++ b/tests/python_frontend/loops_test.py @@ -416,7 +416,7 @@ def test_nested_map_with_symbol(): reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_for_else(): - @dace.program(use_experimental_cfg_blocks=True) + @dace.program() def for_else(A: dace.float64[20]): for i in range(1, 20): if A[i] >= 10: @@ -496,12 +496,10 @@ def branch_in_while(cond: dace.int32): def test_branch_in_while(): sdfg = branch_in_while.to_sdfg(simplify=False) - sdfg.save("branch_in_while.sdfg") assert len(sdfg.source_nodes()) == 1 if __name__ == "__main__": - test_branch_in_while() test_for_loop() test_for_loop_with_break_continue() test_nested_for_loop() @@ -523,3 +521,4 @@ def test_branch_in_while(): test_for_else() test_while_else() test_branch_in_for() + test_branch_in_while() diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index 767f81983b..4e4eda3f44 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -4,7 +4,7 @@ import dace from dace.properties import CodeBlock from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import ConditionalRegion, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion import dace.serialize @@ -14,7 +14,7 @@ def test_cond_region_if(): sdfg.add_symbol("i", dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalRegion("if1") + if1 = ConditionalBlock("if1") sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) @@ -37,7 +37,7 @@ def test_cond_region_if(): def test_serialization(): sdfg = SDFG("test_serialization") - cond_region = ConditionalRegion("cond_region") + cond_region = ConditionalBlock("cond_region") sdfg.add_node(cond_region, is_start_block=True) sdfg.add_symbol("i", dace.int32) @@ -49,7 +49,7 @@ def test_serialization(): new_sdfg = SDFG.from_json(sdfg.to_json()) assert new_sdfg.is_valid() - new_cond_region: ConditionalRegion = new_sdfg.nodes()[0] + new_cond_region: ConditionalBlock = new_sdfg.nodes()[0] for j in range(10): condition, cfg = new_cond_region.branches[j] assert condition == CodeBlock(f"i == {j}") @@ -61,7 +61,7 @@ def test_if_else(): sdfg.add_symbol("i", dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalRegion("if1") + if1 = ConditionalBlock("if1") sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index b4a986aa1d..c360a59529 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -317,7 +317,6 @@ def test_assumption_system_contradictions(assumptions): if __name__ == '__main__': - test_work_depth("unbounded_while_do") for test_name in work_depth_test_cases.keys(): test_work_depth(test_name) From a72bc917cb9e0f9bacca82dfd8153ceb862fa829 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 25 Sep 2024 12:37:35 +0200 Subject: [PATCH 28/38] Bugfixes --- dace/sdfg/state.py | 3 +-- tests/python_frontend/conditional_regions_test.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index eb2e46ea4e..3d95f635a7 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2776,8 +2776,7 @@ def all_states(self) -> Iterator[SDFGState]: yield from block.all_states() elif isinstance(block, ConditionalBlock): for _, region in block.branches: - if region is not None: - yield from region.all_states() + yield from region.all_states() def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: """ Iterate over all control flow blocks in this control flow graph. """ diff --git a/tests/python_frontend/conditional_regions_test.py b/tests/python_frontend/conditional_regions_test.py index d9463abce7..07e214653c 100644 --- a/tests/python_frontend/conditional_regions_test.py +++ b/tests/python_frontend/conditional_regions_test.py @@ -1,13 +1,14 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace +import numpy as np from dace.sdfg.state import ConditionalBlock def test_dataflow_if_check(): @dace.program - def dataflow_if_check(A: dace.float64, i: dace.int64): + def dataflow_if_check(A: dace.int32[10], i: dace.int64): if A[i] < 10: return 0 elif A[i] == 10: @@ -19,7 +20,9 @@ def dataflow_if_check(A: dace.float64, i: dace.int64): assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) - A = [0., 0., 0., 0., 10., 100., 0.] + A = np.zeros((10,), np.int32) + A[4] = 10 + A[5] = 100 assert sdfg(A, 0)[0] == 0 assert sdfg(A, 4)[0] == 10 assert sdfg(A, 5)[0] == 100 From 6d41d635c841a5b87229d459019d0443482c095b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 25 Sep 2024 13:03:45 +0200 Subject: [PATCH 29/38] Codegen bugfix --- dace/codegen/control_flow.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 08ffa4d6ee..7701a19ec2 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -236,14 +236,18 @@ def first_block(self) -> ReturnBlock: @dataclass -class GeneralBlock(ControlFlow): - """ - General (or unrecognized) control flow block with gotos between blocks. - """ +class RegionBlock(ControlFlow): # The control flow region that this block corresponds to (may be the SDFG in the absence of hierarchical regions). region: Optional[ControlFlowRegion] + +@dataclass +class GeneralBlock(RegionBlock): + """ + General (or unrecognized) control flow block with gotos between blocks. + """ + # List of children control flow blocks elements: List[ControlFlow] @@ -270,7 +274,7 @@ def as_cpp(self, codegen, symbols) -> str: for i, elem in enumerate(self.elements): expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. - if isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region): + if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region): cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region) @@ -514,10 +518,9 @@ def children(self) -> List[ControlFlow]: @dataclass -class GeneralLoopScope(ControlFlow): +class GeneralLoopScope(RegionBlock): """ General loop block based on a loop control flow region. """ - loop: LoopRegion body: ControlFlow def as_cpp(self, codegen, symbols) -> str: @@ -565,6 +568,10 @@ def as_cpp(self, codegen, symbols) -> str: return expr + @property + def loop(self) -> LoopRegion: + return self.region + @property def first_block(self) -> ControlFlowBlock: return self.loop.start_block @@ -602,10 +609,9 @@ def children(self) -> List[ControlFlow]: @dataclass -class GeneralConditionalScope(ControlFlow): +class GeneralConditionalScope(RegionBlock): """ General conditional block based on a conditional control flow region. """ - conditional: ConditionalBlock branch_bodies: List[Tuple[Optional[CodeBlock], ControlFlow]] def as_cpp(self, codegen, symbols) -> str: @@ -629,6 +635,10 @@ def as_cpp(self, codegen, symbols) -> str: expr += '}\n' return expr + @property + def conditional(self) -> ConditionalBlock: + return self.region + @property def first_block(self) -> ControlFlowBlock: return self.conditional From eb8719abed75ffa8854ed906089341e025cd7d17 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 25 Sep 2024 13:27:29 +0200 Subject: [PATCH 30/38] Remove unnecessary BranchRegion type --- dace/frontend/python/newast.py | 7 ++++--- dace/sdfg/state.py | 14 ++++---------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index b0c9f11c62..e5528a1bd2 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -30,7 +30,7 @@ from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock from dace.sdfg import SDFG, SDFGState -from dace.sdfg.state import (BranchRegion, BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, FunctionCallRegion, +from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, FunctionCallRegion, LoopRegion, ControlFlowRegion, NamedRegion) from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -2564,7 +2564,7 @@ def visit_If(self, node: ast.If): self.cfg_target.add_node(cond_block) self._on_block_added(cond_block) - if_body = BranchRegion(cond_block.label + '_body', sdfg=self.sdfg) + if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg) cond_block.branches.append((CodeBlock(cond), if_body)) if_body.parent_graph = self.cfg_target @@ -2573,7 +2573,8 @@ def visit_If(self, node: ast.If): # Process 'else'/'elif' statements if len(node.orelse) > 0: - else_body = BranchRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) + else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', + sdfg=self.sdfg) #cond_block.branches.append((CodeBlock(cond_else), else_body)) cond_block.branches.append((None, else_body)) else_body.parent_graph = self.cfg_target diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 3d95f635a7..3747c83a4d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3212,16 +3212,10 @@ def has_return(self) -> bool: return False -class BranchRegion(ControlFlowRegion): - - def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None): - super().__init__(label, sdfg) - - @make_properties class ConditionalBlock(ControlFlowBlock, ControlGraphView): - _branches: List[Tuple[Optional[CodeBlock], BranchRegion]] + _branches: List[Tuple[Optional[CodeBlock], ControlFlowRegion]] def __init__(self, label: str): super().__init__(label) @@ -3234,7 +3228,7 @@ def __repr__(self) -> str: return f'ConditionalBlock ({self.label})' @property - def branches(self) -> List[Tuple[Optional[CodeBlock], BranchRegion]]: + def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]: return self._branches def nodes(self) -> List['ControlFlowBlock']: @@ -3292,9 +3286,9 @@ def from_json(cls, json_obj, context=None): for condition, region in json_obj['branches']: if condition is not None: cond_region._branches.append((CodeBlock.from_json(condition), - BranchRegion.from_json(region, context))) + ControlFlowRegion.from_json(region, context))) else: - cond_region._branches.append((None, BranchRegion.from_json(region, context))) + cond_region._branches.append((None, ControlFlowRegion.from_json(region, context))) return cond_region def inline(self) -> Tuple[bool, Any]: From 3e21bceee187c41643b68fbf2f6bea61fd97640f Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 25 Sep 2024 14:19:14 +0200 Subject: [PATCH 31/38] Fix conditional inlining --- dace/frontend/python/astutils.py | 42 ++++++++++++++++++++++++++++++++ dace/sdfg/state.py | 27 +++++++++++++++++--- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index c9a400e5f1..425e94cd9f 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -384,6 +384,48 @@ def negate_expr(node): return ast.fix_missing_locations(newexpr) +def and_expr(node_a, node_b): + """ Generates the logical AND of two AST expressions. + """ + if type(node_a) is not type(node_b): + raise ValueError('Node types do not match') + + # Support for SymPy expressions + if isinstance(node_a, sympy.Basic): + return sympy.And(node_a, node_b) + # Support for numerical constants + if isinstance(node_a, (numbers.Number, numpy.bool_)): + return str(node_a and node_b) + # Support for strings (most likely dace.Data.Scalar names) + if isinstance(node_a, str): + return f'({node_a}) and ({node_b})' + + from dace.properties import CodeBlock # Avoid import loop + if isinstance(node_a, CodeBlock): + node_a = node_a.code + node_b = node_b.code + + if hasattr(node_a, "__len__"): + if len(node_a) > 1: + raise ValueError("and_expr only expects single expressions, got: {}".format(node_a)) + if len(node_b) > 1: + raise ValueError("and_expr only expects single expressions, got: {}".format(node_b)) + expr_a = node_a[0] + expr_b = node_b[0] + else: + expr_a = node_a + expr_b = node_b + + if isinstance(expr_a, ast.Expr): + expr_a = expr_a.value + if isinstance(expr_b, ast.Expr): + expr_b = expr_b.value + + newexpr = ast.Expr(value=ast.BinOp(left=copy_tree(expr_a), op=ast.And, right=copy_tree(expr_b))) + newexpr = ast.copy_location(newexpr, expr_a) + return ast.fix_missing_locations(newexpr) + + def copy_tree(node: ast.AST) -> ast.AST: """ Copies an entire AST without copying the non-AST parts (e.g., constant values). diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 3747c83a4d..f6dae42843 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -11,7 +11,10 @@ from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload) +import sympy + import dace +from dace.frontend.python import astutils import dace.serialize from dace import data as dt from dace import dtypes @@ -3315,13 +3318,31 @@ def inline(self) -> Tuple[bool, Any]: parent.remove_edge(a_edge) from dace.sdfg.sdfg import InterstateEdge + else_branch = None + full_cond_expression: Optional[List[ast.AST]] = None for condition, region in self._branches: - if region is not None: + if condition is None: + else_branch = region + else: + if full_cond_expression is None: + full_cond_expression = condition.code[0] + else: + full_cond_expression = astutils.and_expr(full_cond_expression, condition.code[0]) parent.add_node(region) parent.add_edge(guard_state, region, InterstateEdge(condition=condition)) parent.add_edge(region, end_state, InterstateEdge()) - else: - parent.add_edge(guard_state, end_state, InterstateEdge(condition=condition)) + if full_cond_expression is not None: + negative_full_cond = astutils.negate_expr(full_cond_expression) + negative_cond = CodeBlock([negative_full_cond]) + else: + negative_cond = CodeBlock('1') + + if else_branch is not None: + parent.add_node(else_branch) + parent.add_edge(guard_state, else_branch, InterstateEdge(condition=negative_cond)) + parent.add_edge(region, end_state, InterstateEdge()) + else: + parent.add_edge(guard_state, end_state, InterstateEdge(condition=negative_cond)) parent.remove_node(self) From 426fceca29a7f62c44b8b122acbd4e0398c87421 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 25 Sep 2024 15:09:40 +0200 Subject: [PATCH 32/38] Fixes in inlining, again --- dace/frontend/python/parser.py | 3 +-- dace/sdfg/state.py | 7 ++++--- dace/sdfg/utils.py | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 1c3510c51f..b0ef56907f 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -495,9 +495,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF if not self.use_experimental_cfg_blocks: for nsdfg in sdfg.all_sdfgs_recursive(): - sdutils.inline_conditional_regions(nsdfg) + sdutils.inline_conditional_blocks(nsdfg) sdutils.inline_control_flow_regions(nsdfg) - sdutils.inline_loop_blocks(nsdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks # Apply simplification pass automatically diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index f6dae42843..26ed099ccd 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3362,10 +3362,11 @@ def __init__(self, label: str, sdfg: Optional['SDFG']=None, debuginfo: Optional[ self.debuginfo = debuginfo @make_properties -class FunctionCallRegion(ControlFlowRegion): +class FunctionCallRegion(NamedRegion): arguments = DictProperty(str, str) - def __init__(self, label: str, arguments: Dict[str, str] = {}, sdfg: 'SDFG' = None): - super().__init__(label, sdfg) + def __init__(self, label: str, arguments: Dict[str, str] = {}, sdfg: 'SDFG' = None, + debuginfo: Optional[dtypes.DebugInfo]=None): + super().__init__(label, sdfg, debuginfo) self.arguments = arguments diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 87039fa27a..5b9ce1a431 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1262,10 +1262,10 @@ def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = No def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if type(n) is ControlFlowRegion] + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ControlFlowRegion)] count = 0 - for _block in optional_progressbar(reversed(blocks), title='Inlining control flow blocks', + for _block in optional_progressbar(reversed(blocks), title='Inlining control flow regions', n=len(blocks), progress=progress): block: ControlFlowRegion = _block if block.inline()[0]: @@ -1273,13 +1273,13 @@ def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: return count -def inline_conditional_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: +def inline_conditional_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)] count = 0 - for _block in optional_progressbar(reversed(blocks), title='Inlining conditional regions', + for _block in optional_progressbar(reversed(blocks), title='Inlining conditional blocks', n=len(blocks), progress=progress): - block: ControlFlowRegion = _block + block: ConditionalBlock = _block if block.inline()[0]: count += 1 From fa91519dee2a49bf35922e7f6bf95d56855b1773 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 26 Sep 2024 08:59:29 +0200 Subject: [PATCH 33/38] Serialization fix --- dace/sdfg/state.py | 22 ++++++++++++------- .../transformation/interstate/state_fusion.py | 6 ++--- tests/python_frontend/loops_test.py | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 26ed099ccd..6c1b1168e2 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3220,8 +3220,8 @@ class ConditionalBlock(ControlFlowBlock, ControlGraphView): _branches: List[Tuple[Optional[CodeBlock], ControlFlowRegion]] - def __init__(self, label: str): - super().__init__(label) + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optional['ControlFlowRegion'] = None): + super().__init__(label, sdfg, parent) self._branches = [] def __str__(self): @@ -3284,15 +3284,21 @@ def to_json(self, parent=None): @classmethod def from_json(cls, json_obj, context=None): - cond_region = ConditionalBlock(json_obj['label']) - cond_region.is_collapsed = json_obj['collapsed'] + context = context or {'sdfg': None, 'parent_graph': None} + _type = json_obj['type'] + if _type != cls.__name__: + raise TypeError('Class type mismatch') + + ret = cls(label=json_obj['label'], sdfg=context['sdfg']) + + dace.serialize.set_properties_from_json(ret, json_obj) + for condition, region in json_obj['branches']: if condition is not None: - cond_region._branches.append((CodeBlock.from_json(condition), - ControlFlowRegion.from_json(region, context))) + ret._branches.append((CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context))) else: - cond_region._branches.append((None, ControlFlowRegion.from_json(region, context))) - return cond_region + ret._branches.append((None, ControlFlowRegion.from_json(region, context))) + return ret def inline(self) -> Tuple[bool, Any]: """ diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 5362695b49..3abbe085f5 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -461,8 +461,6 @@ def apply(self, _, sdfg): graph = first_state.parent_graph - start_block = graph.start_block - # Remove interstate edge(s) edges = graph.edges_between(first_state, second_state) for edge in edges: @@ -475,7 +473,7 @@ def apply(self, _, sdfg): if first_state.is_empty(): sdutil.change_edge_dest(graph, first_state, second_state) graph.remove_node(first_state) - if start_block == first_state: + if graph.start_block == first_state: graph.start_block = graph.node_id(second_state) return @@ -484,7 +482,7 @@ def apply(self, _, sdfg): sdutil.change_edge_src(graph, second_state, first_state) sdutil.change_edge_dest(graph, second_state, first_state) graph.remove_node(second_state) - if start_block == second_state: + if graph.start_block == second_state: graph.start_block = graph.node_id(first_state) return diff --git a/tests/python_frontend/loops_test.py b/tests/python_frontend/loops_test.py index d4a54a0456..e0c869f20c 100644 --- a/tests/python_frontend/loops_test.py +++ b/tests/python_frontend/loops_test.py @@ -416,7 +416,7 @@ def test_nested_map_with_symbol(): reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_for_else(): - @dace.program() + @dace.program def for_else(A: dace.float64[20]): for i in range(1, 20): if A[i] >= 10: From a8dc0f78df900471c27fe18abca2b77ed5cfa443 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 26 Sep 2024 09:35:44 +0200 Subject: [PATCH 34/38] Revert unnecessary test change --- tests/sdfg/work_depth_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index c360a59529..e677cca752 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -192,7 +192,7 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_3') * N, sp.Symbol('num_execs_0_3'))), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), 'break_for_loop': (break_for_loop, (N**2, N)), From 3632cc5f95618385ed977e5562fa323622dc8fca Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 26 Sep 2024 09:42:48 +0200 Subject: [PATCH 35/38] Fix codegen not detecting existence of experimental blocks --- dace/codegen/targets/framecode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index da25816f9b..488c1c7fbd 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -483,7 +483,7 @@ def dispatch_state(state: SDFGState) -> str: states_generated.add(state) # For sanity check return stream.getvalue() - if sdfg.root_sdfg.using_experimental_blocks: + if sdfg.root_sdfg.recheck_using_experimental_blocks(): # Use control flow blocks embedded in the SDFG to generate control flow. cft = cflow.structured_control_flow_tree_with_regions(sdfg, dispatch_state) elif config.Config.get_bool('optimizer', 'detect_control_flow'): From 40f3aea1cc1ad9ff89622fbe0a43a72d44e2948b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 26 Sep 2024 11:55:17 +0200 Subject: [PATCH 36/38] Fix shared transients --- dace/sdfg/analysis/cfg.py | 9 ++++++++- dace/sdfg/sdfg.py | 13 ++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index 1d5b1e50eb..c96ef5aff0 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -6,7 +6,7 @@ import sympy as sp from typing import Dict, Iterator, List, Optional, Set -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion def acyclic_dominance_frontier(cfg: ControlFlowRegion, idom=None) -> Dict[ControlFlowBlock, Set[ControlFlowBlock]]: @@ -374,6 +374,13 @@ def blockorder_topological_sort(cfg: ControlFlowRegion, yield block if recursive: yield from blockorder_topological_sort(block, recursive, ignore_nonstate_blocks) + elif isinstance(block, ConditionalBlock): + if not ignore_nonstate_blocks: + yield block + for _, branch in block.branches: + if not ignore_nonstate_blocks: + yield branch + yield from blockorder_topological_sort(branch, recursive, ignore_nonstate_blocks) elif isinstance(block, SDFGState): yield block else: diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 84d7189ebd..017532b135 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -23,7 +23,7 @@ from dace.config import Config from dace.frontend.python import astutils from dace.sdfg import nodes as nd -from dace.sdfg.state import ControlFlowBlock, SDFGState, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState, ControlFlowRegion from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty, @@ -1488,6 +1488,17 @@ def shared_transients(self, check_toplevel: bool = True, include_nested_data: bo seen[sym] = interstate_edge shared.append(sym) + # The same goes for the conditions of conditional blocks. + for block in self.all_control_flow_blocks(): + if isinstance(block, ConditionalBlock): + for cond, _ in block.branches: + if cond is not None: + cond_symbols = set(map(str, dace.symbolic.symbols_in_ast(cond.code[0]))) + for sym in cond_symbols: + if sym in self.arrays and self.arrays[sym].transient: + seen[sym] = block + shared.append(sym) + # If transient is accessed in more than one state, it is shared for state in self.states(): for node in state.data_nodes(): From e1fc2354d34f0457916e043be11fdf1f29801efb Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 27 Sep 2024 08:57:48 +0200 Subject: [PATCH 37/38] Address review comments --- dace/frontend/python/newast.py | 10 +++++----- dace/sdfg/state.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index f8c1c59b1f..0d40e13282 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2514,10 +2514,10 @@ def _connect_break_blocks(self, loop_region: LoopRegion): for node, parent in loop_region.all_nodes_recursive(lambda n, _: not isinstance(n, (LoopRegion, SDFGState))): if isinstance(node, BreakBlock): for in_edge in parent.in_edges(node): - in_edge.data.assignments["did_break_" + loop_region.label] = "1" + in_edge.data.assignments['__dace_did_break_' + loop_region.label] = '1' def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowBlock): - did_break_symbol = 'did_break_' + loop_region.label + did_break_symbol = '__dace_did_break_' + loop_region.label self.sdfg.add_symbol(did_break_symbol, dace.int32) for iedge in self.cfg_target.in_edges(loop_region): iedge.data.assignments[did_break_symbol] = '0' @@ -2527,14 +2527,14 @@ def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowB intermediate = self.cfg_target.add_state(f'{loop_region.label}_normal_exit') self.cfg_target.add_edge(loop_region, intermediate, - dace.InterstateEdge(condition=f"(not {did_break_symbol} == 1)")) + dace.InterstateEdge(condition=f'(not {did_break_symbol} == 1)')) oedge = oedges[0] self.cfg_target.add_edge(intermediate, oedge.dst, copy.deepcopy(oedge.data)) self.cfg_target.remove_edge(oedge) - self.cfg_target.add_edge(loop_region, postloop_block, dace.InterstateEdge(condition=f"{did_break_symbol} == 1")) + self.cfg_target.add_edge(loop_region, postloop_block, dace.InterstateEdge(condition=f'{did_break_symbol} == 1')) def _has_loop_ancestor(self, node: ControlFlowBlock) -> bool: - while node is not None and node != self.sdfg: + while node is not None and node is not self.sdfg: if isinstance(node, LoopRegion): return True node = node.parent_graph diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 6c1b1168e2..8d443e6beb 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2607,6 +2607,7 @@ def inline(self) -> Tuple[bool, Any]: for node in to_connect: parent.add_edge(node, end_state, dace.InterstateEdge()) else: + # TODO: Move this to dead state elimination. dead_blocks = [succ for succ in parent.successors(self) if parent.in_degree(succ) == 1] while dead_blocks: layer = list(dead_blocks) From f09fbe158ba5596055971e9e71481f266f44ca51 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 27 Sep 2024 09:54:45 +0200 Subject: [PATCH 38/38] Fix opt_einsum package upgrade --- dace/frontend/common/einsum.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index e2cc2be88b..407e9eb91c 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -3,7 +3,9 @@ from functools import reduce from itertools import chain from string import ascii_letters -from typing import Dict, Optional +from typing import Dict, List, Optional + +import numpy as np import dace from dace import dtypes, subsets, symbolic @@ -180,6 +182,19 @@ def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor', beta=beta)[0] +def _build_einsum_views(tensors: str, dimension_dict: dict) -> List[np.ndarray]: + """ + Function taken and adjusted from opt_einsum package version 3.3.0 following unexpected removal in vesion 3.4.0. + Reference: https://github.com/dgasmith/opt_einsum/blob/v3.3.0/opt_einsum/helpers.py#L18 + """ + views = [] + terms = tensors.split('->')[0].split(',') + for term in terms: + dims = [dimension_dict[x] for x in term] + views.append(np.random.rand(*dims)) + return views + + def _create_einsum_internal(sdfg: SDFG, state: SDFGState, einsum_string: str, @@ -231,7 +246,7 @@ def _create_einsum_internal(sdfg: SDFG, # Create optimal contraction path # noinspection PyTypeChecker - _, path_info = oe.contract_path(einsum_string, *oe.helpers.build_views(einsum_string, chardict)) + _, path_info = oe.contract_path(einsum_string, *_build_einsum_views(einsum_string, chardict)) input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} result_node = None