Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic JAX support #84

Merged
merged 20 commits into from
Feb 9, 2024
Merged

Add basic JAX support #84

merged 20 commits into from
Feb 9, 2024

Conversation

asmeurer
Copy link
Member

@asmeurer asmeurer commented Feb 6, 2024

Unlike other modules, JAX array API support is fully in JAX itself in the jax.experimental.array_api submodule, so the only thing that is done here is to add JAX support to the helper functions. This also means that we do not run array-api-tests on JAX.

This also makes the various is_numpy_array, is_cupy_array, etc. functions public, as I noticed someone was using them on GitHub and they seem like they could be useful.

Unlike other libraries, there is no wrapping for JAX. Actual JAX array_api
support is in JAX itself in the jax.experimental.array_api submodule. This
just adds JAX support to the various helper functions. This also means that we
do not run array-api-tests on JAX.

Closes data-apis#83.
@asmeurer asmeurer mentioned this pull request Feb 6, 2024
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp, I could mostly use your review for the changes in this file.

# jax.numpy is already an array namespace, but requires this
# side-effecting import for __array_namespace__ and some other
# things to be defined.
import jax.experimental.array_api as jnp
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this import go away at some point? Should we guard against that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not true that jax.numpy is already an array namespace. jax.experimental.array_api is an array namespace, and we hope to make jax.numpy an array namespace in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I fixed the comment. The question still remains though? Should I add a guard here like

try:
    import jax.experimental.array_api as jnp
except ImportError:
    import jax.numpy as jnp

for a future JAX version when jax.experimental.array_api gets removed? Or will it never be removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that's probably a reasonable way to future-proof this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside of that particular line of code is it will also pass through jax.numpy for older JAX versions that don't have jax.experimental.array_api. I like to avoid explicit version checks in this library if I can, but maybe that's the best thing to do here.

Or maybe we can just change the logic to this once JAX starts to remove (deprecates?) the experimental import.

# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this logic seem OK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks OK, but it will not work with the to_device function as currently defined in jax.experimental.array_api. That expects to be passed the bound method object – it's a hack, and one of the reasons that this is still considered experimental.

Copy link
Contributor

@jakevdp jakevdp Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other problem here is that, in general, JAX arrays can live on multiple devices (in which case arr.device() will error). It's not clear to me from the Array API design how multi-device arrays are meant to be supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the answer to that. I would bring it up on the array API repo. https://github.com/data-apis/array-api/. As far as I know it hasn't really been discussed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like both forms work

>>> import jax.experimental.array_api as xp
>>> x = xp.asarray([1, 2, 3])
>>> x.to_device(x.device)
Array([1, 2, 3], dtype=int32)
>>> x.to_device(x.device())
Array([1, 2, 3], dtype=int32)

import jax
if stream is not None:
raise NotImplementedError
return jax.device_put(x, device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this helper function for to_device correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, device_put would be the right method here.


import jax.numpy

return isinstance(x, jax.numpy.ndarray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(x, jax.Array) would be more concise.

@asmeurer
Copy link
Member Author

asmeurer commented Feb 6, 2024

Oh we should also probably add a basic test for the device helpers at https://github.com/data-apis/array-api-compat/blob/main/tests/test_common.py

@asmeurer
Copy link
Member Author

asmeurer commented Feb 6, 2024

Apparently we've been supporting "cpu" as a special host device here (#40). This is still being discussed for the standard data-apis/array-api#626.

I'm actually unsure if it's a good idea for us to be supporting that here given that we haven't really agreed about it in the standard. But if we did want to support it for JAX, how would we? I don't see how to actually access jax.CpuDevice.

@jakevdp
Copy link
Contributor

jakevdp commented Feb 6, 2024

jax.devices('cpu')[0] would return the first CPU device (if available).

@asmeurer asmeurer mentioned this pull request Feb 6, 2024
Copy link
Contributor

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good!


return isinstance(x, jax.numpy.ndarray)
return isinstance(x, jax.Array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be more guarded? e.g. what if someone has a module named jax.py in their path, it could lead to an AttributeError here and make array_api_compat unusable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. That sort of thing tends to just break everything anyway. I've never really felt that libraries should protect against that sort of thing.

Anyway, the whole point of this function is to be guarded. It won't import jax unless it's already been imported, because jax is a slow import and unnecessary if a user is using another array library.

Copy link
Contributor

@jakevdp jakevdp Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More precisely, it won't import jax unless something named jax has been imported, and then array_api_compat will irrecoverably error if that jax module isn't the one we're expecting, or is an older version that doesn't define the attributes we reference.

A similar issue exists for every other package name referenced in this module.

My feeling is: it costs virtually nothing to wrap this all in an appropriate try ... except, so why not do so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not so sure this is a good idea. Usually if you have that sort of thing it will be an error for a lot of things, not just array_api_compat. My worry here is that guarding isn't as straightforward as it might seem. Wrapping everything in try/except could mean we end up silencing legitimate errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good.

@asmeurer
Copy link
Member Author

asmeurer commented Feb 8, 2024

Removed the "cpu" device logic. It looks like that isn't going to be part of the standard. I opened #86 about removing it for cupy as well.

This requires using subprocess to test that it works even if the
side-effecting jax.experimental.array_api hasn't been imported yet.
@asmeurer asmeurer merged commit 645f9a8 into data-apis:main Feb 9, 2024
9 of 27 checks passed
@adityagoel4512 adityagoel4512 mentioned this pull request Jun 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants