From fd3462cffff9bde099af9d1c6b2fc1629595dc36 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sat, 27 Jul 2024 13:32:37 +0000 Subject: [PATCH] BUG: fix `array_namespace` for NumPy scalars --- array_api_compat/common/_helpers.py | 4 ++-- tests/test_array_namespace.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b55b16e2..93a50d87 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -342,8 +342,8 @@ def your_function(x, y): else: # numpy 2.0 has __array_namespace__ and is fully array API # compatible. - if hasattr(x, '__array_namespace__'): - namespaces.add(x.__array_namespace__(api_version=api_version)) + if hasattr(np.empty(0), '__array_namespace__'): + namespaces.add(np.empty(0).__array_namespace__(api_version=api_version)) else: namespaces.add(numpy_namespace) elif is_cupy_array(x): diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 1f83a473..8707b05a 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -36,6 +36,13 @@ def test_array_namespace(library, api_version, use_compat): else: assert namespace == getattr(array_api_compat, library) + if library == "numpy": + # check that the same namespace is returned for NumPy scalars + scalar_namespace = array_api_compat.array_namespace( + xp.float64(0.0), api_version=api_version, use_compat=use_compat + ) + assert scalar_namespace == namespace + # Check that array_namespace works even if jax.experimental.array_api # hasn't been imported yet (it monkeypatches __array_namespace__ # onto JAX arrays, but we should support them regardless). The only way to