Skip to content

Commit

Permalink
Various stability improvements and convenience APIs (#1724)
Browse files Browse the repository at this point in the history
* Various minor code generation and runtime fixes
* Minor API improvements to CompiledSDFG, sdfg.view(), and subset
offsetting
* Minor memlet propagation fix
* Various simplify pass fixes that pertain to use of views, references,
and tasklets with side effects
* Symbolic support for shift and ternary expressions (fixes #1315)
* Pass permissiveness into transformations
  • Loading branch information
tbennun authored Nov 4, 2024
1 parent b27024b commit 64d7679
Show file tree
Hide file tree
Showing 28 changed files with 440 additions and 84 deletions.
2 changes: 1 addition & 1 deletion dace/cli/sdfgcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main():
sdfg = SDFGOptimizer(sdfg).optimize()

# Compile SDFG
sdfg.compile(outpath)
sdfg.compile(outpath, return_program_handle=False)

# Copying header file to optional path
if outpath is not None:
Expand Down
6 changes: 5 additions & 1 deletion dace/cli/sdfv.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: b
):
fd, filename = tempfile.mkstemp(suffix='.sdfg')
sdfg.save(filename)
os.system(f'code {filename}')
if platform.system() == 'Darwin':
# Special case for MacOS
os.system(f'open {filename}')
else:
os.system(f'code {filename}')
os.close(fd)
return

Expand Down
5 changes: 5 additions & 0 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,9 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
# Otherwise, None values are passed as null pointers below
elif isinstance(arg, ctypes._Pointer):
pass
elif isinstance(arg, str):
# Cast to bytes
arglist[i] = ctypes.c_char_p(arg.encode('utf-8'))
else:
raise TypeError(f'Passing an object (type {type(arg).__name__}) to an array in argument "{a}"')
elif is_array and not is_dtArray:
Expand Down Expand Up @@ -550,6 +553,8 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
pass
elif isinstance(arg, float) and atype.dtype.type == np.float64:
pass
elif isinstance(arg, bool) and atype.dtype.type == np.bool_:
pass
elif (isinstance(arg, str) or arg is None) and atype.dtype == dtypes.string:
if arg is None:
arglist[i] = ctypes.c_char_p(None)
Expand Down
2 changes: 1 addition & 1 deletion dace/codegen/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def configure_and_compile(program_folder, program_name=None, output_stream=None)
# Clean CMake directory and try once more
if Config.get_bool('debugprint'):
print('Cleaning CMake build folder and retrying...')
shutil.rmtree(build_folder)
shutil.rmtree(build_folder, ignore_errors=True)
os.makedirs(build_folder)
try:
_run_liveoutput(cmake_command, shell=True, cwd=build_folder, output_stream=output_stream)
Expand Down
8 changes: 7 additions & 1 deletion dace/codegen/cppunparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,11 @@ def _write_constant(self, value):
if result.find("b'") >= 0:
self.write(result)
else:
self.write(result.replace('\'', '\"'))
towrite = result
if result.startswith("'"):
towrite = result[1:-1].replace('"', '\\"')
towrite = f'"{towrite}"'
self.write(towrite)

def _Constant(self, t):
value = t.value
Expand Down Expand Up @@ -1187,6 +1191,8 @@ def py2cpp(code, expr_semicolon=True, defined_symbols=None):
return cppunparse(ast.parse(symbolic.symstr(code, cpp_mode=True)),
expr_semicolon,
defined_symbols=defined_symbols)
elif isinstance(code, int):
return str(code)
elif code.__class__.__name__ == 'function':
try:
code_str = inspect.getsource(code)
Expand Down
2 changes: 2 additions & 0 deletions dace/codegen/tools/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def _Compare(t, symbols, inferred_symbols):
for o, e in zip(t.ops, t.comparators):
if o.__class__.__name__ not in cppunparse.CPPUnparser.cmpops:
continue
if isinstance(e, ast.Constant) and e.value is None:
continue
inf_type = _dispatch(e, symbols, inferred_symbols)
if isinstance(inf_type, dtypes.vector):
# Make sure all occuring vectors are of same size
Expand Down
2 changes: 2 additions & 0 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def __init__(self, wrapped_type, typename=None):
wrapped_type = numpy.bool_
elif getattr(wrapped_type, '__name__', '') == 'bool_' and typename is None:
typename = 'bool'
elif wrapped_type is type(None):
wrapped_type = None

self.type = wrapped_type # Type in Python
self.ctype = _CTYPES[wrapped_type] # Type in C
Expand Down
4 changes: 2 additions & 2 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,9 @@ def used_symbols(self, all_symbols: bool, edge=None) -> Set[str]:
from dace.sdfg import nodes
if isinstance(edge.dst, nodes.CodeNode) or isinstance(edge.src, nodes.CodeNode):
view_edge = True
elif edge.dst_conn == 'views' and isinstance(edge.dst, nodes.AccessNode):
elif edge.dst_conn and isinstance(edge.dst, nodes.AccessNode):
view_edge = True
elif edge.src_conn == 'views' and isinstance(edge.src, nodes.AccessNode):
elif edge.src_conn and isinstance(edge.src, nodes.AccessNode):
view_edge = True

if not view_edge:
Expand Down
2 changes: 1 addition & 1 deletion dace/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def initialize_properties(obj, *args, **kwargs):
for name, prop in own_properties.items():
# Only assign our own properties, so we don't overwrite what's been
# set by the base class
if hasattr(obj, name):
if hasattr(obj, '_' + name):
raise PropertyError("Property {} already assigned in {}".format(name, type(obj).__name__))
if not prop.indirected:
if prop.allow_none or prop.default is not None:
Expand Down
8 changes: 4 additions & 4 deletions dace/runtime/include/dace/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ namespace dace {

template <int CHUNKSIZE>
struct Consume {
template <template <typename, bool> typename StreamT, typename T, bool ALIGNED,
template <template <typename, bool> class StreamT, typename T, bool ALIGNED,
typename Functor>
static void consume(StreamT<T, ALIGNED>& stream, unsigned num_threads,
Functor&& contents) {
Expand All @@ -359,7 +359,7 @@ namespace dace {
for (auto& t : threads) t.join();
}

template <template <typename, bool> typename StreamT, typename T, bool ALIGNED,
template <template <typename, bool> class StreamT, typename T, bool ALIGNED,
typename CondFunctor, typename Functor>
static void consume_cond(StreamT<T, ALIGNED>& stream, unsigned num_threads,
CondFunctor&& quiescence, Functor&& contents) {
Expand All @@ -384,7 +384,7 @@ namespace dace {
// Specialization for consumption of 1 element
template<>
struct Consume<1> {
template <template <typename, bool> typename StreamT, typename T, bool ALIGNED,
template <template <typename, bool> class StreamT, typename T, bool ALIGNED,
typename Functor>
static void consume(StreamT<T, ALIGNED>& stream, unsigned num_threads,
Functor&& contents) {
Expand All @@ -404,7 +404,7 @@ namespace dace {
for (auto& t : threads) t.join();
}

template <template <typename, bool> typename StreamT, typename T, bool ALIGNED,
template <template <typename, bool> class StreamT, typename T, bool ALIGNED,
typename CondFunctor, typename Functor>
static void consume_cond(StreamT<T, ALIGNED>& stream, unsigned num_threads,
CondFunctor&& quiescence, Functor&& contents) {
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def all_simple_paths(self,
for path in map(nx.utils.pairwise, nx.all_simple_paths(self._nx, source_node, dest_node)):
yield [Edge(e[0], e[1], self._nx.edges[e]['data']) for e in path]
else:
return nx.all_simple_paths(self._nx, source_node, dest_node)
yield from nx.all_simple_paths(self._nx, source_node, dest_node)

def all_nodes_between(self, begin: NodeT, end: NodeT) -> Sequence[NodeT]:
"""Finds all nodes between begin and end. Returns None if there is any
Expand Down
6 changes: 4 additions & 2 deletions dace/sdfg/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di
matches = rngelem.match(cst)
if matches is None or len(matches) != 1:
return False
if not matches[cst].is_constant():
if matches[cst].free_symbols:
return False

else: # Single element case
Expand All @@ -386,7 +386,7 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di
matches = dexpr.match(cst)
if matches is None or len(matches) != 1:
return False
if not matches[cst].is_constant():
if matches[cst].free_symbols:
return False

return True
Expand Down Expand Up @@ -1296,6 +1296,8 @@ def align_memlet(state, e: gr.MultiConnectorEdge[Memlet], dst: bool) -> Memlet:
# Fix memlet fields
result.data = node.data
result.subset = e.data.other_subset
if result.subset is None:
result.subset = subsets.Range.from_array(state.sdfg.arrays[result.data])
result.other_subset = e.data.subset
result._is_data_src = not is_src
return result
Expand Down
9 changes: 6 additions & 3 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,14 +2266,16 @@ def is_loaded(self) -> bool:
dll = cs.ReloadableDLL(binary_filename, self.name)
return dll.is_loaded()

def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
def compile(self, output_file=None, validate=True,
return_program_handle=True) -> 'CompiledSDFG':
""" Compiles a runnable binary from this SDFG.
:param output_file: If not None, copies the output library file to
the specified path.
:param validate: If True, validates the SDFG prior to generating
code.
:return: A callable CompiledSDFG object.
:param return_program_handle: If False, does not load the generated library.
:return: A callable CompiledSDFG object, or None if ``return_program_handle=False``.
"""

# Importing these outside creates an import loop
Expand Down Expand Up @@ -2336,7 +2338,8 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
shutil.copyfile(shared_library, output_file)

# Get the function handle
return compiler.get_program_handle(shared_library, sdfg)
if return_program_handle:
return compiler.get_program_handle(shared_library, sdfg)

def argument_typecheck(self, args, kwargs, types_only=False):
""" Checks if arguments and keyword arguments match the SDFG
Expand Down
67 changes: 56 additions & 11 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,13 @@ def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge):
e = curedge.edge
state.remove_edge(e)
if inwards:
neighbors = [] if not e.src_conn else [neighbor for neighbor in state.out_edges_by_connector(e.src, e.src_conn)]
neighbors = [] if not e.src_conn else [
neighbor for neighbor in state.out_edges_by_connector(e.src, e.src_conn)
]
else:
neighbors = [] if not e.dst_conn else [neighbor for neighbor in state.in_edges_by_connector(e.dst, e.dst_conn)]
neighbors = [] if not e.dst_conn else [
neighbor for neighbor in state.in_edges_by_connector(e.dst, e.dst_conn)
]
if len(neighbors) > 0: # There are still edges connected, leave as-is
break

Expand Down Expand Up @@ -796,6 +800,33 @@ def get_all_view_nodes(state: SDFGState, view: nd.AccessNode) -> List[nd.AccessN
return result


def get_all_view_edges(state: SDFGState, view: nd.AccessNode) -> List[gr.MultiConnectorEdge[mm.Memlet]]:
"""
Given a view access node, returns a list of viewed access nodes as edges
if existent, else None
"""
sdfg = state.parent
node = view
desc = sdfg.arrays[node.data]
result = []
while isinstance(desc, dt.View):
edge = get_view_edge(state, node)
if edge is None:
break
old_node = node
if edge.dst is view:
node = edge.src
else:
node = edge.dst
if node is old_node:
break
if not isinstance(node, nd.AccessNode):
break
desc = sdfg.arrays[node.data]
result.append(edge)
return result


def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdge[mm.Memlet]:
"""
Given a view access node, returns the
Expand All @@ -818,8 +849,18 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg
# If there is one edge (in/out) that leads (via memlet path) to an access
# node, and the other side (out/in) has a different number of edges.
if len(in_edges) == 1 and len(out_edges) != 1:
# If the edge is not leading to an access node, fail
mpath = state.memlet_path(in_edges[0])
if not isinstance(mpath[0].src, nd.AccessNode):
return None

return in_edges[0]
if len(out_edges) == 1 and len(in_edges) != 1:
# If the edge is not leading to an access node, fail
mpath = state.memlet_path(out_edges[0])
if not isinstance(mpath[-1].dst, nd.AccessNode):
return None

return out_edges[0]
if len(out_edges) == len(in_edges) and len(out_edges) != 1:
return None
Expand All @@ -843,7 +884,7 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg
return out_edge
if not src_is_data and not dst_is_data:
return None

# Check if there is a 'views' connector
if in_edge.dst_conn and in_edge.dst_conn == 'views':
return in_edge
Expand Down Expand Up @@ -1227,8 +1268,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) ->
progress = True
pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter)

if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) or
not isinstance(u, SDFGState)):
if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState)
or not isinstance(u, SDFGState)):
continue
candidate = {StateFusion.first_state: u, StateFusion.second_state: v}
sf = StateFusion()
Expand All @@ -1252,8 +1293,7 @@ def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = No
blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, LoopRegion)]
count = 0

for _block in optional_progressbar(reversed(blocks), title='Inlining Loops',
n=len(blocks), progress=progress):
for _block in optional_progressbar(reversed(blocks), title='Inlining Loops', n=len(blocks), progress=progress):
block: LoopRegion = _block
if block.inline()[0]:
count += 1
Expand All @@ -1265,20 +1305,25 @@ def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress:
blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ControlFlowRegion)]
count = 0

for _block in optional_progressbar(reversed(blocks), title='Inlining control flow regions',
n=len(blocks), progress=progress):
for _block in optional_progressbar(reversed(blocks),
title='Inlining control flow regions',
n=len(blocks),
progress=progress):
block: ControlFlowRegion = _block
if block.inline()[0]:
count += 1

return count


def inline_conditional_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int:
blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)]
count = 0

for _block in optional_progressbar(reversed(blocks), title='Inlining conditional blocks',
n=len(blocks), progress=progress):
for _block in optional_progressbar(reversed(blocks),
title='Inlining conditional blocks',
n=len(blocks),
progress=progress):
block: ConditionalBlock = _block
if block.inline()[0]:
count += 1
Expand Down
4 changes: 3 additions & 1 deletion dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,14 +836,16 @@ def validate_state(state: 'dace.sdfg.SDFGState',
dst_expr = (e.data.dst_subset.num_elements() * sdfg.arrays[dst_node.data].veclen)
if symbolic.inequal_symbols(src_expr, dst_expr):
error = InvalidSDFGEdgeError('Dimensionality mismatch between src/dst subsets', sdfg, state_id, eid)
# NOTE: Make an exception for Views
# NOTE: Make an exception for Views and reference sets
from dace.sdfg import utils
if (isinstance(sdfg.arrays[src_node.data], dt.View) and utils.get_view_edge(state, src_node) is e):
warnings.warn(error.message)
continue
if (isinstance(sdfg.arrays[dst_node.data], dt.View) and utils.get_view_edge(state, dst_node) is e):
warnings.warn(error.message)
continue
if e.dst_conn == 'set':
continue
raise error

if Config.get_bool('experimental.check_race_conditions'):
Expand Down
Loading

0 comments on commit 64d7679

Please sign in to comment.