Skip to content

Commit

Permalink
Add clip() wrapper for NumPy and CuPy
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Jul 9, 2024
1 parent 090f570 commit 398fdf1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
54 changes: 52 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import NamedTuple
import inspect

from ._helpers import _check_device
from ._helpers import array_namespace, _check_device

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -264,6 +264,56 @@ def var(
) -> ndarray:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)


# The min and max argument names in clip are different and not optional in numpy, and type
# promotion behavior is different.
def clip(
x: ndarray,
/,
min: Optional[Union[int, float, ndarray]] = None,
max: Optional[Union[int, float, ndarray]] = None,
*,
xp,
# TODO: np.clip has other ufunc kwargs
out: Optional[ndarray] = None,
) -> ndarray:
def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)

wrapped_xp = array_namespace(x)

# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
# avoiding uint64 -> float64 promotion).

# Note: cases where min or max overflow (integer) or round (float) in the
# wrong direction when downcasting to x.dtype are unspecified. This code
# just does whatever NumPy does when it downcasts in the assignment, but
# other behavior could be preferred, especially for integers. For example,
# this code produces:

# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
# -128

# but an answer of 0 might be preferred. See
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
if out is None:
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
if min is not None:
a = xp.broadcast_to(xp.asarray(min), result_shape)
ia = (out < a) | xp.isnan(a)
out[ia] = a[ia]
if max is not None:
b = xp.broadcast_to(xp.asarray(max), result_shape)
ib = (out > b) | xp.isnan(b)
out[ib] = b[ib]
# Return a scalar for 0-D
return out[()]

# Unlike transpose(), the axes argument to permute_dims() is required.
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
return xp.transpose(x, axes)
Expand Down Expand Up @@ -465,6 +515,6 @@ def isdtype(
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
1 change: 1 addition & 0 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
clip = get_xp(cp)(_aliases.clip)
permute_dims = get_xp(cp)(_aliases.permute_dims)
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
astype = _aliases.astype
std = get_xp(np)(_aliases.std)
var = get_xp(np)(_aliases.var)
clip = get_xp(np)(_aliases.clip)
permute_dims = get_xp(np)(_aliases.permute_dims)
reshape = get_xp(np)(_aliases.reshape)
argsort = get_xp(np)(_aliases.argsort)
Expand Down

0 comments on commit 398fdf1

Please sign in to comment.