Skip to content

Commit

Permalink
Mark tests with require_numpy1_on_windows
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova committed Jul 12, 2024
1 parent c7fb3d9 commit f48e089
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from datasets.info import DatasetInfo
from datasets.utils.py_utils import asdict

from ..utils import require_jax, require_tf, require_torch
from ..utils import require_jax, require_numpy1_on_windows, require_tf, require_torch


class FeaturesTest(TestCase):
Expand Down Expand Up @@ -543,6 +543,7 @@ def test_cast_to_python_objects_pandas_timedelta(self):
casted_obj = cast_to_python_objects(pd.DataFrame({"a": [obj]}))
self.assertDictEqual(casted_obj, {"a": [expected_obj]})

@require_numpy1_on_windows
@require_torch
def test_cast_to_python_objects_torch(self):
import torch
Expand Down
3 changes: 2 additions & 1 deletion tests/packaged_modules/test_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datasets import Audio, DownloadManager, Features, Image, Sequence, Value
from datasets.packaged_modules.webdataset.webdataset import WebDataset

from ..utils import require_librosa, require_pil, require_sndfile, require_torch
from ..utils import require_librosa, require_numpy1_on_windows, require_pil, require_sndfile, require_torch


@pytest.fixture
Expand Down Expand Up @@ -226,6 +226,7 @@ def test_webdataset_with_features(image_wds_file):
assert isinstance(decoded["jpg"], PIL.Image.Image)


@require_numpy1_on_windows
@require_torch
def test_tensor_webdataset(tensor_wds_file):
import torch
Expand Down
4 changes: 4 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
require_dill_gt_0_3_2,
require_jax,
require_not_windows,
require_numpy1_on_windows,
require_pil,
require_polars,
require_pyspark,
Expand Down Expand Up @@ -420,6 +421,7 @@ def test_set_format_numpy_multiple_columns(self, in_memory):
self.assertIsInstance(dset[0]["col_2"], np.str_)
self.assertEqual(dset[0]["col_2"].item(), "a")

@require_numpy_on_windows
@require_torch
def test_set_format_torch(self, in_memory):
import torch
Expand Down Expand Up @@ -1525,6 +1527,7 @@ def func_return_multi_row_pd_dataframe(x):
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)

@require_numpy1_on_windows
@require_torch
def test_map_torch(self, in_memory):
import torch
Expand Down Expand Up @@ -1590,6 +1593,7 @@ def func(example):
)
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])

@require_numpy1_on_windows
@require_torch
def test_map_tensor_batched(self, in_memory):
import torch
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .utils import (
assert_arrow_memory_doesnt_increase,
assert_arrow_memory_increases,
require_numpy1_on_windows,
require_polars,
require_tf,
require_torch,
Expand Down Expand Up @@ -109,6 +110,7 @@ def test_set_format_numpy(self):
self.assertEqual(dset_split[0]["col_2"].item(), "a")
del dset

@require_numpy1_on_windows
@require_torch
def test_set_format_torch(self):
import torch
Expand Down
3 changes: 3 additions & 0 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .utils import (
require_not_windows,
require_numpy1_on_windows,
require_regex,
require_spacy,
require_tiktoken,
Expand Down Expand Up @@ -303,6 +304,7 @@ def test_hash_tiktoken_encoding(self):
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

@require_numpy1_on_windows
@require_torch
def test_hash_torch_tensor(self):
import torch
Expand All @@ -316,6 +318,7 @@ def test_hash_torch_tensor(self):
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

@require_numpy1_on_windows
@require_torch
def test_hash_torch_generator(self):
import torch
Expand Down
5 changes: 5 additions & 0 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .utils import (
require_jax,
require_librosa,
require_numpy1_on_windows,
require_pil,
require_polars,
require_sndfile,
Expand Down Expand Up @@ -353,6 +354,7 @@ def test_polars_formatter(self):
assert pl.Series.eq(batch["a"], pl.Series("a", _COL_A)).all()
assert pl.Series.eq(batch["b"], pl.Series("b", _COL_B)).all()

@require_numpy1_on_windows
@require_torch
def test_torch_formatter(self):
import torch
Expand All @@ -373,6 +375,7 @@ def test_torch_formatter(self):
torch.testing.assert_close(batch["c"], torch.tensor(_COL_C, dtype=torch.float32))
assert batch["c"].shape == np.array(_COL_C).shape

@require_numpy1_on_windows
@require_torch
def test_torch_formatter_torch_tensor_kwargs(self):
import torch
Expand All @@ -389,6 +392,7 @@ def test_torch_formatter_torch_tensor_kwargs(self):
self.assertEqual(batch["a"].dtype, torch.float16)
self.assertEqual(batch["c"].dtype, torch.float16)

@require_numpy1_on_windows
@require_torch
@require_pil
def test_torch_formatter_image(self):
Expand Down Expand Up @@ -975,6 +979,7 @@ def test_tf_formatter_sets_default_dtypes(cast_schema, arrow_table):
tf.debugging.assert_equal(batch["col_float"], tf.ragged.constant(list_float, dtype=tf.float32))


@require_numpy1_on_windows
@require_torch
@pytest.mark.parametrize(
"cast_schema",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
is_rng_equal,
require_dill_gt_0_3_2,
require_not_windows,
require_numpy1_on_windows,
require_pyspark,
require_tf,
require_torch,
Expand Down Expand Up @@ -1279,6 +1280,7 @@ def gen(shard_names):
assert dataset.n_shards == len(shard_names)


@require_numpy1_on_windows
def test_iterable_dataset_from_file(dataset: IterableDataset, arrow_file: str):
with assert_arrow_memory_doesnt_increase():
dataset_from_file = IterableDataset.from_file(arrow_file)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
zip_dict,
)

from .utils import require_tf, require_torch
from .utils import require_numpy1_on_windows, require_tf, require_torch


def np_sum(x): # picklable for multiprocessing
Expand Down Expand Up @@ -151,6 +151,7 @@ def gen_random_output():
np.testing.assert_equal(out1, out2)
self.assertGreater(np.abs(out1 - out3).sum(), 0)

@require_numpy1_on_windows
@require_torch
def test_torch(self):
import torch
Expand Down

0 comments on commit f48e089

Please sign in to comment.