Skip to content

Commit

Permalink
Revert JAX test change.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Oct 23, 2024
1 parent 1e4ef0a commit 9052149
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dali/test/python/jax_plugin/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value):
assert jax.numpy.array_equal(jax_array, jax.numpy.full(shape, value, dtype))

# Make sure JAX array is backed by the GPU
assert jax_array.device == jax.devices()[0]
assert jax_array.device() == jax.devices()[0]


def test_dali_sequential_tensors_to_jax_array():
Expand Down

0 comments on commit 9052149

Please sign in to comment.