Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional Blocks #1666

Merged
merged 44 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9e01539
Create ConditionalRegion and implement its serialization
luca-patrignani Jun 30, 2024
4430c10
Inline conditional regions and add test for serialization
luca-patrignani Jul 1, 2024
c6f8cc8
Clean up test
luca-patrignani Jul 1, 2024
e49beac
Test if else
luca-patrignani Jul 1, 2024
e258d64
Fix typo
luca-patrignani Jul 1, 2024
281a620
Inline conditional regions
luca-patrignani Jul 2, 2024
c310a82
Parse complex tests region before adding the conditional region
luca-patrignani Jul 2, 2024
58451fa
Allow break and continue states inside cfgs nested in loops and remov…
luca-patrignani Jul 3, 2024
ebb69be
Call start_block before removing the current start block
luca-patrignani Jul 3, 2024
7f429ea
Avoid double additions of return blocks
luca-patrignani Jul 3, 2024
9fe2532
Do not invalidate start_block cache everytime a new node is added
luca-patrignani Jul 4, 2024
d91ef62
Fix typo
luca-patrignani Jul 4, 2024
c8b1bee
Inline all cfg for each nested sdfg and remove dead states after all …
luca-patrignani Jul 4, 2024
ad5e574
Set did_break_symbol to 1 in loop region inline method
luca-patrignani Jul 5, 2024
8b4a116
Merge branch 'master' into if-2
luca-patrignani Jul 6, 2024
25e52bb
Use codegen generation for loop regions
luca-patrignani Jul 9, 2024
e5aa40a
Raise exception when creating a break or continue block outside a loo…
luca-patrignani Jul 9, 2024
e4d9a85
Fix _used_symbols_internal in conditional region and remove dead bloc…
luca-patrignani Jul 10, 2024
e8deac1
Fix use of start_block in state fusion transformation
luca-patrignani Jul 10, 2024
9c506bd
Fix symbols internal for conditional region
luca-patrignani Jul 10, 2024
c614bcb
Connect arrays to views in all states for each cfg
luca-patrignani Jul 10, 2024
f4a0298
Merge branch 'master' into if-2
luca-patrignani Jul 10, 2024
7a4584e
Fix from_json
luca-patrignani Jul 10, 2024
feed8aa
Fix _used_symbols_internal in ControlFlowRegion for handling correctl…
luca-patrignani Jul 11, 2024
faad41d
Remove unused file
luca-patrignani Jul 11, 2024
39a183e
Revert ControlFlowRegion add_node change
luca-patrignani Jul 11, 2024
cbc51cd
Fix test
luca-patrignani Jul 11, 2024
575f351
Represent else branch as tuple (else condition, None)
luca-patrignani Jul 11, 2024
65e9486
Merge branch 'master' into if-2
phschaad Aug 20, 2024
b7aed06
Merge branch 'master' into users/phschaad/conditional_regions
phschaad Sep 24, 2024
08765ae
Merge remote-tracking branch 'origin/master' into users/phschaad/cond…
phschaad Sep 24, 2024
5cfcfbb
Lots of bugfixes
phschaad Sep 25, 2024
a72bc91
Bugfixes
phschaad Sep 25, 2024
6d41d63
Codegen bugfix
phschaad Sep 25, 2024
eb8719a
Remove unnecessary BranchRegion type
phschaad Sep 25, 2024
3e21bce
Fix conditional inlining
phschaad Sep 25, 2024
426fcec
Fixes in inlining, again
phschaad Sep 25, 2024
fa91519
Serialization fix
phschaad Sep 26, 2024
a8dc0f7
Revert unnecessary test change
phschaad Sep 26, 2024
3632cc5
Fix codegen not detecting existence of experimental blocks
phschaad Sep 26, 2024
40f3aea
Fix shared transients
phschaad Sep 26, 2024
bebb2c4
Merge branch 'master' into users/phschaad/conditional_regions
phschaad Sep 27, 2024
e1fc235
Address review comments
phschaad Sep 27, 2024
f09fbe1
Fix opt_einsum package upgrade
phschaad Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 65 additions & 72 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import sympy as sp
from dace import dtypes
from dace.sdfg.analysis import cfg as cfg_analysis
from dace.sdfg.state import (BreakBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion,
from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion,
ReturnBlock, SDFGState)
from dace.sdfg.sdfg import SDFG, InterstateEdge
from dace.sdfg.graph import Edge
Expand Down Expand Up @@ -236,14 +236,18 @@ def first_block(self) -> ReturnBlock:


@dataclass
class GeneralBlock(ControlFlow):
"""
General (or unrecognized) control flow block with gotos between blocks.
"""
class RegionBlock(ControlFlow):

# The control flow region that this block corresponds to (may be the SDFG in the absence of hierarchical regions).
region: Optional[ControlFlowRegion]


@dataclass
class GeneralBlock(RegionBlock):
"""
General (or unrecognized) control flow block with gotos between blocks.
"""

# List of children control flow blocks
elements: List[ControlFlow]

Expand All @@ -270,7 +274,7 @@ def as_cpp(self, codegen, symbols) -> str:
for i, elem in enumerate(self.elements):
expr += elem.as_cpp(codegen, symbols)
# In a general block, emit transitions and assignments after each individual block or region.
if isinstance(elem, BasicCFBlock) or (isinstance(elem, GeneralBlock) and elem.region):
if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region):
cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph
sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg
out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region)
Expand Down Expand Up @@ -514,10 +518,9 @@ def children(self) -> List[ControlFlow]:


@dataclass
class GeneralLoopScope(ControlFlow):
class GeneralLoopScope(RegionBlock):
""" General loop block based on a loop control flow region. """

loop: LoopRegion
body: ControlFlow

def as_cpp(self, codegen, symbols) -> str:
Expand Down Expand Up @@ -565,6 +568,10 @@ def as_cpp(self, codegen, symbols) -> str:

return expr

@property
def loop(self) -> LoopRegion:
return self.region

@property
def first_block(self) -> ControlFlowBlock:
return self.loop.start_block
Expand Down Expand Up @@ -601,6 +608,46 @@ def children(self) -> List[ControlFlow]:
return list(self.cases.values())


@dataclass
class GeneralConditionalScope(RegionBlock):
""" General conditional block based on a conditional control flow region. """

branch_bodies: List[Tuple[Optional[CodeBlock], ControlFlow]]

def as_cpp(self, codegen, symbols) -> str:
sdfg = self.conditional.sdfg
expr = ''
for i in range(len(self.branch_bodies)):
branch = self.branch_bodies[i]
if branch[0] is not None:
cond = unparse_interstate_edge(branch[0].code, sdfg, codegen=codegen, symbols=symbols)
cond = cond.strip(';')
if i == 0:
expr += f'if ({cond}) {{\n'
else:
expr += f'}} else if ({cond}) {{\n'
else:
if i < len(self.branch_bodies) - 1 or i == 0:
raise RuntimeError('Missing branch condition for non-final conditional branch')
expr += '} else {\n'
expr += branch[1].as_cpp(codegen, symbols)
if i == len(self.branch_bodies) - 1:
expr += '}\n'
return expr

@property
def conditional(self) -> ConditionalBlock:
return self.region

@property
def first_block(self) -> ControlFlowBlock:
return self.conditional

@property
def children(self) -> List[ControlFlow]:
return [b for _, b in self.branch_bodies]


def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge],
leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]],
dispatch_state: Callable[[SDFGState],
Expand Down Expand Up @@ -973,7 +1020,6 @@ def _structured_control_flow_traversal_with_regions(cfg: ControlFlowRegion,
if branch_merges is None:
branch_merges = cfg_analysis.branch_merges(cfg)


if ptree is None:
ptree = cfg_analysis.block_parent_tree(cfg, with_loops=False)

Expand Down Expand Up @@ -1004,6 +1050,14 @@ def make_empty_block():
cfg_block = ContinueCFBlock(dispatch_state, parent_block, True, node)
elif isinstance(node, ReturnBlock):
cfg_block = ReturnCFBlock(dispatch_state, parent_block, True, node)
elif isinstance(node, ConditionalBlock):
cfg_block = GeneralConditionalScope(dispatch_state, parent_block, False, node, [])
for cond, branch in node.branches:
if branch is not None:
body = make_empty_block()
body.parent = cfg_block
_structured_control_flow_traversal_with_regions(branch, dispatch_state, body)
cfg_block.branch_bodies.append((cond, body))
elif isinstance(node, ControlFlowRegion):
if isinstance(node, LoopRegion):
body = make_empty_block()
Expand All @@ -1027,69 +1081,8 @@ def make_empty_block():
stack.append(oe[0].dst)
parent_block.elements.append(cfg_block)
continue

# Potential branch or loop
if node in branch_merges:
mergeblock = branch_merges[node]

# Add branching node and ignore outgoing edges
parent_block.elements.append(cfg_block)
parent_block.gotos_to_ignore.extend(oe) # TODO: why?
parent_block.assignments_to_ignore.extend(oe) # TODO: why?
cfg_block.last_block = True

# Parse all outgoing edges recursively first
cblocks: Dict[Edge[InterstateEdge], GeneralBlock] = {}
for branch in oe:
if branch.dst is mergeblock:
# If we hit the merge state (if without else), defer to end of branch traversal
continue
cblocks[branch] = make_empty_block()
_structured_control_flow_traversal_with_regions(cfg=cfg,
dispatch_state=dispatch_state,
parent_block=cblocks[branch],
start=branch.dst,
stop=mergeblock,
generate_children_of=node,
branch_merges=branch_merges,
ptree=ptree,
visited=visited)

# Classify branch type:
branch_block = None
# If there are 2 out edges, one negation of the other:
# * if/else in case both branches are not merge state
# * if without else in case one branch is merge state
if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())):
if oe[0].dst is mergeblock:
# If without else
branch_block = IfScope(dispatch_state, parent_block, False, node, oe[1].data.condition,
cblocks[oe[1]])
elif oe[1].dst is mergeblock:
branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition,
cblocks[oe[0]])
else:
branch_block = IfScope(dispatch_state, parent_block, False, node, oe[0].data.condition,
cblocks[oe[0]], cblocks[oe[1]])
else:
# If there are 2 or more edges (one is not the negation of the
# other):
switch = _cases_from_branches(oe, cblocks)
if switch:
# If all edges are of form "x == y" for a single x and
# integer y, it is a switch/case
branch_block = SwitchCaseScope(dispatch_state, parent_block, False, node, switch[0], switch[1])
else:
# Otherwise, create if/else if/.../else goto exit chain
branch_block = IfElseChain(dispatch_state, parent_block, False, node,
[(e.data.condition, cblocks[e] if e in cblocks else make_empty_block())
for e in oe])
# End of branch classification
parent_block.elements.append(branch_block)
if mergeblock != stop:
stack.append(mergeblock)

else: # No merge state: Unstructured control flow
else:
# Unstructured control flow.
parent_block.sequential = False
parent_block.elements.append(cfg_block)
stack.extend([e.dst for e in oe])
Expand Down
2 changes: 1 addition & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def dispatch_state(state: SDFGState) -> str:
states_generated.add(state) # For sanity check
return stream.getvalue()

if sdfg.root_sdfg.using_experimental_blocks:
if sdfg.root_sdfg.recheck_using_experimental_blocks():
# Use control flow blocks embedded in the SDFG to generate control flow.
cft = cflow.structured_control_flow_tree_with_regions(sdfg, dispatch_state)
elif config.Config.get_bool('optimizer', 'detect_control_flow'):
Expand Down
19 changes: 17 additions & 2 deletions dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from functools import reduce
from itertools import chain
from string import ascii_letters
from typing import Dict, Optional
from typing import Dict, List, Optional

import numpy as np

import dace
from dace import dtypes, subsets, symbolic
Expand Down Expand Up @@ -180,6 +182,19 @@ def create_einsum_sdfg(pv: 'dace.frontend.python.newast.ProgramVisitor',
beta=beta)[0]


def _build_einsum_views(tensors: str, dimension_dict: dict) -> List[np.ndarray]:
"""
Function taken and adjusted from opt_einsum package version 3.3.0 following unexpected removal in vesion 3.4.0.
Reference: https://github.com/dgasmith/opt_einsum/blob/v3.3.0/opt_einsum/helpers.py#L18
"""
views = []
terms = tensors.split('->')[0].split(',')
for term in terms:
dims = [dimension_dict[x] for x in term]
views.append(np.random.rand(*dims))
return views


def _create_einsum_internal(sdfg: SDFG,
state: SDFGState,
einsum_string: str,
Expand Down Expand Up @@ -231,7 +246,7 @@ def _create_einsum_internal(sdfg: SDFG,

# Create optimal contraction path
# noinspection PyTypeChecker
_, path_info = oe.contract_path(einsum_string, *oe.helpers.build_views(einsum_string, chardict))
_, path_info = oe.contract_path(einsum_string, *_build_einsum_views(einsum_string, chardict))

input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}
result_node = None
Expand Down
42 changes: 42 additions & 0 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,48 @@ def negate_expr(node):
return ast.fix_missing_locations(newexpr)


def and_expr(node_a, node_b):
""" Generates the logical AND of two AST expressions.
"""
if type(node_a) is not type(node_b):
raise ValueError('Node types do not match')

# Support for SymPy expressions
if isinstance(node_a, sympy.Basic):
return sympy.And(node_a, node_b)
# Support for numerical constants
if isinstance(node_a, (numbers.Number, numpy.bool_)):
return str(node_a and node_b)
# Support for strings (most likely dace.Data.Scalar names)
if isinstance(node_a, str):
return f'({node_a}) and ({node_b})'

from dace.properties import CodeBlock # Avoid import loop
if isinstance(node_a, CodeBlock):
node_a = node_a.code
node_b = node_b.code

if hasattr(node_a, "__len__"):
if len(node_a) > 1:
raise ValueError("and_expr only expects single expressions, got: {}".format(node_a))
if len(node_b) > 1:
raise ValueError("and_expr only expects single expressions, got: {}".format(node_b))
expr_a = node_a[0]
expr_b = node_b[0]
else:
expr_a = node_a
expr_b = node_b

if isinstance(expr_a, ast.Expr):
expr_a = expr_a.value
if isinstance(expr_b, ast.Expr):
expr_b = expr_b.value

newexpr = ast.Expr(value=ast.BinOp(left=copy_tree(expr_a), op=ast.And, right=copy_tree(expr_b)))
newexpr = ast.copy_location(newexpr, expr_a)
return ast.fix_missing_locations(newexpr)


def copy_tree(node: ast.AST) -> ast.AST:
"""
Copies an entire AST without copying the non-AST parts (e.g., constant values).
Expand Down
6 changes: 5 additions & 1 deletion dace/frontend/python/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def program(f: F,
recompile: bool = True,
distributed_compilation: bool = False,
constant_functions=False,
use_experimental_cfg_blocks=False,
**kwargs) -> Callable[..., parser.DaceProgram]:
"""
Entry point to a data-centric program. For methods and ``classmethod``s, use
Expand All @@ -68,6 +69,8 @@ def program(f: F,
not depend on internal variables are constant.
This will hardcode their return values into the
resulting program.
:param use_experimental_cfg_blocks: If True, makes use of experimental CFG blocks susch as loop and conditional
regions.
:note: If arguments are defined with type hints, the program can be compiled
ahead-of-time with ``.compile()``.
"""
Expand All @@ -83,7 +86,8 @@ def program(f: F,
recreate_sdfg=recreate_sdfg,
regenerate_code=regenerate_code,
recompile=recompile,
distributed_compilation=distributed_compilation)
distributed_compilation=distributed_compilation,
use_experimental_cfg_blocks=use_experimental_cfg_blocks)


function = program
Expand Down
Loading
Loading