diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 60848b60..2467793c 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -276,7 +276,7 @@ def is_numpy_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'numpy', _compat_module_name + '.numpy'} + return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} def is_cupy_namespace(xp) -> bool: """ @@ -296,7 +296,7 @@ def is_cupy_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'cupy', _compat_module_name + '.cupy'} + return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} def is_torch_namespace(xp) -> bool: """ @@ -316,7 +316,7 @@ def is_torch_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'torch', _compat_module_name + '.torch'} + return xp.__name__ in {'torch', _compat_module_name() + '.torch'} def is_ndonnx_namespace(xp): @@ -355,7 +355,7 @@ def is_dask_namespace(xp): is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name + '.dask.array'} + return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} def is_jax_namespace(xp): """ diff --git a/tests/test_common.py b/tests/test_common.py index 294a112a..e1cfa9eb 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,9 @@ -from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401 - is_dask_array, is_jax_array, is_pydata_sparse_array) +from array_api_compat import ( # noqa: F401 + is_numpy_array, is_cupy_array, is_torch_array, + is_dask_array, is_jax_array, is_pydata_sparse_array, + is_numpy_namespace, is_cupy_namespace, is_torch_namespace, + is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, +) from array_api_compat import is_array_api_obj, device, to_device @@ -10,7 +14,7 @@ import array from numpy.testing import assert_allclose -is_functions = { +is_array_functions = { 'numpy': 'is_numpy_array', 'cupy': 'is_cupy_array', 'torch': 'is_torch_array', @@ -19,18 +23,38 @@ 'sparse': 'is_pydata_sparse_array', } -@pytest.mark.parametrize('library', is_functions.keys()) -@pytest.mark.parametrize('func', is_functions.values()) +is_namespace_functions = { + 'numpy': 'is_numpy_namespace', + 'cupy': 'is_cupy_namespace', + 'torch': 'is_torch_namespace', + 'dask.array': 'is_dask_namespace', + 'jax.numpy': 'is_jax_namespace', + 'sparse': 'is_pydata_sparse_namespace', +} + + +@pytest.mark.parametrize('library', is_array_functions.keys()) +@pytest.mark.parametrize('func', is_array_functions.values()) def test_is_xp_array(library, func): lib = import_(library) is_func = globals()[func] x = lib.asarray([1, 2, 3]) - assert is_func(x) == (func == is_functions[library]) + assert is_func(x) == (func == is_array_functions[library]) assert is_array_api_obj(x) + +@pytest.mark.parametrize('library', is_namespace_functions.keys()) +@pytest.mark.parametrize('func', is_namespace_functions.values()) +def test_is_xp_namespace(library, func): + lib = import_(library) + is_func = globals()[func] + + assert is_func(lib) == (func == is_namespace_functions[library]) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True) @@ -64,8 +88,8 @@ def test_to_device_host(library): assert_allclose(x, expected) -@pytest.mark.parametrize("target_library", is_functions.keys()) -@pytest.mark.parametrize("source_library", is_functions.keys()) +@pytest.mark.parametrize("target_library", is_array_functions.keys()) +@pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): if source_library == "dask.array" and target_library == "torch": # Allow rest of test to execute instead of immediately xfailing @@ -81,7 +105,7 @@ def test_asarray_cross_library(source_library, target_library, request): pytest.skip(reason="`sparse` does not allow implicit densification") src_lib = import_(source_library, wrapper=True) tgt_lib = import_(target_library, wrapper=True) - is_tgt_type = globals()[is_functions[target_library]] + is_tgt_type = globals()[is_array_functions[target_library]] a = src_lib.asarray([1, 2, 3]) b = tgt_lib.asarray(a) @@ -96,7 +120,7 @@ def test_asarray_copy(library): # should be able to delete this. xp = import_(library, wrapper=True) asarray = xp.asarray - is_lib_func = globals()[is_functions[library]] + is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 66fc6984..70083b49 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -20,6 +20,7 @@ def test_vendoring_torch(): uses_torch._test_torch() + def test_vendoring_dask(): from vendor_test import uses_dask uses_dask._test_dask() diff --git a/vendor_test/uses_cupy.py b/vendor_test/uses_cupy.py index 97f710b9..e3bbdebe 100644 --- a/vendor_test/uses_cupy.py +++ b/vendor_test/uses_cupy.py @@ -1,6 +1,10 @@ # Basic test that vendoring works -from .vendored._compat import cupy as cp_compat +from .vendored._compat import ( + cupy as cp_compat, + is_cupy_array, + is_cupy_namespace, +) import cupy as cp @@ -16,3 +20,6 @@ def _test_cupy(): assert isinstance(res, cp.ndarray) cp.testing.assert_allclose(res, [1., 2., 9.]) + + assert is_cupy_array(res) + assert is_cupy_namespace(cp) and is_cupy_namespace(cp_compat) diff --git a/vendor_test/uses_dask.py b/vendor_test/uses_dask.py index 65a00916..44fa8f2f 100644 --- a/vendor_test/uses_dask.py +++ b/vendor_test/uses_dask.py @@ -1,6 +1,7 @@ # Basic test that vendoring works from .vendored._compat.dask import array as dask_compat +from .vendored._compat import is_dask_array, is_dask_namespace import dask.array as da import numpy as np @@ -17,3 +18,6 @@ def _test_dask(): assert isinstance(res, da.Array) np.testing.assert_allclose(res, [1., 2., 9.]) + + assert is_dask_array(res) + assert is_dask_namespace(da) and is_dask_namespace(dask_compat) diff --git a/vendor_test/uses_numpy.py b/vendor_test/uses_numpy.py index 96f2c5ff..d7a68248 100644 --- a/vendor_test/uses_numpy.py +++ b/vendor_test/uses_numpy.py @@ -1,6 +1,11 @@ # Basic test that vendoring works -from .vendored._compat import numpy as np_compat +from .vendored._compat import ( + is_numpy_array, + is_numpy_namespace, + numpy as np_compat, +) + import numpy as np @@ -16,3 +21,6 @@ def _test_numpy(): assert isinstance(res, np.ndarray) np.testing.assert_allclose(res, [1., 2., 9.]) + + assert is_numpy_array(res) + assert is_numpy_namespace(np) and is_numpy_namespace(np_compat) diff --git a/vendor_test/uses_torch.py b/vendor_test/uses_torch.py index b828ad33..5804aaff 100644 --- a/vendor_test/uses_torch.py +++ b/vendor_test/uses_torch.py @@ -1,6 +1,10 @@ # Basic test that vendoring works -from .vendored._compat import torch as torch_compat +from .vendored._compat import ( + is_torch_array, + is_torch_namespace, + torch as torch_compat, +) import torch @@ -20,3 +24,7 @@ def _test_torch(): assert isinstance(res, torch.Tensor) torch.testing.assert_allclose(res, [[1., 2., 3.]]) + + assert is_torch_array(res) + assert is_torch_namespace(torch) and is_torch_namespace(torch_compat) +