Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic JAX support #84

Merged
merged 20 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy torch dask[array]
python -m pip install pytest numpy torch dask[array] jax[cpu]

- name: Run Tests
run: |
Expand Down
25 changes: 23 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
NumPy, CuPy, and PyTorch are supported. If you want support for other array
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
libraries, or if you encounter any issues, please [open an
issue](https://github.com/data-apis/array-api-compat/issues).

Expand Down Expand Up @@ -56,7 +56,17 @@ import array_api_compat.cupy as cp
import array_api_compat.torch as torch
```

Each will include all the functions from the normal NumPy/CuPy/PyTorch
```py
import array_api_compat.dask as da
```

> [!NOTE]
> There is no `array_api_compat.jax` submodule. JAX support is contained
> in JAX itself in the `jax.experimental.array_api` module. array-api-compat simply
> wraps that submodule. The main JAX support in this module consists of
> supporting it in the [helper functions](#helper-functions) defined below.

Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array
namespace, except that functions that are part of the array API are wrapped so
that they have the correct array API behavior. In each case, the array object
used will be the same array object from the wrapped library.
Expand Down Expand Up @@ -99,6 +109,11 @@ part of the specification but which are useful for using the array API:
- `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array
object.

- `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`,
`is_dask_array(x)`, `is_jax_array(x)`: return `True` if `x` is an array from
the corresponding library. These functions do not import the underlying
library if it has not already been imported, so they are cheap to use.

- `array_namespace(*xs)`: Get the corresponding array API namespace for the
arrays `xs`. For example, if the arrays are NumPy arrays, the returned
namespace will be `array_api_compat.numpy`. Note that this function will
Expand Down Expand Up @@ -219,6 +234,12 @@ version.

The minimum supported PyTorch version is 1.13.

### JAX

Unlike the other libraries supported here, JAX array API support is contained
entirely in the JAX library. The JAX array API support is tracked at
https://github.com/google/jax/issues/18353.

## Vendoring

This library supports vendoring as an installation method. To vendor the
Expand Down
4 changes: 2 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from types import ModuleType
import inspect

from ._helpers import _check_device, _is_numpy_array, array_namespace
from ._helpers import _check_device, is_numpy_array, array_namespace

# These functions are modified from the NumPy versions.

Expand Down Expand Up @@ -309,7 +309,7 @@ def _asarray(
raise ValueError("Unrecognized namespace argument to asarray()")

_check_device(xp, device)
if _is_numpy_array(obj):
if is_numpy_array(obj):
import numpy as np
if hasattr(np, '_CopyMode'):
# Not present in older NumPys
Expand Down
72 changes: 53 additions & 19 deletions array_api_compat/common/_helpers.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp, I could mostly use your review for the changes in this file.

Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

import sys
import math
import inspect

def _is_numpy_array(x):
def is_numpy_array(x):
# Avoid importing NumPy if it isn't already
if 'numpy' not in sys.modules:
return False
Expand All @@ -20,7 +21,7 @@ def _is_numpy_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, (np.ndarray, np.generic))

def _is_cupy_array(x):
def is_cupy_array(x):
# Avoid importing NumPy if it isn't already
if 'cupy' not in sys.modules:
return False
Expand All @@ -30,7 +31,7 @@ def _is_cupy_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, (cp.ndarray, cp.generic))

def _is_torch_array(x):
def is_torch_array(x):
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
return False
Expand All @@ -40,7 +41,7 @@ def _is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def _is_dask_array(x):
def is_dask_array(x):
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False
Expand All @@ -49,14 +50,24 @@ def _is_dask_array(x):

return isinstance(x, dask.array.Array)

def is_jax_array(x):
# Avoid importing jax if it isn't already
if 'jax' not in sys.modules:
return False

import jax

return isinstance(x, jax.Array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be more guarded? e.g. what if someone has a module named jax.py in their path, it could lead to an AttributeError here and make array_api_compat unusable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. That sort of thing tends to just break everything anyway. I've never really felt that libraries should protect against that sort of thing.

Anyway, the whole point of this function is to be guarded. It won't import jax unless it's already been imported, because jax is a slow import and unnecessary if a user is using another array library.

Copy link
Contributor

@jakevdp jakevdp Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More precisely, it won't import jax unless something named jax has been imported, and then array_api_compat will irrecoverably error if that jax module isn't the one we're expecting, or is an older version that doesn't define the attributes we reference.

A similar issue exists for every other package name referenced in this module.

My feeling is: it costs virtually nothing to wrap this all in an appropriate try ... except, so why not do so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not so sure this is a good idea. Usually if you have that sort of thing it will be an error for a lot of things, not just array_api_compat. My worry here is that guarding isn't as straightforward as it might seem. Wrapping everything in try/except could mean we end up silencing legitimate errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good.


def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or _is_dask_array(x) \
return is_numpy_array(x) \
or is_cupy_array(x) \
or is_torch_array(x) \
or is_dask_array(x) \
or is_jax_array(x) \
or hasattr(x, '__array_namespace__')

def _check_api_version(api_version):
Expand All @@ -81,37 +92,43 @@ def your_function(x, y):
"""
namespaces = set()
for x in xs:
if _is_numpy_array(x):
if is_numpy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import numpy as numpy_namespace
namespaces.add(numpy_namespace)
else:
import numpy as np
namespaces.add(np)
elif _is_cupy_array(x):
elif is_cupy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import cupy as cupy_namespace
namespaces.add(cupy_namespace)
else:
import cupy as cp
namespaces.add(cp)
elif _is_torch_array(x):
elif is_torch_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import torch as torch_namespace
namespaces.add(torch_namespace)
else:
import torch
namespaces.add(torch)
elif _is_dask_array(x):
elif is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
elif is_jax_array(x):
_check_api_version(api_version)
# jax.experimental.array_api is already an array namespace. We do
# not have a wrapper submodule for it.
import jax.experimental.array_api as jnp
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this import go away at some point? Should we guard against that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not true that jax.numpy is already an array namespace. jax.experimental.array_api is an array namespace, and we hope to make jax.numpy an array namespace in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I fixed the comment. The question still remains though? Should I add a guard here like

try:
    import jax.experimental.array_api as jnp
except ImportError:
    import jax.numpy as jnp

for a future JAX version when jax.experimental.array_api gets removed? Or will it never be removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that's probably a reasonable way to future-proof this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside of that particular line of code is it will also pass through jax.numpy for older JAX versions that don't have jax.experimental.array_api. I like to avoid explicit version checks in this library if I can, but maybe that's the best thing to do here.

Or maybe we can just change the logic to this once JAX starts to remove (deprecates?) the experimental import.

namespaces.add(jnp)
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
Expand All @@ -136,7 +153,7 @@ def _check_device(xp, device):
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")

# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
Expand All @@ -156,8 +173,17 @@ def device(x: "Array", /) -> "Device":
out: device
a ``device`` object (see the "Device Support" section of the array API specification).
"""
if _is_numpy_array(x):
if is_numpy_array(x):
return "cpu"
if is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this logic seem OK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks OK, but it will not work with the to_device function as currently defined in jax.experimental.array_api. That expects to be passed the bound method object – it's a hack, and one of the reasons that this is still considered experimental.

Copy link
Contributor

@jakevdp jakevdp Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other problem here is that, in general, JAX arrays can live on multiple devices (in which case arr.device() will error). It's not clear to me from the Array API design how multi-device arrays are meant to be supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the answer to that. I would bring it up on the array API repo. https://github.com/data-apis/array-api/. As far as I know it hasn't really been discussed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like both forms work

>>> import jax.experimental.array_api as xp
>>> x = xp.asarray([1, 2, 3])
>>> x.to_device(x.device)
Array([1, 2, 3], dtype=int32)
>>> x.to_device(x.device())
Array([1, 2, 3], dtype=int32)

if inspect.ismethod(x.device):
return x.device()
else:
return x.device
return x.device

# Based on cupy.array_api.Array.to_device
Expand Down Expand Up @@ -225,24 +251,30 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
.. note::
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
"""
if _is_numpy_array(x):
if is_numpy_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
elif _is_cupy_array(x):
elif is_cupy_array(x):
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif _is_torch_array(x):
elif is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif _is_dask_array(x):
elif is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_jax_array(x):
# This import adds to_device to x
import jax.experimental.array_api
if device == 'cpu':
device = jax.devices('cpu')[0]
return x.to_device(device, stream=stream)
return x.to_device(device, stream=stream)

def size(x):
Expand All @@ -253,4 +285,6 @@ def size(x):
return None
return math.prod(x.shape)

__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device',
'to_device', 'size', 'is_numpy_array', 'is_cupy_array',
'is_torch_array', 'is_dask_array', 'is_jax_array']
23 changes: 0 additions & 23 deletions tests/test_common.py

This file was deleted.

73 changes: 73 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_array_api_obj,
device, to_device)

from ._helpers import import_

import pytest
import numpy as np
from numpy.testing import assert_allclose

is_functions = {
'numpy': 'is_numpy_array',
'cupy': 'is_cupy_array',
'torch': 'is_torch_array',
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
}

@pytest.mark.parametrize('library', is_functions.keys())
@pytest.mark.parametrize('func', is_functions.values())
def test_is_xp_array(library, func):
lib = import_(library)
is_func = globals()[func]

x = lib.asarray([1, 2, 3])

assert is_func(x) == (func == is_functions[library])

assert is_array_api_obj(x)

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
def test_device(library):
if library == "dask.array":
pytest.xfail("device() needs to be fixed for dask")

if library == "jax.numpy":
xp = import_('jax.experimental.array_api')
else:
xp = import_('array_api_compat.' + library)

# We can't test much for device() and to_device() other than that
# x.to_device(x.device) works.

x = xp.asarray([1, 2, 3])
dev = device(x)

x2 = to_device(x, dev)
assert device(x) == device(x2)


@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
def test_to_device_host(library):
# Test that "cpu" device works. Note: this isn't actually supported by the
# standard yet. See https://github.com/data-apis/array-api/issues/626.

# different libraries have different semantics
# for DtoH transfers; ensure that we support a portable
# shim for common array libs
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
if library == "jax.numpy":
xp = import_('jax.experimental.array_api')
else:
xp = import_('array_api_compat.' + library)

expected = np.array([1, 2, 3])
x = xp.asarray([1, 2, 3])
x = to_device(x, "cpu")
# torch will return a genuine Device object, but
# the other libs will do something different with
# a `device(x)` query; however, what's really important
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)
14 changes: 10 additions & 4 deletions tests/test_isdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ def isdtype_(dtype_, kind):
assert type(res) is bool
return res

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
def test_isdtype_spec_dtypes(library):
xp = import_('array_api_compat.' + library)
if library == "jax.numpy":
xp = import_('jax.experimental.array_api')
else:
xp = import_('array_api_compat.' + library)

isdtype = xp.isdtype

Expand Down Expand Up @@ -98,10 +101,13 @@ def test_isdtype_spec_dtypes(library):
'bfloat16',
]

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
@pytest.mark.parametrize("dtype_", additional_dtypes)
def test_isdtype_additional_dtypes(library, dtype_):
xp = import_('array_api_compat.' + library)
if library == "jax.numpy":
xp = import_('jax.experimental.array_api')
else:
xp = import_('array_api_compat.' + library)

isdtype = xp.isdtype

Expand Down
Loading