Skip to content

Commit

Permalink
Remove floating-point promotion from sum, prod, and trace
Browse files Browse the repository at this point in the history
Fixes #152
  • Loading branch information
asmeurer committed Jul 29, 2024
1 parent d57c671 commit 0734064
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 50 deletions.
42 changes: 3 additions & 39 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,42 +389,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)

# sum() and prod() should always upcast when dtype=None
def sum(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
# `xp.sum` already upcasts integers, but not floats or complexes
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)

def prod(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)

# ceil, floor, and trunc return integers for integer inputs

def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
Expand Down Expand Up @@ -525,6 +489,6 @@ def isdtype(
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape',
'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul',
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
5 changes: 0 additions & 5 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)

def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
Expand Down
2 changes: 0 additions & 2 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
sum = get_xp(cp)(_aliases.sum)
prod = get_xp(cp)(_aliases.prod)
ceil = get_xp(cp)(_aliases.ceil)
floor = get_xp(cp)(_aliases.floor)
trunc = get_xp(cp)(_aliases.trunc)
Expand Down
2 changes: 0 additions & 2 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ def _dask_arange(
vecdot = get_xp(da)(_aliases.vecdot)

nonzero = get_xp(da)(_aliases.nonzero)
sum = get_xp(np)(_aliases.sum)
prod = get_xp(np)(_aliases.prod)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
Expand Down
2 changes: 0 additions & 2 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@
argsort = get_xp(np)(_aliases.argsort)
sort = get_xp(np)(_aliases.sort)
nonzero = get_xp(np)(_aliases.nonzero)
sum = get_xp(np)(_aliases.sum)
prod = get_xp(np)(_aliases.prod)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
Expand Down

0 comments on commit 0734064

Please sign in to comment.