Skip to content

Commit

Permalink
Improvements to the clip wrapper
Browse files Browse the repository at this point in the history
- Ensure the arrays that are created are created on the same device as x.
  (fixes #177)

- Make clip() work with dask.array. The workaround avoid uint64 -> float64
  promotion does not work here. (fixes #176)

- Fix loss of precision when clipping a float64 tensor with torch due to the
  scalar being converted to a float32 tensor.
  • Loading branch information
asmeurer committed Aug 12, 2024
1 parent b96e84b commit 5f8b5d6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
21 changes: 14 additions & 7 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import NamedTuple
import inspect

from ._helpers import array_namespace, _check_device
from ._helpers import array_namespace, _check_device, device, is_torch_array

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -281,10 +281,11 @@ def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)

wrapped_xp = array_namespace(x)

result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)

# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
Expand All @@ -305,20 +306,26 @@ def _isscalar(a):

# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
if type(min) is int and min <= xp.iinfo(x.dtype).min:
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
min = None
if type(max) is int and max >= xp.iinfo(x.dtype).max:
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
max = None

if out is None:
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
copy=True, device=device(x))
if min is not None:
a = xp.broadcast_to(xp.asarray(min), result_shape)
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
# Avoid loss of precision due to torch defaulting to float32
min = wrapped_xp.asarray(min, dtype=xp.float64)
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
ia = (out < a) | xp.isnan(a)
# torch requires an explicit cast here
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
if max is not None:
b = xp.broadcast_to(xp.asarray(max), result_shape)
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
max = wrapped_xp.asarray(max, dtype=xp.float64)
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
ib = (out > b) | xp.isnan(b)
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
# Return a scalar for 0-D
Expand Down
38 changes: 37 additions & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def _dask_arange(
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
clip = get_xp(da)(_aliases.clip)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
Expand Down Expand Up @@ -167,6 +166,43 @@ def asarray(
concatenate as concat,
)

# dask.array.clip does not work unless all three arguments are provided.
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
# now).
@get_xp(da)
def clip(
x: Array,
/,
min: Optional[Union[int, float, Array]] = None,
max: Optional[Union[int, float, Array]] = None,
*,
xp,
) -> Array:
def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape

# TODO: This won't handle dask unknown shapes
import numpy as np
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)

if min is not None:
min = xp.broadcast_to(xp.asarray(min), result_shape)
if max is not None:
max = xp.broadcast_to(xp.asarray(max), result_shape)

if min is None and max is None:
return xp.positive(x)

if min is None:
return astype(xp.minimum(x, max), x.dtype)
if max is None:
return astype(xp.maximum(x, min), x.dtype)

return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)

# exclude these from all since
_da_unsupported = ['sort', 'argsort']

Expand Down

0 comments on commit 5f8b5d6

Please sign in to comment.