Skip to content

Commit

Permalink
Fix symbolic parsing for ternary operators (#1346)
Browse files Browse the repository at this point in the history
Co-authored-by: acalotoiu <61420859+acalotoiu@users.noreply.github.com>
  • Loading branch information
tbennun and acalotoiu authored Aug 4, 2023
1 parent 425fed6 commit 20240a8
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
5 changes: 5 additions & 0 deletions dace/runtime/include/dace/pyinterop.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,10 @@ template <typename T>
static DACE_HDFI T Abs(T val) {
return abs(val);
}
template <typename T, typename U>
DACE_CONSTEXPR DACE_HDFI typename std::common_type<T, U>::type IfExpr(bool condition, const T& iftrue, const U& iffalse)
{
return condition ? iftrue : iffalse;
}

#endif // __DACE_INTEROP_H
21 changes: 20 additions & 1 deletion dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,21 @@ def eval(cls, x, y):
def _eval_is_boolean(self):
return True

class IfExpr(sympy.Function):

@classmethod
def eval(cls, x, y, z):
"""
Evaluates a ternary operator.
:param x: Predicate.
:param y: If true return this.
:param z: If false return this.
:return: Return value (literal or symbolic).
"""
if x.is_Boolean:
return (y if x else z)


class BitwiseAnd(sympy.Function):
pass
Expand Down Expand Up @@ -968,6 +983,9 @@ def visit_Constant(self, node):
def visit_NameConstant(self, node):
return self.visit_Constant(node)

def visit_IfExp(self, node):
new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load), args=[node.test, node.body, node.orelse], keywords=[])
return ast.copy_location(new_node, node)

class BitwiseOpConverter(ast.NodeTransformer):
"""
Expand Down Expand Up @@ -1050,6 +1068,7 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
'RightShift': RightShift,
'int_floor': int_floor,
'int_ceil': int_ceil,
'IfExpr': IfExpr,
'Mod': sympy.Mod,
}
# _clash1 enables all one-letter variables like N as symbols
Expand All @@ -1059,7 +1078,7 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
if isinstance(expr, str):
# Sympy processes "not/and/or" as direct evaluation. Replace with
# And/Or(x, y), Not(x)
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b', expr):
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b', expr):
expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0]))

# NOTE: If the expression contains bitwise operations, replace them with user-functions.
Expand Down
28 changes: 28 additions & 0 deletions tests/passes/scalar_to_symbol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,32 @@ def prog(inp: dace.int32[4, 2], out: dace.float64[5, 5]):
sdfg.compile()


@pytest.mark.parametrize('compile_time_evaluatable', (False, True))
def test_ternary_expression(compile_time_evaluatable):
sdfg = dace.SDFG('tester')
sdfg.add_symbol('N', dace.int32)
sdfg.add_symbol('M', dace.int32)
sdfg.add_scalar('a', dace.int32, transient=True)
state = sdfg.add_state()

if compile_time_evaluatable:
expr = '1 if N > N else 2'
else:
expr = '1 if N > M else 2'

# Test that symbolic conversion works
symexpr = dace.symbolic.pystr_to_symbolic(expr)
if compile_time_evaluatable:
assert symexpr == 2

t = state.add_tasklet('doit', {}, {'out'}, f'out = {expr}')
state.add_edge(t, 'out', state.add_access('a'), None, dace.Memlet('a[0]'))

promoted = scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {})
assert promoted == {'a'}
sdfg.compile()


if __name__ == '__main__':
test_find_promotable()
test_promote_simple()
Expand All @@ -687,3 +713,5 @@ def prog(inp: dace.int32[4, 2], out: dace.float64[5, 5]):
test_multiple_boolop()
test_multidim_cpp()
test_dynamic_mapind()
test_ternary_expression(False)
test_ternary_expression(True)

0 comments on commit 20240a8

Please sign in to comment.