From f48e089bc3cfdf6f42fcdc71d3ab133fff4528e4 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 12 Jul 2024 10:44:19 +0200 Subject: [PATCH] Mark tests with require_numpy1_on_windows --- tests/features/test_features.py | 3 ++- tests/packaged_modules/test_webdataset.py | 3 ++- tests/test_arrow_dataset.py | 4 ++++ tests/test_dataset_dict.py | 2 ++ tests/test_fingerprint.py | 3 +++ tests/test_formatting.py | 5 +++++ tests/test_iterable_dataset.py | 2 ++ tests/test_py_utils.py | 3 ++- 8 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 381e98b8c14..913d8edeebe 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -25,7 +25,7 @@ from datasets.info import DatasetInfo from datasets.utils.py_utils import asdict -from ..utils import require_jax, require_tf, require_torch +from ..utils import require_jax, require_numpy1_on_windows, require_tf, require_torch class FeaturesTest(TestCase): @@ -543,6 +543,7 @@ def test_cast_to_python_objects_pandas_timedelta(self): casted_obj = cast_to_python_objects(pd.DataFrame({"a": [obj]})) self.assertDictEqual(casted_obj, {"a": [expected_obj]}) + @require_numpy1_on_windows @require_torch def test_cast_to_python_objects_torch(self): import torch diff --git a/tests/packaged_modules/test_webdataset.py b/tests/packaged_modules/test_webdataset.py index 6cdd53b6cbf..128f13022fc 100644 --- a/tests/packaged_modules/test_webdataset.py +++ b/tests/packaged_modules/test_webdataset.py @@ -7,7 +7,7 @@ from datasets import Audio, DownloadManager, Features, Image, Sequence, Value from datasets.packaged_modules.webdataset.webdataset import WebDataset -from ..utils import require_librosa, require_pil, require_sndfile, require_torch +from ..utils import require_librosa, require_numpy1_on_windows, require_pil, require_sndfile, require_torch @pytest.fixture @@ -226,6 +226,7 @@ def test_webdataset_with_features(image_wds_file): assert isinstance(decoded["jpg"], PIL.Image.Image) +@require_numpy1_on_windows @require_torch def test_tensor_webdataset(tensor_wds_file): import torch diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index efa7b7ae4c8..3567c54d71f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -57,6 +57,7 @@ require_dill_gt_0_3_2, require_jax, require_not_windows, + require_numpy1_on_windows, require_pil, require_polars, require_pyspark, @@ -420,6 +421,7 @@ def test_set_format_numpy_multiple_columns(self, in_memory): self.assertIsInstance(dset[0]["col_2"], np.str_) self.assertEqual(dset[0]["col_2"].item(), "a") + @require_numpy_on_windows @require_torch def test_set_format_torch(self, in_memory): import torch @@ -1525,6 +1527,7 @@ def func_return_multi_row_pd_dataframe(x): with self._create_dummy_dataset(in_memory, tmp_dir) as dset: self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe) + @require_numpy1_on_windows @require_torch def test_map_torch(self, in_memory): import torch @@ -1590,6 +1593,7 @@ def func(example): ) self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3]) + @require_numpy1_on_windows @require_torch def test_map_tensor_batched(self, in_memory): import torch diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index e6e801087e2..75419c55778 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -16,6 +16,7 @@ from .utils import ( assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, + require_numpy1_on_windows, require_polars, require_tf, require_torch, @@ -109,6 +110,7 @@ def test_set_format_numpy(self): self.assertEqual(dset_split[0]["col_2"].item(), "a") del dset + @require_numpy1_on_windows @require_torch def test_set_format_torch(self): import torch diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index 5b22e467f1f..0b7a45458bd 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -21,6 +21,7 @@ from .utils import ( require_not_windows, + require_numpy1_on_windows, require_regex, require_spacy, require_tiktoken, @@ -303,6 +304,7 @@ def test_hash_tiktoken_encoding(self): self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) + @require_numpy1_on_windows @require_torch def test_hash_torch_tensor(self): import torch @@ -316,6 +318,7 @@ def test_hash_torch_tensor(self): self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) + @require_numpy1_on_windows @require_torch def test_hash_torch_generator(self): import torch diff --git a/tests/test_formatting.py b/tests/test_formatting.py index e190356f216..147822fa8a1 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -21,6 +21,7 @@ from .utils import ( require_jax, require_librosa, + require_numpy1_on_windows, require_pil, require_polars, require_sndfile, @@ -353,6 +354,7 @@ def test_polars_formatter(self): assert pl.Series.eq(batch["a"], pl.Series("a", _COL_A)).all() assert pl.Series.eq(batch["b"], pl.Series("b", _COL_B)).all() + @require_numpy1_on_windows @require_torch def test_torch_formatter(self): import torch @@ -373,6 +375,7 @@ def test_torch_formatter(self): torch.testing.assert_close(batch["c"], torch.tensor(_COL_C, dtype=torch.float32)) assert batch["c"].shape == np.array(_COL_C).shape + @require_numpy1_on_windows @require_torch def test_torch_formatter_torch_tensor_kwargs(self): import torch @@ -389,6 +392,7 @@ def test_torch_formatter_torch_tensor_kwargs(self): self.assertEqual(batch["a"].dtype, torch.float16) self.assertEqual(batch["c"].dtype, torch.float16) + @require_numpy1_on_windows @require_torch @require_pil def test_torch_formatter_image(self): @@ -975,6 +979,7 @@ def test_tf_formatter_sets_default_dtypes(cast_schema, arrow_table): tf.debugging.assert_equal(batch["col_float"], tf.ragged.constant(list_float, dtype=tf.float32)) +@require_numpy1_on_windows @require_torch @pytest.mark.parametrize( "cast_schema", diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 8069c748396..6d21eda3863 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -51,6 +51,7 @@ is_rng_equal, require_dill_gt_0_3_2, require_not_windows, + require_numpy1_on_windows, require_pyspark, require_tf, require_torch, @@ -1279,6 +1280,7 @@ def gen(shard_names): assert dataset.n_shards == len(shard_names) +@require_numpy1_on_windows def test_iterable_dataset_from_file(dataset: IterableDataset, arrow_file: str): with assert_arrow_memory_doesnt_increase(): dataset_from_file = IterableDataset.from_file(arrow_file) diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index 1b618987392..b768ad54ecd 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -18,7 +18,7 @@ zip_dict, ) -from .utils import require_tf, require_torch +from .utils import require_numpy1_on_windows, require_tf, require_torch def np_sum(x): # picklable for multiprocessing @@ -151,6 +151,7 @@ def gen_random_output(): np.testing.assert_equal(out1, out2) self.assertGreater(np.abs(out1 - out3).sum(), 0) + @require_numpy1_on_windows @require_torch def test_torch(self): import torch