diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index fd5e6f1ec51..3d0b3ce1cf3 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1431,10 +1431,18 @@ def n_shards(self) -> int: class SkipExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, n: int): + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + n: int, + block_sources_order_when_shuffling: bool = True, + split_when_sharding: bool = True, + ): super().__init__() self.ex_iterable = ex_iterable self.n = n + self.block_sources_order_when_shuffling = block_sources_order_when_shuffling + self.split_when_sharding = split_when_sharding # TODO(QL): implement iter_arrow def _init_state_dict(self) -> dict: @@ -1447,9 +1455,38 @@ def __iter__(self): self._state_dict["skipped"] = True yield from islice(self.ex_iterable, ex_iterable_idx_start, None) + @staticmethod + def split_number(num, n): + quotient = num // n + remainder = num % n + result = [quotient] * n + for i in range(remainder): + result[i] += 1 + return result + def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesIterable": - """Doesn't shuffle the wrapped examples iterable since it would skip examples from other shards instead.""" - return self + """May not shuffle the wrapped examples iterable since it would skip examples from other shards instead.""" + if self.block_sources_order_when_shuffling: + return self + else: + return SkipExamplesIterable( + self.ex_iterable.shuffle_data_sources(generator), + n=self.n, + block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, + split_when_sharding=self.split_when_sharding, + ) + + def shard_data_sources(self, worker_id: int, num_workers: int) -> "SkipExamplesIterable": + """Keep only the requested shard.""" + if self.split_when_sharding: + return SkipExamplesIterable( + self.ex_iterable.shard_data_sources(worker_id, num_workers), + n=self.split_number(self.n, num_workers)[worker_id], + block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, + split_when_sharding=self.split_when_sharding, + ) + else: + return self @property def n_shards(self) -> int: @@ -1457,10 +1494,18 @@ def n_shards(self) -> int: class TakeExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, n: int): + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + n: int, + block_sources_order_when_shuffling: bool = True, + split_when_sharding: bool = True, + ): super().__init__() self.ex_iterable = ex_iterable self.n = n + self.block_sources_order_when_shuffling = block_sources_order_when_shuffling + self.split_when_sharding = split_when_sharding # TODO(QL): implement iter_arrow def _init_state_dict(self) -> dict: @@ -1474,10 +1519,6 @@ def __iter__(self): self._state_dict["num_taken"] += 1 yield key_example - def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesIterable": - """Doesn't shuffle the wrapped examples iterable since it would take examples from other shards instead.""" - return self - @staticmethod def split_number(num, n): quotient = num // n @@ -1487,12 +1528,34 @@ def split_number(num, n): result[i] += 1 return result + def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesIterable": + """May not shuffle the wrapped examples iterable since it would take examples from other shards instead.""" + if self.block_sources_order_when_shuffling: + return self + else: + return TakeExamplesIterable( + self.ex_iterable.shuffle_data_sources(generator), + n=self.n, + block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, + split_when_sharding=self.split_when_sharding, + ) + def shard_data_sources(self, worker_id: int, num_workers: int) -> "TakeExamplesIterable": """Keep only the requested shard.""" - return TakeExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), - n=self.split_number(self.n, num_workers)[worker_id], - ) + if self.split_when_sharding: + return TakeExamplesIterable( + self.ex_iterable.shard_data_sources(worker_id, num_workers), + n=self.split_number(self.n, num_workers)[worker_id], + block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, + split_when_sharding=self.split_when_sharding, + ) + else: + return TakeExamplesIterable( + self.ex_iterable.shard_data_sources(worker_id, num_workers), + n=self.n, + block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, + split_when_sharding=self.split_when_sharding, + ) @property def n_shards(self) -> int: @@ -2392,7 +2455,7 @@ def shuffle( return IterableDataset( ex_iterable=BufferShuffledExamplesIterable( self._ex_iterable, buffer_size=buffer_size, generator=generator - ).shuffle_data_sources(generator), + ), info=self._info.copy(), split=self._split, formatting=self._formatting, @@ -2432,7 +2495,12 @@ def skip(self, n: int) -> "IterableDataset": 'text': 'if you sometimes like to go to the movies to have fun , wasabi is a good place to start .'}] ``` """ - ex_iterable = SkipExamplesIterable(self._ex_iterable, n) + ex_iterable = SkipExamplesIterable( + self._ex_iterable, + n, + block_sources_order_when_shuffling=self._shuffling is None, + split_when_sharding=self._distributed is None, + ) return IterableDataset( ex_iterable=ex_iterable, info=self._info.copy(), @@ -2464,7 +2532,12 @@ def take(self, n: int) -> "IterableDataset": 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}] ``` """ - ex_iterable = TakeExamplesIterable(self._ex_iterable, n) + ex_iterable = TakeExamplesIterable( + self._ex_iterable, + n, + block_sources_order_when_shuffling=self._shuffling is None, + split_when_sharding=self._distributed is None, + ) return IterableDataset( ex_iterable=ex_iterable, info=self._info.copy(), diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 1221a842012..6d711476e94 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -72,7 +72,7 @@ def test_get_dataset_config_info_error(path, config_name, expected_exception): ("squad", ["plain_text"]), ("hf-internal-testing/dataset_with_script", ["default"]), ("dalle-mini/wit", ["default"]), - ("hf-internal-testing/librispeech_asr_dummy", ["clean", "other"]), + ("hf-internal-testing/librispeech_asr_dummy", ["clean"]), ("hf-internal-testing/audiofolder_no_configs_in_metadata", ["default"]), ("hf-internal-testing/audiofolder_single_config_in_metadata", ["custom"]), ("hf-internal-testing/audiofolder_two_configs_in_metadata", ["v1", "v2"]), @@ -90,7 +90,7 @@ def test_get_dataset_config_names(path, expected): ("squad", "plain_text"), ("hf-internal-testing/dataset_with_script", "default"), ("dalle-mini/wit", "default"), - ("hf-internal-testing/librispeech_asr_dummy", None), + ("hf-internal-testing/librispeech_asr_dummy", "clean"), ("hf-internal-testing/audiofolder_no_configs_in_metadata", "default"), ("hf-internal-testing/audiofolder_single_config_in_metadata", "custom"), ("hf-internal-testing/audiofolder_two_configs_in_metadata", None), diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 4ca143799f7..6d1f84b7d6a 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -10,6 +10,7 @@ from datasets import Dataset, load_dataset from datasets.combine import concatenate_datasets, interleave_datasets +from datasets.distributed import split_dataset_by_node from datasets.features import ( ClassLabel, Features, @@ -1657,17 +1658,60 @@ def test_iterable_dataset_take(dataset: IterableDataset, n): @pytest.mark.parametrize("method", ["skip", "take"]) -def test_iterable_dataset_shuffle_after_skip_or_take(method): +@pytest.mark.parametrize("after_shuffle", [False, True]) +@pytest.mark.parametrize("count", [2, 5, 11]) +def test_iterable_dataset_skip_or_take_after_shuffle(method, after_shuffle, count): seed = 42 n, n_shards = 3, 10 - count = 7 ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]}) dataset = IterableDataset(ex_iterable) - dataset = dataset.skip(n) if method == "skip" else dataset.take(count) - shuffled_dataset = dataset.shuffle(seed, buffer_size=DEFAULT_N_EXAMPLES) - # shuffling a skip/take dataset should keep the same examples and don't shuffle the shards - key = lambda x: f"{x['filepath']}_{x['id']}" # noqa: E731 - assert sorted(dataset, key=key) == sorted(shuffled_dataset, key=key) + shuffled_dataset = dataset + if after_shuffle: + shuffled_dataset = shuffled_dataset.shuffle(seed, buffer_size=DEFAULT_N_EXAMPLES) + shuffled_dataset = shuffled_dataset.skip(count) if method == "skip" else shuffled_dataset.take(count) + # skip/take a shuffled dataset should not keep the same examples and shuffle the shards + key = lambda x: f"{x['filepath']}_{x['id']}" # noqa: E731 + assert (len(list(dataset)) - count if method == "skip" else count) == len(list(shuffled_dataset)) + assert sorted(list(dataset)[count:] if method == "skip" else list(dataset)[:count], key=key) != sorted( + shuffled_dataset, key=key + ) + else: + shuffled_dataset = shuffled_dataset.skip(count) if method == "skip" else shuffled_dataset.take(count) + shuffled_dataset = shuffled_dataset.shuffle(seed, buffer_size=DEFAULT_N_EXAMPLES) + # shuffling a skip/take dataset should keep the same examples and don't shuffle the shards + key = lambda x: f"{x['filepath']}_{x['id']}" # noqa: E731 + assert (len(list(dataset)) - count if method == "skip" else count) == len(list(shuffled_dataset)) + assert sorted(list(dataset)[count:] if method == "skip" else list(dataset)[:count], key=key) == sorted( + shuffled_dataset, key=key + ) + + +@pytest.mark.parametrize("method", ["skip", "take"]) +@pytest.mark.parametrize("after_split_by_node", [False, True]) +@pytest.mark.parametrize("count", [2, 5, 11]) +def test_iterable_dataset_skip_or_take_after_split_by_node(method, after_split_by_node, count): + n, n_shards = 3, 10 + rank, world_size = 1, 2 + ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]}) + dataset = IterableDataset(ex_iterable) + distributed_dataset = dataset + true_distributed_dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + if after_split_by_node: + distributed_dataset = split_dataset_by_node(distributed_dataset, rank=rank, world_size=world_size) + distributed_dataset = distributed_dataset.skip(count) if method == "skip" else distributed_dataset.take(count) + assert ( + list(true_distributed_dataset)[count:] + if method == "skip" + else list(true_distributed_dataset)[:count] == list(distributed_dataset) + ) + else: + distributed_dataset = distributed_dataset.skip(count) if method == "skip" else distributed_dataset.take(count) + distributed_dataset = split_dataset_by_node(distributed_dataset, rank=rank, world_size=world_size) + assert len( + list(true_distributed_dataset)[count // world_size :] + if method == "skip" + else list(true_distributed_dataset)[: count // world_size] + ) == len(list(distributed_dataset)) def test_iterable_dataset_add_column(dataset_with_several_columns: IterableDataset):