Skip to content

Commit

Permalink
Check JAX version and invoke __dlpack__ manually for jax pre-0.4.16.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Nov 6, 2024
1 parent 1119827 commit 0beb511
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion dali/python/nvidia/dali/plugin/jax/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
from nvidia.dali.backend import TensorGPU


_jax_version_pre_0_4_16 = None


def _jax_has_old_dlpack():
global _jax_version_pre_0_4_16
if _jax_version_pre_0_4_16 is not None:
return _jax_version_pre_0_4_16

from packaging.version import Version

_jax_version_pre_0_4_16 = Version(jax.__version__) < Version("0.4.16")
return _jax_version_pre_0_4_16


def _to_jax_array(dali_tensor: TensorGPU) -> jax.Array:
"""Converts input DALI tensor to JAX array.
Expand All @@ -35,7 +49,10 @@ def _to_jax_array(dali_tensor: TensorGPU) -> jax.Array:
jax.Array: JAX array with the same values and backing device as
input DALI tensor.
"""
jax_array = jax.dlpack.from_dlpack(dali_tensor)
if _jax_has_old_dlpack():
jax_array = jax.dlpack.from_dlpack(dali_tensor.__dlpack__(stream=None))
else:
jax_array = jax.dlpack.from_dlpack(dali_tensor)

# For now we need this copy to make sure that underlying memory is available.
# One solution is to implement full DLPack contract in DALI.
Expand Down

0 comments on commit 0beb511

Please sign in to comment.