Skip to content

Commit

Permalink
Merge pull request #165 from lucascolley/164
Browse files Browse the repository at this point in the history
BUG: fix `array_namespace` for NumPy scalars
  • Loading branch information
asmeurer authored Jul 29, 2024
2 parents ff87838 + fd3462c commit d57c671
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 2 additions & 2 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def your_function(x, y):
else:
# numpy 2.0 has __array_namespace__ and is fully array API
# compatible.
if hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
if hasattr(np.empty(0), '__array_namespace__'):
namespaces.add(np.empty(0).__array_namespace__(api_version=api_version))
else:
namespaces.add(numpy_namespace)
elif is_cupy_array(x):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def test_array_namespace(library, api_version, use_compat):
else:
assert namespace == getattr(array_api_compat, library)

if library == "numpy":
# check that the same namespace is returned for NumPy scalars
scalar_namespace = array_api_compat.array_namespace(
xp.float64(0.0), api_version=api_version, use_compat=use_compat
)
assert scalar_namespace == namespace

# Check that array_namespace works even if jax.experimental.array_api
# hasn't been imported yet (it monkeypatches __array_namespace__
# onto JAX arrays, but we should support them regardless). The only way to
Expand Down

0 comments on commit d57c671

Please sign in to comment.