Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic JAX support #84

Merged
merged 20 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions array_api_compat/common/_helpers.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp, I could mostly use your review for the changes in this file.

Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def is_jax_array(x):
if 'jax' not in sys.modules:
return False

import jax.numpy
import jax

return isinstance(x, jax.numpy.ndarray)
return isinstance(x, jax.Array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be more guarded? e.g. what if someone has a module named jax.py in their path, it could lead to an AttributeError here and make array_api_compat unusable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. That sort of thing tends to just break everything anyway. I've never really felt that libraries should protect against that sort of thing.

Anyway, the whole point of this function is to be guarded. It won't import jax unless it's already been imported, because jax is a slow import and unnecessary if a user is using another array library.

Copy link
Contributor

@jakevdp jakevdp Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More precisely, it won't import jax unless something named jax has been imported, and then array_api_compat will irrecoverably error if that jax module isn't the one we're expecting, or is an older version that doesn't define the attributes we reference.

A similar issue exists for every other package name referenced in this module.

My feeling is: it costs virtually nothing to wrap this all in an appropriate try ... except, so why not do so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not so sure this is a good idea. Usually if you have that sort of thing it will be an error for a lot of things, not just array_api_compat. My worry here is that guarding isn't as straightforward as it might seem. Wrapping everything in try/except could mean we end up silencing legitimate errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good.


def is_array_api_obj(x):
"""
Expand Down Expand Up @@ -153,7 +153,7 @@ def _check_device(xp, device):
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")

# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
Expand Down Expand Up @@ -230,12 +230,6 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)

def _jax_to_device(x, device, /, stream=None):
import jax
if stream is not None:
raise NotImplementedError
return jax.device_put(x, device)

def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Expand Down Expand Up @@ -276,7 +270,9 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_jax_array(x):
return _jax_to_device(x, device, stream=stream)
# This import adds to_device to x
import jax.experimental.array_api
return x.to_device(device, stream=stream)
return x.to_device(device, stream=stream)

def size(x):
Expand Down
19 changes: 18 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_array_api_obj)
is_dask_array, is_jax_array, is_array_api_obj,
device, to_device)

from ._helpers import import_

Expand All @@ -24,3 +25,19 @@ def test_is_xp_array(library, func):
assert is_func(x) == (func == is_functions[library])

assert is_array_api_obj(x)

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
def test_device(library):
if library == "jax.numpy":
xp = import_('jax.experimental.array_api')
else:
xp = import_('array_api_compat.' + library)

# We can't test much for device() and to_device() other than that
# x.to_device(x.device) works.

x = xp.asarray([1, 2, 3])
dev = device(x)

x2 = to_device(x, dev)
assert device(x) == device(x2)
Loading