Skip to content

Commit

Permalink
[TKW] Add overflow flags to index calculations (#218)
Browse files Browse the repository at this point in the history
Index calculations should never overflow. Add corresponding flags to
`arith.addi/muli`.

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Oct 15, 2024
1 parent bacfdcd commit 6d6f117
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 103 deletions.
43 changes: 27 additions & 16 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,52 +249,63 @@ def get_const_val(arg):

return None

def muli_fold(lhs, rhs):
overflow_flags = arith_d.IntegerOverflowFlags.nsw | arith_d.IntegerOverflowFlags.nuw

def muli(lhs, rhs):
if get_const_val(lhs) == 1:
return rhs

if get_const_val(rhs) == 1:
return lhs

return arith_d.muli(lhs, rhs)
return arith_d.muli(lhs, rhs, overflow_flags=overflow_flags)

def addi(lhs, rhs):
if get_const_val(lhs) == 0:
return rhs

if get_const_val(rhs) == 0:
return lhs

return arith_d.addi(lhs, rhs, overflow_flags=overflow_flags)

# `x + (a/b)` transformed into `(x*b + a) / b`
def _add(lhs, rhs):
is_rational_lhs = isinstance(lhs, _Rational)
is_rational_rhs = isinstance(rhs, _Rational)
if is_rational_lhs and not is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs.denominator, rhs))
numerator = arith_d.addi(*_broadcast(numerator, lhs.numerator))
numerator = muli(*_broadcast(lhs.denominator, rhs))
numerator = addi(*_broadcast(numerator, lhs.numerator))
return _Rational(numerator, lhs.denominator)
elif not is_rational_lhs and is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs, rhs.denominator))
numerator = arith_d.addi(*_broadcast(numerator, rhs.numerator))
numerator = muli(*_broadcast(lhs, rhs.denominator))
numerator = addi(*_broadcast(numerator, rhs.numerator))
return _Rational(numerator, rhs.denominator)
elif is_rational_lhs and is_rational_rhs:
lhs_numerator = muli_fold(*_broadcast(lhs.numerator, rhs.denominator))
rhs_numerator = muli_fold(*_broadcast(rhs.numerator, lhs.denominator))
numerator = arith_d.addi(*_broadcast(lhs_numerator, rhs_numerator))
denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator))
lhs_numerator = muli(*_broadcast(lhs.numerator, rhs.denominator))
rhs_numerator = muli(*_broadcast(rhs.numerator, lhs.denominator))
numerator = addi(*_broadcast(lhs_numerator, rhs_numerator))
denominator = muli(*_broadcast(lhs.denominator, rhs.denominator))
return _Rational(numerator, denominator)
else:
return arith_d.addi(*_broadcast(lhs, rhs))
return addi(*_broadcast(lhs, rhs))

# `x * (a/b)` transformed into `(x * a) / b`
def _mul(lhs, rhs):
is_rational_lhs = isinstance(lhs, _Rational)
is_rational_rhs = isinstance(rhs, _Rational)
if is_rational_lhs and not is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs.numerator, rhs))
numerator = muli(*_broadcast(lhs.numerator, rhs))
return _Rational(numerator, lhs.denominator)
elif not is_rational_lhs and is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs, rhs.numerator))
numerator = muli(*_broadcast(lhs, rhs.numerator))
return _Rational(numerator, rhs.denominator)
elif is_rational_lhs and is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs.numerator, rhs.numerator))
denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator))
numerator = muli(*_broadcast(lhs.numerator, rhs.numerator))
denominator = muli(*_broadcast(lhs.denominator, rhs.denominator))
return _Rational(numerator, denominator)
else:
return muli_fold(*_broadcast(lhs, rhs))
return muli(*_broadcast(lhs, rhs))

def _floor(value):
if isinstance(value, _Rational):
Expand Down
Loading

0 comments on commit 6d6f117

Please sign in to comment.