diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index bb53e8e..ec43ae3 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -389,42 +389,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) -# sum() and prod() should always upcast when dtype=None -def sum( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - # `xp.sum` already upcasts integers, but not floats or complexes - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 - return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs) - -def prod( - x: ndarray, - /, - xp, - *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, - keepdims: bool = False, - **kwargs, -) -> ndarray: - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 - return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs) - # ceil, floor, and trunc return integers for integer inputs def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: @@ -525,6 +489,6 @@ def isdtype( 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort', - 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] + 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', + 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', + 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index dc2b69d..bfa1f1b 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - if dtype is None: - if x.dtype == xp.float32: - dtype = xp.float64 - elif x.dtype == xp.complex64: - dtype = xp.complex128 return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index d7e78fd..68ff378 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -53,8 +53,6 @@ argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -sum = get_xp(cp)(_aliases.sum) -prod = get_xp(cp)(_aliases.prod) ceil = get_xp(cp)(_aliases.ceil) floor = get_xp(cp)(_aliases.floor) trunc = get_xp(cp)(_aliases.trunc) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d26ec6a..5d89aa1 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -102,8 +102,6 @@ def _dask_arange( vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -sum = get_xp(np)(_aliases.sum) -prod = get_xp(np)(_aliases.prod) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index e29b075..4fd6a68 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -53,8 +53,6 @@ argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -sum = get_xp(np)(_aliases.sum) -prod = get_xp(np)(_aliases.prod) ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc)