From 9052149dbdd5ab43eadf82964a5091f6f2321a00 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Wed, 23 Oct 2024 15:36:54 +0200 Subject: [PATCH] Revert JAX test change. Signed-off-by: Michal Zientkiewicz --- dali/test/python/jax_plugin/test_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dali/test/python/jax_plugin/test_integration.py b/dali/test/python/jax_plugin/test_integration.py index dd1fe3de98d..8eff162ebb4 100644 --- a/dali/test/python/jax_plugin/test_integration.py +++ b/dali/test/python/jax_plugin/test_integration.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value): assert jax.numpy.array_equal(jax_array, jax.numpy.full(shape, value, dtype)) # Make sure JAX array is backed by the GPU - assert jax_array.device == jax.devices()[0] + assert jax_array.device() == jax.devices()[0] def test_dali_sequential_tensors_to_jax_array():