diff --git a/docs/source/about_mapstyle_vs_iterable.mdx b/docs/source/about_mapstyle_vs_iterable.mdx index 1e9fa279e11..f794eea5714 100644 --- a/docs/source/about_mapstyle_vs_iterable.mdx +++ b/docs/source/about_mapstyle_vs_iterable.mdx @@ -139,12 +139,12 @@ But using a shuffle buffer is not enough to provide a satisfactory shuffling for ```python # Stream from the internet my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True) -my_iterable_dataset.n_shards # 39 +my_iterable_dataset.num_shards # 39 # Stream from local files data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]} my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True) -my_iterable_dataset.n_shards # 1024 +my_iterable_dataset.num_shards # 1024 # From a generator function def my_generator(n, sources): @@ -154,7 +154,7 @@ def my_generator(n, sources): gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]} my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs) -my_iterable_dataset.n_shards # 1024 +my_iterable_dataset.num_shards # 1024 ``` ## Speed differences @@ -242,5 +242,5 @@ my_iterable_dataset = my_dataset.to_iterable_dataset() If you want to shuffle your dataset or [use it with a PyTorch DataLoader](./use_with_pytorch#stream-data), we recommend generating a sharded [`IterableDataset`]: ```python my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024) -my_iterable_dataset.n_shards # 1024 +my_iterable_dataset.num_shards # 1024 ``` diff --git a/docs/source/use_with_pytorch.mdx b/docs/source/use_with_pytorch.mdx index 7f78d8de05c..375d6facc3e 100644 --- a/docs/source/use_with_pytorch.mdx +++ b/docs/source/use_with_pytorch.mdx @@ -216,7 +216,7 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi ```py >>> my_iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train") ->>> my_iterable_dataset.n_shards +>>> my_iterable_dataset.num_shards 39 >>> dataloader = DataLoader(my_iterable_dataset, batch_size=32, num_workers=4) ``` @@ -259,7 +259,7 @@ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of t For iterable datasets: -If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), +If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eb8fedce996..c2296a5ca99 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4630,32 +4630,31 @@ def shard( self, num_shards: int, index: int, - contiguous: bool = False, + contiguous: bool = True, keep_in_memory: bool = False, indices_cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, ) -> "Dataset": """Return the `index`-nth shard from dataset split into `num_shards` pieces. - This shards deterministically. `dset.shard(n, i)` will contain all elements of dset whose - index mod `n = i`. + This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks, + so it can be easily concatenated back together after processing. If `len(dataset) % n == l`, then the + first `l` dataset each have length `(len(dataset) // n) + 1`, and the remaining dataset have length `(len(dataset) // n)`. + `datasets.concatenate([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original. - `dset.shard(n, i, contiguous=True)` will instead split dset into contiguous chunks, - so it can be easily concatenated back together after processing. If `n % i == l`, then the - first `l` shards will have length `(n // i) + 1`, and the remaining shards will have length `(n // i)`. - `datasets.concatenate([dset.shard(n, i, contiguous=True) for i in range(n)])` will return - a dataset with the same order as the original. + Note: n should be less or equal to the number of elements in the dataset `len(dataset)`. + + On the other hand, `dataset.shard(n, i, contiguous=False)` contains all elements of the dataset whose index mod `n = i`. Be sure to shard before using any randomizing operator (such as `shuffle`). It is best if the shard operator is used early in the dataset pipeline. - Args: num_shards (`int`): How many shards to split the dataset into. index (`int`): Which shard to select and return. - contiguous: (`bool`, defaults to `False`): + contiguous: (`bool`, defaults to `True`): Whether to select contiguous blocks of indices for shards. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. @@ -4663,7 +4662,8 @@ def shard( Provide the name of a path for the cache file. It is used to store the indices of each shard instead of the automatically generated cache file name. writer_batch_size (`int`, defaults to `1000`): - Number of rows per write operation for the cache file writer. + This only concerns the indices mapping. + Number of indices per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `map`. diff --git a/src/datasets/distributed.py b/src/datasets/distributed.py index e036fabaf2c..4697948f342 100644 --- a/src/datasets/distributed.py +++ b/src/datasets/distributed.py @@ -18,7 +18,7 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D For iterable datasets: - If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), + If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index d1b54131b61..7381a2ae7d3 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -141,16 +141,23 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples """ raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") - def shard_data_sources(self, worker_id: int, num_workers: int) -> "_BaseExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") - def split_shard_indices_by_worker(self, worker_id: int, num_workers: int) -> List[int]: - return list(range(worker_id, self.n_shards, num_workers)) + def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous=True) -> List[int]: + if contiguous: + div = self.num_shards // num_shards + mod = self.num_shards % num_shards + start = div * index + min(index, mod) + end = start + div + (1 if index < mod else 0) + return list(range(start, end)) + else: + return list(range(index, self.num_shards, num_shards)) @property - def n_shards(self) -> int: - raise NotImplementedError(f"{type(self)} doesn't implement n_shards yet") + def num_shards(self) -> int: + raise NotImplementedError(f"{type(self)} doesn't implement num_shards yet") def _init_state_dict(self) -> dict: raise NotImplementedError(f"{type(self)} doesn't implement _init_state_dict yet") @@ -187,7 +194,7 @@ def _init_state_dict(self) -> dict: def __iter__(self): shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None): + for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 for key_example in islice(self.generate_examples_fn(**gen_kwags), shard_example_idx_start, None): if self._state_dict: @@ -200,15 +207,15 @@ def __iter__(self): def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": return ShuffledDataSourcesExamplesIterable(self.generate_examples_fn, self.kwargs, generator) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" - gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards) - shard_indices = self.split_shard_indices_by_worker(worker_id, num_workers) + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) + shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) return ExamplesIterable(self.generate_examples_fn, requested_gen_kwargs) @property - def n_shards(self) -> int: + def num_shards(self) -> int: return _number_of_shards_in_gen_kwargs(self.kwargs) @@ -229,7 +236,7 @@ def __iter__(self): kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None + _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None ): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 for key_example in islice(self.generate_examples_fn(**gen_kwags), shard_example_idx_start, None): @@ -240,12 +247,12 @@ def __iter__(self): self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0 - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ExamplesIterable(self.generate_examples_fn, kwargs_with_shuffled_shards).shard_data_sources( - worker_id, num_workers + num_shards, index, contiguous=contiguous ) @@ -266,7 +273,7 @@ def _init_state_dict(self) -> dict: def __iter__(self): formatter = PythonFormatter() shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None): + for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 for key, pa_table in self.generate_tables_fn(**gen_kwags): @@ -287,7 +294,7 @@ def __iter__(self): def _iter_arrow(self): shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None): + for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 for key, pa_table in self.generate_tables_fn(**gen_kwags): @@ -304,15 +311,15 @@ def _iter_arrow(self): def shuffle_data_sources(self, generator: np.random.Generator) -> "ArrowExamplesIterable": return ShuffledDataSourcesArrowExamplesIterable(self.generate_tables_fn, self.kwargs, generator) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" - gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards) - shard_indices = self.split_shard_indices_by_worker(worker_id, num_workers) + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) + shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) return ArrowExamplesIterable(self.generate_tables_fn, requested_gen_kwargs) @property - def n_shards(self) -> int: + def num_shards(self) -> int: return _number_of_shards_in_gen_kwargs(self.kwargs) @@ -337,7 +344,7 @@ def __iter__(self): formatter = PythonFormatter() shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None + _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None ): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 @@ -362,7 +369,7 @@ def _iter_arrow(self): kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None + _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None ): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 @@ -377,12 +384,12 @@ def _iter_arrow(self): self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0 - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources( - worker_id, num_workers + num_shards, index, contiguous=contiguous ) @@ -505,14 +512,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArro self.ex_iterable.shuffle_data_sources(generator), self.batch_size, self.drop_last_batch ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "RebatchedArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RebatchedArrowExamplesIterable": return RebatchedArrowExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), self.batch_size, self.drop_last_batch + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + self.batch_size, + self.drop_last_batch, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class SelectColumnsIterable(_BaseExamplesIterable): @@ -546,12 +555,14 @@ def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumnsIterable": return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "SelectColumnsIterable": - return SelectColumnsIterable(self.ex_iterable.shard_data_sources(worker_id, num_workers), self.column_names) + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SelectColumnsIterable": + return SelectColumnsIterable( + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.column_names + ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class StepExamplesIterable(_BaseExamplesIterable): @@ -584,14 +595,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "StepExamplesI self.ex_iterable.shuffle_data_sources(generator), step=self.step, offset=self.offset ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "StepExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "StepExamplesIterable": return StepExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), step=self.step, offset=self.offset + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + step=self.step, + offset=self.offset, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable): @@ -679,13 +692,15 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "CyclingMultiS return CyclingMultiSourcesExamplesIterable(ex_iterables, self.stopping_strategy) @property - def n_shards(self) -> int: - return min(ex_iterable.n_shards for ex_iterable in self.ex_iterables) + def num_shards(self) -> int: + return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "CyclingMultiSourcesExamplesIterable": + def shard_data_sources( + self, num_shards: int, index: int, contiguous=True + ) -> "CyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return CyclingMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables], + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables], stopping_strategy=self.stopping_strategy, ) @@ -748,15 +763,15 @@ def shuffle_data_sources( return VerticallyConcatenatedMultiSourcesExamplesIterable(ex_iterables) @property - def n_shards(self) -> int: - return min(ex_iterable.n_shards for ex_iterable in self.ex_iterables) + def num_shards(self) -> int: + return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) def shard_data_sources( - self, worker_id: int, num_workers: int + self, num_shards: int, index: int, contiguous=True ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return VerticallyConcatenatedMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables] + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -829,15 +844,15 @@ def shuffle_data_sources( return self @property - def n_shards(self) -> int: + def num_shards(self) -> int: return 1 def shard_data_sources( - self, worker_id: int, num_workers: int + self, num_shards: int, index: int, contiguous=True ) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return HorizontallyConcatenatedMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables] + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -907,10 +922,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RandomlyCycli stopping_strategy=self.stopping_strategy, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "RandomlyCyclingMultiSourcesExamplesIterable": + def shard_data_sources( + self, num_shards: int, index: int, contiguous=True + ) -> "RandomlyCyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return RandomlyCyclingMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables], + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables], self.generator, self.probabilities, self.stopping_strategy, @@ -1161,10 +1178,10 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample formatting=self.formatting, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "MappedExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable": """Keep only the requested shard.""" return MappedExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.function, with_indices=self.with_indices, input_columns=self.input_columns, @@ -1177,8 +1194,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "MappedExample ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class FilteredExamplesIterable(_BaseExamplesIterable): @@ -1381,10 +1398,10 @@ def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable batch_size=self.batch_size, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "FilteredExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FilteredExamplesIterable": """Keep only the requested shard.""" return FilteredExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.function, with_indices=self.with_indices, input_columns=self.input_columns, @@ -1393,8 +1410,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "FilteredExamp ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class BufferShuffledExamplesIterable(_BaseExamplesIterable): @@ -1451,17 +1468,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffle self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "BufferShuffledExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable": """Keep only the requested shard.""" return BufferShuffledExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), buffer_size=self.buffer_size, generator=self.generator, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class SkipExamplesIterable(_BaseExamplesIterable): @@ -1514,12 +1531,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesI split_when_sharding=self.split_when_sharding, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "SkipExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SkipExamplesIterable": """Keep only the requested shard.""" if self.split_when_sharding: return SkipExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), - n=self.split_number(self.n, num_workers)[worker_id], + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + n=self.split_number(self.n, num_shards)[index], block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, split_when_sharding=self.split_when_sharding, ) @@ -1527,8 +1544,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "SkipExamplesI return self @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class TakeExamplesIterable(_BaseExamplesIterable): @@ -1582,26 +1599,26 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesI split_when_sharding=self.split_when_sharding, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "TakeExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TakeExamplesIterable": """Keep only the requested shard.""" if self.split_when_sharding: return TakeExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), - n=self.split_number(self.n, num_workers)[worker_id], + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + n=self.split_number(self.n, num_shards)[index], block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, split_when_sharding=self.split_when_sharding, ) else: return TakeExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), n=self.n, block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, split_when_sharding=self.split_when_sharding, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards def _apply_feature_types_on_example( @@ -1690,17 +1707,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamples token_per_repo_id=self.token_per_repo_id, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "TypedExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TypedExamplesIterable": """Keep only the requested shard.""" return TypedExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), features=self.features, token_per_repo_id=self.token_per_repo_id, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards @dataclass @@ -1885,7 +1902,7 @@ def load_state_dict(self, state_dict: dict) -> None: self._starting_state_dict = state_dict def __repr__(self): - return f"IterableDataset({{\n features: {list(self._info.features.keys()) if self._info.features is not None else 'Unknown'},\n n_shards: {self.n_shards}\n}})" + return f"IterableDataset({{\n features: {list(self._info.features.keys()) if self._info.features is not None else 'Unknown'},\n num_shards: {self.num_shards}\n}})" def __getstate__(self): return self.__dict__ @@ -1916,10 +1933,14 @@ def _effective_generator(self): raise ValueError("This dataset is not shuffled") @property - def n_shards(self) -> int: - if self._distributed and self._ex_iterable.n_shards % self._distributed.world_size == 0: - return self._ex_iterable.n_shards // self._distributed.world_size - return self._ex_iterable.n_shards + def num_shards(self) -> int: + if self._distributed and self._ex_iterable.num_shards % self._distributed.world_size == 0: + return self._ex_iterable.num_shards // self._distributed.world_size + return self._ex_iterable.num_shards + + @property + def n_shards(self) -> int: # backward compatibility + return self.num_shards def _iter_pytorch(self): ex_iterable = self._prepare_ex_iterable_for_iteration() @@ -1930,24 +1951,26 @@ def _iter_pytorch(self): import torch.utils.data worker_info = torch.utils.data.get_worker_info() - if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers: + if self._is_main_process() and ex_iterable.num_shards < worker_info.num_workers: logger.warning( - f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). " - f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers." + f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.num_shards={ex_iterable.num_shards}). " + f"Stopping {worker_info.num_workers - ex_iterable.num_shards} dataloader workers." ) logger.info( f"To parallelize data loading, we give each process some shards (or data sources) to process. " - f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. " - f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}." + f"Therefore it's unnecessary to have a number of workers greater than dataset.num_shards={ex_iterable.num_shards}. " + f"To enable more parallelism, please split the dataset in more files than {ex_iterable.num_shards}." ) # split workload _log_prefix = f"node#{self._distributed.rank} " if self._distributed else "" shards_indices = ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers) if shards_indices: logger.debug( - f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards." + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.num_shards} shards." + ) + ex_iterable = ex_iterable.shard_data_sources( + num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) - ex_iterable = ex_iterable.shard_data_sources(worker_id=worker_info.id, num_workers=worker_info.num_workers) self._state_dict = ex_iterable._init_state_dict() if self._starting_state_dict: ex_iterable.load_state_dict(self._starting_state_dict) @@ -1978,11 +2001,11 @@ def _iter_pytorch(self): ) yield format_dict(example) if format_dict else example logger.debug( - f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.n_shards} shards." + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.num_shards} shards." ) else: logger.debug( - f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.n_shards}<{worker_info.num_workers})." + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.num_shards}<{worker_info.num_workers})." ) def _is_main_process(self): @@ -2012,14 +2035,14 @@ def _prepare_ex_iterable_for_iteration( if self._distributed: rank = self._distributed.rank world_size = self._distributed.world_size - if ex_iterable.n_shards % world_size == 0: + if ex_iterable.num_shards % world_size == 0: if self._is_main_process(): - n_shards_per_node = ex_iterable.n_shards // world_size - plural = "s" if n_shards_per_node > 1 else "" + num_shards_per_node = ex_iterable.num_shards // world_size + plural = "s" if num_shards_per_node > 1 else "" logger.info( - f"Assigning {n_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." + f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." ) - ex_iterable = ex_iterable.shard_data_sources(rank, world_size) + ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False) else: if self._is_main_process(): logger.info( @@ -2028,7 +2051,7 @@ def _prepare_ex_iterable_for_iteration( logger.info( f"It is more optimized to distribute the dataset shards (or data sources) across nodes. " f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. " - f"The current dataset has {ex_iterable.n_shards} which is not a factor of {world_size}" + f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}" ) ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank) @@ -2635,6 +2658,63 @@ def take(self, n: int) -> "IterableDataset": token_per_repo_id=self._token_per_repo_id, ) + def shard( + self, + num_shards: int, + index: int, + contiguous: bool = True, + ) -> "Dataset": + """Return the `index`-nth shard from dataset split into `num_shards` pieces. + + This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks, + so it can be easily concatenated back together after processing. If `dataset.num_shards % n == l`, then the + first `l` datasets each have `(dataset.num_shards // n) + 1` shards, and the remaining datasets have `(dataset.num_shards // n)` shards. + `datasets.concatenate([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original. + In particular, `dataset.shard(dataset.num_shards, i)` returns a dataset with 1 shard. + + Note: n should be less or equal to the number of shards in the dataset `dataset.num_shards`. + + On the other hand, `dataset.shard(n, i, contiguous=False)` contains all the shards of the dataset whose index mod `n = i`. + + Be sure to shard before using any randomizing operator (such as `shuffle`). + It is best if the shard operator is used early in the dataset pipeline. + + Args: + num_shards (`int`): + How many shards to split the dataset into. + index (`int`): + Which shard to select and return. + contiguous: (`bool`, defaults to `True`): + Whether to select contiguous blocks of indices for shards. + + Example: + + ```py + >>> from datasets import load_dataset + >>> ds = load_dataset("amazon_polarity", split="train", streaming=True) + >>> ds + Dataset({ + features: ['label', 'title', 'content'], + num_shards: 4 + }) + >>> ds.shard(num_shards=2, index=0) + Dataset({ + features: ['label', 'title', 'content'], + num_shards: 2 + }) + ``` + """ + ex_iterable = self._ex_iterable.shard_data_sources(num_shards=num_shards, index=index, contiguous=contiguous) + return IterableDataset( + ex_iterable=ex_iterable, + info=self._info.copy(), + split=self._split, + formatting=self._formatting, + shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), + token_per_repo_id=self._token_per_repo_id, + ) + @property def column_names(self) -> Optional[List[str]]: """Names of the columns in the dataset. @@ -3079,7 +3159,7 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s """ Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`. - If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), + If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py index c21cb3dd981..0d06e5db2c1 100644 --- a/src/datasets/packaged_modules/spark/spark.py +++ b/src/datasets/packaged_modules/spark/spark.py @@ -105,7 +105,7 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "SparkExamples return SparkExamplesIterable(self.df, partition_order=partition_order) @property - def n_shards(self) -> int: + def num_shards(self) -> int: return len(self.partition_order) diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py index c91bdd571ea..ce3a13bd794 100644 --- a/tests/packaged_modules/test_spark.py +++ b/tests/packaged_modules/test_spark.py @@ -72,7 +72,7 @@ def test_spark_examples_iterable(): spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() df = spark.range(10).repartition(1) it = SparkExamplesIterable(df) - assert it.n_shards == 1 + assert it.num_shards == 1 for i, (row_id, row_dict) in enumerate(it): assert row_id == f"0_{i}" assert row_dict == {"id": i} @@ -89,7 +89,7 @@ def test_spark_examples_iterable_shuffle(): expected_row_ids_and_row_dicts = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [2, 1, 0]) shuffled_it = SparkExamplesIterable(df).shuffle_data_sources(generator_mock) - assert shuffled_it.n_shards == 3 + assert shuffled_it.num_shards == 3 for i, (row_id, row_dict) in enumerate(shuffled_it): expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts[i] assert row_id == expected_row_id @@ -104,7 +104,7 @@ def test_spark_examples_iterable_shard(): # Partitions 0 and 2 shard_it_1 = SparkExamplesIterable(df).shard_data_sources(worker_id=0, num_workers=2) - assert shard_it_1.n_shards == 2 + assert shard_it_1.num_shards == 2 expected_row_ids_and_row_dicts_1 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [0, 2]) for i, (row_id, row_dict) in enumerate(shard_it_1): expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_1[i] @@ -113,7 +113,7 @@ def test_spark_examples_iterable_shard(): # Partitions 1 and 3 shard_it_2 = SparkExamplesIterable(df).shard_data_sources(worker_id=1, num_workers=2) - assert shard_it_2.n_shards == 2 + assert shard_it_2.num_shards == 2 expected_row_ids_and_row_dicts_2 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [1, 3]) for i, (row_id, row_dict) in enumerate(shard_it_2): expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_2[i] diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ffa048644e2..1e08862031b 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2630,10 +2630,12 @@ def test_shard(self, in_memory): tmp_file = os.path.join(tmp_dir, "test.arrow") with dset.select(range(10), indices_cache_file_name=tmp_file) as dset: self.assertEqual(len(dset), 10) - # Shard + # Shard non-contiguous tmp_file_1 = os.path.join(tmp_dir, "test_1.arrow") fingerprint = dset._fingerprint - with dset.shard(num_shards=8, index=1, indices_cache_file_name=tmp_file_1) as dset_sharded: + with dset.shard( + num_shards=8, index=1, contiguous=False, indices_cache_file_name=tmp_file_1 + ) as dset_sharded: self.assertEqual(2, len(dset_sharded)) self.assertEqual(["my_name-train_1", "my_name-train_9"], dset_sharded["filename"]) self.assertDictEqual(dset.features, Features({"filename": Value("string")})) @@ -4268,7 +4270,7 @@ def test_dataset_to_iterable_dataset(dataset: Dataset): assert isinstance(iterable_dataset, IterableDataset) assert list(iterable_dataset) == list(dataset) assert iterable_dataset.features == dataset.features - assert iterable_dataset.n_shards == 3 + assert iterable_dataset.num_shards == 3 with pytest.raises(ValueError): dataset.to_iterable_dataset(num_shards=len(dataset) + 1) with pytest.raises(NotImplementedError): diff --git a/tests/test_distributed.py b/tests/test_distributed.py index b8e0f56b180..65d2130f753 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -46,11 +46,11 @@ def gen(shards): gen_kwargs = {"shards": [f"shard_{shard_idx}.txt" for shard_idx in range(num_shards)]} full_ds = IterableDataset.from_generator(gen, gen_kwargs=gen_kwargs) full_size = len(list(full_ds)) - assert full_ds.n_shards == world_size * shards_per_node + assert full_ds.num_shards == world_size * shards_per_node datasets_per_rank = [ split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size) ] - assert [ds.n_shards for ds in datasets_per_rank] == [shards_per_node] * world_size + assert [ds.num_shards for ds in datasets_per_rank] == [shards_per_node] * world_size assert sum(len(list(ds)) for ds in datasets_per_rank) == full_size assert len({tuple(x.values()) for ds in datasets_per_rank for x in ds}) == full_size diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 232652f1fa3..afa6b8db5d0 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1277,7 +1277,7 @@ def gen(shard_names): shard_names = [f"data{shard_idx}.txt" for shard_idx in range(4)] dataset = IterableDataset.from_generator(gen, gen_kwargs={"shard_names": shard_names}) assert isinstance(dataset, IterableDataset) - assert dataset.n_shards == len(shard_names) + assert dataset.num_shards == len(shard_names) @require_numpy1_on_windows @@ -1392,11 +1392,11 @@ def test_iterable_dataset_torch_dataloader_parallel(): @require_torch @pytest.mark.filterwarnings("ignore:This DataLoader will create:UserWarning") -@pytest.mark.parametrize("n_shards, num_workers", [(2, 1), (2, 2), (3, 2), (2, 3)]) -def test_sharded_iterable_dataset_torch_dataloader_parallel(n_shards, num_workers): +@pytest.mark.parametrize("num_shards, num_workers", [(2, 1), (2, 2), (3, 2), (2, 3)]) +def test_sharded_iterable_dataset_torch_dataloader_parallel(num_shards, num_workers): from torch.utils.data import DataLoader - ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(n_shards)]}) + ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(num_shards)]}) dataset = IterableDataset(ex_iterable) dataloader = DataLoader(dataset, batch_size=None, num_workers=num_workers) result = list(dataloader) @@ -1686,8 +1686,10 @@ def test_iterable_dataset_take(dataset: IterableDataset, n): @pytest.mark.parametrize("count", [2, 5, 11]) def test_iterable_dataset_skip_or_take_after_shuffle(method, after_shuffle, count): seed = 42 - n, n_shards = 3, 10 - ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]}) + n, num_shards = 3, 10 + ex_iterable = ExamplesIterable( + generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(num_shards)]} + ) dataset = IterableDataset(ex_iterable) shuffled_dataset = dataset if after_shuffle: @@ -1714,9 +1716,11 @@ def test_iterable_dataset_skip_or_take_after_shuffle(method, after_shuffle, coun @pytest.mark.parametrize("after_split_by_node", [False, True]) @pytest.mark.parametrize("count", [2, 5, 11]) def test_iterable_dataset_skip_or_take_after_split_by_node(method, after_split_by_node, count): - n, n_shards = 3, 10 + n, num_shards = 3, 10 rank, world_size = 1, 2 - ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]}) + ex_iterable = ExamplesIterable( + generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(num_shards)]} + ) dataset = IterableDataset(ex_iterable) distributed_dataset = dataset true_distributed_dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) @@ -2114,17 +2118,17 @@ def add_one_numpy(example): assert isinstance(next(dataset.iter(batch_size=3))["id"], list) -@pytest.mark.parametrize("n_shards1, n_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)]) -def test_interleave_dataset_with_sharding(n_shards1, n_shards2, num_workers): +@pytest.mark.parametrize("num_shards1, num_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)]) +def test_interleave_dataset_with_sharding(num_shards1, num_shards2, num_workers): from torch.utils.data import DataLoader - ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(n_shards1)]}) + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(num_shards1)]}) dataset1 = IterableDataset(ex_iterable1).with_format("torch") - ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(n_shards2)]}) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(num_shards2)]}) dataset2 = IterableDataset(ex_iterable2).with_format("torch") dataset_merged = interleave_datasets([dataset1, dataset2], stopping_strategy="first_exhausted") - assert dataset_merged.n_shards == min(n_shards1, n_shards2) + assert dataset_merged.num_shards == min(num_shards1, num_shards2) dataloader = DataLoader(dataset_merged, batch_size=None, num_workers=num_workers) result = list(dataloader) expected_length = 2 * min(