diff --git a/dace/runtime/include/dace/pyinterop.h b/dace/runtime/include/dace/pyinterop.h index e8f255af70..f93cbab770 100644 --- a/dace/runtime/include/dace/pyinterop.h +++ b/dace/runtime/include/dace/pyinterop.h @@ -52,5 +52,10 @@ template static DACE_HDFI T Abs(T val) { return abs(val); } +template +DACE_CONSTEXPR DACE_HDFI typename std::common_type::type IfExpr(bool condition, const T& iftrue, const U& iffalse) +{ + return condition ? iftrue : iffalse; +} #endif // __DACE_INTEROP_H diff --git a/dace/symbolic.py b/dace/symbolic.py index 01440d465e..0ab6e3f6ff 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -658,6 +658,21 @@ def eval(cls, x, y): def _eval_is_boolean(self): return True +class IfExpr(sympy.Function): + + @classmethod + def eval(cls, x, y, z): + """ + Evaluates a ternary operator. + + :param x: Predicate. + :param y: If true return this. + :param z: If false return this. + :return: Return value (literal or symbolic). + """ + if x.is_Boolean: + return (y if x else z) + class BitwiseAnd(sympy.Function): pass @@ -968,6 +983,9 @@ def visit_Constant(self, node): def visit_NameConstant(self, node): return self.visit_Constant(node) + def visit_IfExp(self, node): + new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load), args=[node.test, node.body, node.orelse], keywords=[]) + return ast.copy_location(new_node, node) class BitwiseOpConverter(ast.NodeTransformer): """ @@ -1050,6 +1068,7 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic: 'RightShift': RightShift, 'int_floor': int_floor, 'int_ceil': int_ceil, + 'IfExpr': IfExpr, 'Mod': sympy.Mod, } # _clash1 enables all one-letter variables like N as symbols @@ -1059,7 +1078,7 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic: if isinstance(expr, str): # Sympy processes "not/and/or" as direct evaluation. Replace with # And/Or(x, y), Not(x) - if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b', expr): + if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b', expr): expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0])) # NOTE: If the expression contains bitwise operations, replace them with user-functions. diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 9ec23e3886..02cc57a204 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -666,6 +666,32 @@ def prog(inp: dace.int32[4, 2], out: dace.float64[5, 5]): sdfg.compile() +@pytest.mark.parametrize('compile_time_evaluatable', (False, True)) +def test_ternary_expression(compile_time_evaluatable): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int32) + sdfg.add_symbol('M', dace.int32) + sdfg.add_scalar('a', dace.int32, transient=True) + state = sdfg.add_state() + + if compile_time_evaluatable: + expr = '1 if N > N else 2' + else: + expr = '1 if N > M else 2' + + # Test that symbolic conversion works + symexpr = dace.symbolic.pystr_to_symbolic(expr) + if compile_time_evaluatable: + assert symexpr == 2 + + t = state.add_tasklet('doit', {}, {'out'}, f'out = {expr}') + state.add_edge(t, 'out', state.add_access('a'), None, dace.Memlet('a[0]')) + + promoted = scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) + assert promoted == {'a'} + sdfg.compile() + + if __name__ == '__main__': test_find_promotable() test_promote_simple() @@ -687,3 +713,5 @@ def prog(inp: dace.int32[4, 2], out: dace.float64[5, 5]): test_multiple_boolop() test_multidim_cpp() test_dynamic_mapind() + test_ternary_expression(False) + test_ternary_expression(True)