From 2ee6902c03a0c76a1460adbec7ed4ca5d53cc661 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 14:58:18 -0700 Subject: [PATCH 01/19] Make is_numpy_array, is_cupy_array, is_torch_array, and is_dask_array public --- README.md | 5 ++++ array_api_compat/common/_aliases.py | 4 +-- array_api_compat/common/_helpers.py | 38 +++++++++++++++-------------- tests/test_helpers.py | 25 +++++++++++++++++++ 4 files changed, 52 insertions(+), 20 deletions(-) create mode 100644 tests/test_helpers.py diff --git a/README.md b/README.md index 2c0ce59a..dd90f340 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,11 @@ part of the specification but which are useful for using the array API: - `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array object. +- `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`, + `is_dask_array(x)`: return `True` if `x` is an array from the corresponding + library. These functions do not import the underlying library if it has not + already been imported, so they are cheap to use. + - `array_namespace(*xs)`: Get the corresponding array API namespace for the arrays `xs`. For example, if the arrays are NumPy arrays, the returned namespace will be `array_api_compat.numpy`. Note that this function will diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 7713213e..0f67387f 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -13,7 +13,7 @@ from types import ModuleType import inspect -from ._helpers import _check_device, _is_numpy_array, array_namespace +from ._helpers import _check_device, is_numpy_array, array_namespace # These functions are modified from the NumPy versions. @@ -309,7 +309,7 @@ def _asarray( raise ValueError("Unrecognized namespace argument to asarray()") _check_device(xp, device) - if _is_numpy_array(obj): + if is_numpy_array(obj): import numpy as np if hasattr(np, '_CopyMode'): # Not present in older NumPys diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 82bf47c1..b5197467 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -10,7 +10,7 @@ import sys import math -def _is_numpy_array(x): +def is_numpy_array(x): # Avoid importing NumPy if it isn't already if 'numpy' not in sys.modules: return False @@ -20,7 +20,7 @@ def _is_numpy_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, (np.ndarray, np.generic)) -def _is_cupy_array(x): +def is_cupy_array(x): # Avoid importing NumPy if it isn't already if 'cupy' not in sys.modules: return False @@ -30,7 +30,7 @@ def _is_cupy_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, (cp.ndarray, cp.generic)) -def _is_torch_array(x): +def is_torch_array(x): # Avoid importing torch if it isn't already if 'torch' not in sys.modules: return False @@ -40,7 +40,7 @@ def _is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) -def _is_dask_array(x): +def is_dask_array(x): # Avoid importing dask if it isn't already if 'dask.array' not in sys.modules: return False @@ -53,10 +53,10 @@ def is_array_api_obj(x): """ Check if x is an array API compatible array object. """ - return _is_numpy_array(x) \ - or _is_cupy_array(x) \ - or _is_torch_array(x) \ - or _is_dask_array(x) \ + return is_numpy_array(x) \ + or is_cupy_array(x) \ + or is_torch_array(x) \ + or is_dask_array(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -81,7 +81,7 @@ def your_function(x, y): """ namespaces = set() for x in xs: - if _is_numpy_array(x): + if is_numpy_array(x): _check_api_version(api_version) if _use_compat: from .. import numpy as numpy_namespace @@ -89,7 +89,7 @@ def your_function(x, y): else: import numpy as np namespaces.add(np) - elif _is_cupy_array(x): + elif is_cupy_array(x): _check_api_version(api_version) if _use_compat: from .. import cupy as cupy_namespace @@ -97,7 +97,7 @@ def your_function(x, y): else: import cupy as cp namespaces.add(cp) - elif _is_torch_array(x): + elif is_torch_array(x): _check_api_version(api_version) if _use_compat: from .. import torch as torch_namespace @@ -105,7 +105,7 @@ def your_function(x, y): else: import torch namespaces.add(torch) - elif _is_dask_array(x): + elif is_dask_array(x): _check_api_version(api_version) if _use_compat: from ..dask import array as dask_namespace @@ -156,7 +156,7 @@ def device(x: "Array", /) -> "Device": out: device a ``device`` object (see the "Device Support" section of the array API specification). """ - if _is_numpy_array(x): + if is_numpy_array(x): return "cpu" return x.device @@ -225,18 +225,18 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A .. note:: If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. """ - if _is_numpy_array(x): + if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") if device == 'cpu': return x raise ValueError(f"Unsupported device {device!r}") - elif _is_cupy_array(x): + elif is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) - elif _is_torch_array(x): + elif is_torch_array(x): return _torch_to_device(x, device, stream=stream) - elif _is_dask_array(x): + elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? @@ -253,4 +253,6 @@ def size(x): return None return math.prod(x.shape) -__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] +__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', + 'to_device', 'size', 'is_numpy_array', 'is_cupy_array', + 'is_torch_array', 'is_dask_array'] diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 00000000..c1cc59d1 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,25 @@ +from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, + is_dask_array, is_array_api_obj) + +from ._helpers import import_ + +import pytest + +is_functions = { + 'numpy': 'is_numpy_array', + 'cupy': 'is_cupy_array', + 'torch': 'is_torch_array', + 'dask.array': 'is_dask_array', +} + +@pytest.mark.parametrize('library', is_functions.keys()) +@pytest.mark.parametrize('func', is_functions.values()) +def test_is_xp_array(library, func): + lib = import_(library) + is_func = globals()[func] + + x = lib.asarray([1, 2, 3]) + + assert is_func(x) == (func == is_functions[library]) + + assert is_array_api_obj(x) From 12b5294cd7003494c72e307083b42dbe7daf7350 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 15:00:25 -0700 Subject: [PATCH 02/19] Note dask in a couple of places in the README --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dd90f340..5fe19381 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, and PyTorch are supported. If you want support for other array +NumPy, CuPy, PyTorch, and Dask are supported. If you want support for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). @@ -56,7 +56,11 @@ import array_api_compat.cupy as cp import array_api_compat.torch as torch ``` -Each will include all the functions from the normal NumPy/CuPy/PyTorch +```py +import array_api_compat.dask as da +``` + +Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array namespace, except that functions that are part of the array API are wrapped so that they have the correct array API behavior. In each case, the array object used will be the same array object from the wrapped library. From 6bf5dad170d9a996c62e204e8e84bad339ebedc2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 15:24:06 -0700 Subject: [PATCH 03/19] Add JAX support Unlike other libraries, there is no wrapping for JAX. Actual JAX array_api support is in JAX itself in the jax.experimental.array_api submodule. This just adds JAX support to the various helper functions. This also means that we do not run array-api-tests on JAX. Closes #83. --- README.md | 20 ++++++++++++---- array_api_compat/common/_helpers.py | 37 ++++++++++++++++++++++++++++- tests/test_helpers.py | 3 ++- tests/test_isdtype.py | 14 +++++++---- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 5fe19381..5be86271 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, PyTorch, and Dask are supported. If you want support for other array +NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). @@ -60,6 +60,12 @@ import array_api_compat.torch as torch import array_api_compat.dask as da ``` +> [!NOTE] +> There is no `array_api_compat.jax` submodule. JAX support is contained +> in JAX itself in the `jax.experimental.array_api` module. array-api-compat simply +> wraps that submodule. The main JAX support in this module consists of +> supporting it in the [helper functions](#helper-functions) defined below. + Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array namespace, except that functions that are part of the array API are wrapped so that they have the correct array API behavior. In each case, the array object @@ -104,9 +110,9 @@ part of the specification but which are useful for using the array API: object. - `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`, - `is_dask_array(x)`: return `True` if `x` is an array from the corresponding - library. These functions do not import the underlying library if it has not - already been imported, so they are cheap to use. + `is_dask_array(x)`, `is_jax_array(x)`: return `True` if `x` is an array from + the corresponding library. These functions do not import the underlying + library if it has not already been imported, so they are cheap to use. - `array_namespace(*xs)`: Get the corresponding array API namespace for the arrays `xs`. For example, if the arrays are NumPy arrays, the returned @@ -228,6 +234,12 @@ version. The minimum supported PyTorch version is 1.13. +### JAX + +Unlike the other libraries supported here, JAX array API support is contained +entirely in the JAX library. The JAX array API support is tracked at +https://github.com/google/jax/issues/18353. + ## Vendoring This library supports vendoring as an installation method. To vendor the diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b5197467..0d1e6d27 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -9,6 +9,7 @@ import sys import math +import inspect def is_numpy_array(x): # Avoid importing NumPy if it isn't already @@ -49,6 +50,15 @@ def is_dask_array(x): return isinstance(x, dask.array.Array) +def is_jax_array(x): + # Avoid importing jax if it isn't already + if 'jax' not in sys.modules: + return False + + import jax.numpy + + return isinstance(x, jax.numpy.ndarray) + def is_array_api_obj(x): """ Check if x is an array API compatible array object. @@ -57,6 +67,7 @@ def is_array_api_obj(x): or is_cupy_array(x) \ or is_torch_array(x) \ or is_dask_array(x) \ + or is_jax_array(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -112,6 +123,13 @@ def your_function(x, y): namespaces.add(dask_namespace) else: raise TypeError("_use_compat cannot be False if input array is a dask array!") + elif is_jax_array(x): + _check_api_version(api_version) + # jax.numpy is already an array namespace, but requires this + # side-effecting import for __array_namespace__ and some other + # things to be defined. + import jax.experimental.array_api as jnp + namespaces.add(jnp) elif hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) else: @@ -158,6 +176,15 @@ def device(x: "Array", /) -> "Device": """ if is_numpy_array(x): return "cpu" + if is_jax_array(x): + # JAX has .device() as a method, but it is being deprecated so that it + # can become a property, in accordance with the standard. In order for + # this function to not break when JAX makes the flip, we check for + # both here. + if inspect.ismethod(x.device): + return x.device() + else: + return x.device return x.device # Based on cupy.array_api.Array.to_device @@ -204,6 +231,12 @@ def _torch_to_device(x, device, /, stream=None): raise NotImplementedError return x.to(device) +def _jax_to_device(x, device, /, stream=None): + import jax + if stream is not None: + raise NotImplementedError + return jax.device_put(x, device) + def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -243,6 +276,8 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A if device == 'cpu': return x raise ValueError(f"Unsupported device {device!r}") + elif is_jax_array(x): + return _jax_to_device(x, device, stream=stream) return x.to_device(device, stream=stream) def size(x): @@ -255,4 +290,4 @@ def size(x): __all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size', 'is_numpy_array', 'is_cupy_array', - 'is_torch_array', 'is_dask_array'] + 'is_torch_array', 'is_dask_array', 'is_jax_array'] diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c1cc59d1..65de2276 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,5 @@ from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, - is_dask_array, is_array_api_obj) + is_dask_array, is_jax_array, is_array_api_obj) from ._helpers import import_ @@ -10,6 +10,7 @@ 'cupy': 'is_cupy_array', 'torch': 'is_torch_array', 'dask.array': 'is_dask_array', + 'jax.numpy': 'is_jax_array', } @pytest.mark.parametrize('library', is_functions.keys()) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 77e7ce72..ff615bd7 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -64,9 +64,12 @@ def isdtype_(dtype_, kind): assert type(res) is bool return res -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) def test_isdtype_spec_dtypes(library): - xp = import_('array_api_compat.' + library) + if library == "jax.numpy": + xp = import_('jax.experimental.array_api') + else: + xp = import_('array_api_compat.' + library) isdtype = xp.isdtype @@ -98,10 +101,13 @@ def test_isdtype_spec_dtypes(library): 'bfloat16', ] -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): - xp = import_('array_api_compat.' + library) + if library == "jax.numpy": + xp = import_('jax.experimental.array_api') + else: + xp = import_('array_api_compat.' + library) isdtype = xp.isdtype From 583f6bbe858e116dc6e77647853997930cfb8f36 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 15:27:33 -0700 Subject: [PATCH 04/19] Install JAX on CI --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 71083fbc..fa06bba5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy torch dask[array] + python -m pip install pytest numpy torch dask[array] jax - name: Run Tests run: | From 9c8bed69cc8c360c81d952209e0e5d9d9309d0bd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 15:33:13 -0700 Subject: [PATCH 05/19] Install JAX as jax[cpu] --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fa06bba5..2877bf06 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy torch dask[array] jax + python -m pip install pytest numpy torch dask[array] jax[cpu] - name: Run Tests run: | From 6d59ae8d7a438f0881863b4faa037241b54508cb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 15:36:04 -0700 Subject: [PATCH 06/19] Fix a comment --- array_api_compat/common/_helpers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 0d1e6d27..c71ba206 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -125,9 +125,8 @@ def your_function(x, y): raise TypeError("_use_compat cannot be False if input array is a dask array!") elif is_jax_array(x): _check_api_version(api_version) - # jax.numpy is already an array namespace, but requires this - # side-effecting import for __array_namespace__ and some other - # things to be defined. + # jax.experimental.array_api is already an array namespace. We do + # not have a wrapper submodule for it. import jax.experimental.array_api as jnp namespaces.add(jnp) elif hasattr(x, '__array_namespace__'): From ce07cd9d7a36f8de3a3aab0f9e17d931b4e9878f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 16:20:39 -0700 Subject: [PATCH 07/19] Use jax.Array instead of jax.numpy.ndarray in is_jax_array --- array_api_compat/common/_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index c71ba206..657645dd 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -55,9 +55,9 @@ def is_jax_array(x): if 'jax' not in sys.modules: return False - import jax.numpy + import jax - return isinstance(x, jax.numpy.ndarray) + return isinstance(x, jax.Array) def is_array_api_obj(x): """ From ddb313eac30f7b61921f222738be9c323e5b6f72 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 16:40:35 -0700 Subject: [PATCH 08/19] Use the native jax to_device() method --- array_api_compat/common/_helpers.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 657645dd..73c1fc19 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -153,7 +153,7 @@ def _check_device(xp, device): if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") -# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray +# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of @@ -230,12 +230,6 @@ def _torch_to_device(x, device, /, stream=None): raise NotImplementedError return x.to(device) -def _jax_to_device(x, device, /, stream=None): - import jax - if stream is not None: - raise NotImplementedError - return jax.device_put(x, device) - def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -276,7 +270,9 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): - return _jax_to_device(x, device, stream=stream) + # This import adds to_device to x + import jax.experimental.array_api + return x.to_device(device, stream=stream) return x.to_device(device, stream=stream) def size(x): From 6004b97f71f7c90fb6d1d1d0c98ea29df1ed3af4 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 16:40:54 -0700 Subject: [PATCH 09/19] Add a basic test for device() and to_device() --- tests/test_helpers.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 65de2276..4c560286 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,6 @@ from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, - is_dask_array, is_jax_array, is_array_api_obj) + is_dask_array, is_jax_array, is_array_api_obj, + device, to_device) from ._helpers import import_ @@ -24,3 +25,19 @@ def test_is_xp_array(library, func): assert is_func(x) == (func == is_functions[library]) assert is_array_api_obj(x) + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) +def test_device(library): + if library == "jax.numpy": + xp = import_('jax.experimental.array_api') + else: + xp = import_('array_api_compat.' + library) + + # We can't test much for device() and to_device() other than that + # x.to_device(x.device) works. + + x = xp.asarray([1, 2, 3]) + dev = device(x) + + x2 = to_device(x, dev) + assert device(x) == device(x2) From 701a5ef2b8a493a08e09e6741fd167893f7e357c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 16:45:31 -0700 Subject: [PATCH 10/19] XFAIL the device test for dask.array --- tests/test_helpers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 4c560286..df2d7695 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -28,6 +28,9 @@ def test_is_xp_array(library, func): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) def test_device(library): + if library == "dask.array": + pytest.xfail("device() needs to be fixed for dask") + if library == "jax.numpy": xp = import_('jax.experimental.array_api') else: From 049d5571c8af3ec69b890b9028c7bf082cac1273 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 16:47:45 -0700 Subject: [PATCH 11/19] Allow to_device(x, "cpu") for JAX arrays --- array_api_compat/common/_helpers.py | 2 ++ tests/test_common.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 73c1fc19..2d13604e 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -272,6 +272,8 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A elif is_jax_array(x): # This import adds to_device to x import jax.experimental.array_api + if device == 'cpu': + device = jax.devices('cpu')[0] return x.to_device(device, stream=stream) return x.to_device(device, stream=stream) diff --git a/tests/test_common.py b/tests/test_common.py index f98a717a..0a2162fb 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,13 +5,17 @@ import numpy as np from numpy.testing import assert_allclose -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - xp = import_('array_api_compat.' + library) + if library == "jax.numpy": + xp = import_('jax.experimental.array_api') + else: + xp = import_('array_api_compat.' + library) + expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) x = to_device(x, "cpu") From db667ea3997a630105ef8b1a4a34e391a47590e2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Feb 2024 16:49:13 -0700 Subject: [PATCH 12/19] Move tests in test_common.py to test_helpers.py --- tests/test_common.py | 27 --------------------------- tests/test_helpers.py | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 27 deletions(-) delete mode 100644 tests/test_common.py diff --git a/tests/test_common.py b/tests/test_common.py deleted file mode 100644 index 0a2162fb..00000000 --- a/tests/test_common.py +++ /dev/null @@ -1,27 +0,0 @@ -from ._helpers import import_ -from array_api_compat import to_device, device - -import pytest -import numpy as np -from numpy.testing import assert_allclose - -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) -def test_to_device_host(library): - # different libraries have different semantics - # for DtoH transfers; ensure that we support a portable - # shim for common array libs - # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - if library == "jax.numpy": - xp = import_('jax.experimental.array_api') - else: - xp = import_('array_api_compat.' + library) - - expected = np.array([1, 2, 3]) - x = xp.asarray([1, 2, 3]) - x = to_device(x, "cpu") - # torch will return a genuine Device object, but - # the other libs will do something different with - # a `device(x)` query; however, what's really important - # here is that we can test portably after calling - # to_device(x, "cpu") to return to host - assert_allclose(x, expected) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index df2d7695..e4018840 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -5,6 +5,8 @@ from ._helpers import import_ import pytest +import numpy as np +from numpy.testing import assert_allclose is_functions = { 'numpy': 'is_numpy_array', @@ -44,3 +46,28 @@ def test_device(library): x2 = to_device(x, dev) assert device(x) == device(x2) + + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) +def test_to_device_host(library): + # Test that "cpu" device works. Note: this isn't actually supported by the + # standard yet. See https://github.com/data-apis/array-api/issues/626. + + # different libraries have different semantics + # for DtoH transfers; ensure that we support a portable + # shim for common array libs + # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 + if library == "jax.numpy": + xp = import_('jax.experimental.array_api') + else: + xp = import_('array_api_compat.' + library) + + expected = np.array([1, 2, 3]) + x = xp.asarray([1, 2, 3]) + x = to_device(x, "cpu") + # torch will return a genuine Device object, but + # the other libs will do something different with + # a `device(x)` query; however, what's really important + # here is that we can test portably after calling + # to_device(x, "cpu") to return to host + assert_allclose(x, expected) From aafbbaa941b419684700914652ec97afe9e4f858 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Feb 2024 13:05:57 -0700 Subject: [PATCH 13/19] Remove "cpu" device from jax to_device() --- array_api_compat/common/_helpers.py | 2 -- tests/test_helpers.py | 10 ++-------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 2d13604e..73c1fc19 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -272,8 +272,6 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A elif is_jax_array(x): # This import adds to_device to x import jax.experimental.array_api - if device == 'cpu': - device = jax.devices('cpu')[0] return x.to_device(device, stream=stream) return x.to_device(device, stream=stream) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index e4018840..a79b8512 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -48,19 +48,13 @@ def test_device(library): assert device(x) == device(x2) -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_to_device_host(library): - # Test that "cpu" device works. Note: this isn't actually supported by the - # standard yet. See https://github.com/data-apis/array-api/issues/626. - # different libraries have different semantics # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - if library == "jax.numpy": - xp = import_('jax.experimental.array_api') - else: - xp = import_('array_api_compat.' + library) + xp = import_('array_api_compat.' + library) expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) From fa758f7f2b802ed9e0d3aecbb0451152c17e8569 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Feb 2024 13:25:45 -0700 Subject: [PATCH 14/19] Rename import_or_skip_cupy to import_ and move jax logic into it --- tests/_helpers.py | 11 +++++++++-- tests/test_array_namespace.py | 4 ++-- tests/test_helpers.py | 11 ++++------- tests/test_isdtype.py | 12 +++--------- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index 69952118..a070a53f 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -3,7 +3,14 @@ import pytest -def import_or_skip_cupy(library): - if "cupy" in library: +def import_(library, wrapper=False): + if library == 'cupy': return pytest.importorskip(library) + + if wrapper: + if 'jax' in library: + library = 'jax.experimental.array_api' + else: + library = 'array_api_compat.' + library + return import_module(library) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 2c596d70..9fbf5656 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -5,13 +5,13 @@ import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_or_skip_cupy +from ._helpers import import_ @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("api_version", [None, "2021.12"]) def test_array_namespace(library, api_version): - xp = import_or_skip_cupy(library) + xp = import_(library) array = xp.asarray([1.0, 2.0, 3.0]) namespace = array_api_compat.array_namespace(array, api_version=api_version) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index daab5d58..730c0903 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,7 +2,7 @@ is_dask_array, is_jax_array, is_array_api_obj, device, to_device) -from ._helpers import import_or_skip_cupy +from ._helpers import import_ import pytest import numpy as np @@ -19,7 +19,7 @@ @pytest.mark.parametrize('library', is_functions.keys()) @pytest.mark.parametrize('func', is_functions.values()) def test_is_xp_array(library, func): - lib = import_or_skip_cupy(library) + lib = import_(library) is_func = globals()[func] x = lib.asarray([1, 2, 3]) @@ -33,10 +33,7 @@ def test_device(library): if library == "dask.array": pytest.xfail("device() needs to be fixed for dask") - if library == "jax.numpy": - xp = import_or_skip_cupy('jax.experimental.array_api') - else: - xp = import_or_skip_cupy('array_api_compat.' + library) + xp = import_(library, wrapper=True) # We can't test much for device() and to_device() other than that # x.to_device(x.device) works. @@ -54,7 +51,7 @@ def test_to_device_host(library): # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - xp = import_or_skip_cupy('array_api_compat.' + library) + xp = import_(library, wrapper=True) expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index a7eb8c10..f4c245f4 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -5,7 +5,7 @@ import pytest -from ._helpers import import_or_skip_cupy +from ._helpers import import_ # Check the known dtypes by their string names @@ -66,10 +66,7 @@ def isdtype_(dtype_, kind): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) def test_isdtype_spec_dtypes(library): - if library == "jax.numpy": - xp = import_or_skip_cupy('jax.experimental.array_api') - else: - xp = import_or_skip_cupy('array_api_compat.' + library) + xp = import_(library, wrapper=True) isdtype = xp.isdtype @@ -104,10 +101,7 @@ def test_isdtype_spec_dtypes(library): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): - if library == "jax.numpy": - xp = import_or_skip_cupy('jax.experimental.array_api') - else: - xp = import_or_skip_cupy('array_api_compat.' + library) + xp = import_(library, wrapper=True) isdtype = xp.isdtype From 6c338cabc082b32b2513b0d2b12691fe1af61801 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Feb 2024 13:26:46 -0700 Subject: [PATCH 15/19] Skip JAX tests in Python 3.8 --- tests/_helpers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/_helpers.py b/tests/_helpers.py index a070a53f..e05ae86c 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,11 +1,15 @@ from importlib import import_module +import sys + import pytest def import_(library, wrapper=False): if library == 'cupy': return pytest.importorskip(library) + if 'jax' in library and sys.version_info <= (3, 8): + pytest.skip('JAX array API support does not support Python 3.8') if wrapper: if 'jax' in library: From 244462fd968a78b8ccf9e93bbde1bbe27566f0ce Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Feb 2024 13:27:06 -0700 Subject: [PATCH 16/19] Rename test_helpers.py to test_common.py test_helpers.py is too confusing alongside tests/_helpers.py. --- tests/{test_helpers.py => test_common.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_helpers.py => test_common.py} (100%) diff --git a/tests/test_helpers.py b/tests/test_common.py similarity index 100% rename from tests/test_helpers.py rename to tests/test_common.py From bff9bf295a1865c613ccb3bfd6f63157855e3ea9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Feb 2024 13:29:19 -0700 Subject: [PATCH 17/19] Fix ruff warnings --- array_api_compat/common/_helpers.py | 4 ++-- tests/test_common.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 5e1671a2..5e59c7ea 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from typing import Optional, Union, Any - from ._typing import Array, Device + from ._typing import Array, Device import sys import math @@ -277,7 +277,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): # This import adds to_device to x - import jax.experimental.array_api + import jax.experimental.array_api # noqa: F401 return x.to_device(device, stream=stream) return x.to_device(device, stream=stream) diff --git a/tests/test_common.py b/tests/test_common.py index 730c0903..b84dfdde 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,6 +1,7 @@ -from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, - is_dask_array, is_jax_array, is_array_api_obj, - device, to_device) +from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401 + is_dask_array, is_jax_array) + +from array_api_compat import is_array_api_obj, device, to_device from ._helpers import import_ From 264e6c36b23b63aa6a20cfb3041caae638a77a66 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Feb 2024 13:44:44 -0700 Subject: [PATCH 18/19] Add jax.numpy to the test_array_namespace() This requires using subprocess to test that it works even if the side-effecting jax.experimental.array_api hasn't been imported yet. --- tests/test_array_namespace.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 9fbf5656..21fc31bb 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -1,3 +1,6 @@ +import subprocess +import sys + import numpy as np import pytest import torch @@ -7,8 +10,7 @@ from ._helpers import import_ - -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) @pytest.mark.parametrize("api_version", [None, "2021.12"]) def test_array_namespace(library, api_version): xp = import_(library) @@ -21,9 +23,31 @@ def test_array_namespace(library, api_version): else: if library == "dask.array": assert namespace == array_api_compat.dask.array + elif library == "jax.numpy": + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == getattr(array_api_compat, library) + # Check that array_namespace works even if jax.experimental.array_api + # hasn't been imported yet (it monkeypatches __array_namespace__ + # onto JAX arrays, but we should support them regardless). The only way to + # do this is to use a subprocess, since we cannot un-import it and another + # test probably already imported it. + if library == "jax.numpy": + code = f"""\ +import sys +import jax.numpy +import array_api_compat +array = jax.numpy.asarray([1.0, 2.0, 3.0]) + +assert 'jax.experimental.array_api' not in sys.modules +namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) + +import jax.experimental.array_api +assert namespace == jax.experimental.array_api +""" + subprocess.run([sys.executable, "-c", code], check=True) def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) From e7aff0f48854d22c8921be99047a7c398220bc36 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 8 Feb 2024 13:45:44 -0700 Subject: [PATCH 19/19] Don't run the jax.numpy array_namespace test in Python 3.8 --- tests/test_array_namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 21fc31bb..7aaef971 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -34,7 +34,7 @@ def test_array_namespace(library, api_version): # onto JAX arrays, but we should support them regardless). The only way to # do this is to use a subprocess, since we cannot un-import it and another # test probably already imported it. - if library == "jax.numpy": + if library == "jax.numpy" and sys.version_info >= (3, 9): code = f"""\ import sys import jax.numpy