Skip to content

Commit

Permalink
Add IterableDataset.shard() (#7252)
Browse files Browse the repository at this point in the history
* add IterableDataset.shard (and rename n_shards -> num_shards)

* docs

* add test

* fix tests

* again

* again

* again

* minor
  • Loading branch information
lhoestq authored Oct 25, 2024
1 parent 8413aac commit 65f6eb5
Show file tree
Hide file tree
Showing 13 changed files with 283 additions and 144 deletions.
8 changes: 4 additions & 4 deletions docs/source/about_mapstyle_vs_iterable.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ But using a shuffle buffer is not enough to provide a satisfactory shuffling for
```python
# Stream from the internet
my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True)
my_iterable_dataset.n_shards # 39
my_iterable_dataset.num_shards # 39

# Stream from local files
data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024

# From a generator function
def my_generator(n, sources):
Expand All @@ -154,7 +154,7 @@ def my_generator(n, sources):

gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024
```

## Speed differences
Expand Down Expand Up @@ -242,5 +242,5 @@ my_iterable_dataset = my_dataset.to_iterable_dataset()
If you want to shuffle your dataset or [use it with a PyTorch DataLoader](./use_with_pytorch#stream-data), we recommend generating a sharded [`IterableDataset`]:
```python
my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024
```
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
- batch
- skip
- take
- shard
- load_state_dict
- state_dict
- info
Expand Down
29 changes: 29 additions & 0 deletions docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,35 @@ You can split your dataset one of two ways:

<a id='interleave_datasets'></a>


### Shard

🤗 Datasets supports sharding to divide a very large dataset into a predefined number of chunks. Specify the `num_shards` parameter in [`~IterableDataset.shard`] to determine the number of shards to split the dataset into. You'll also need to provide the shard you want to return with the `index` parameter.

For example, the [amazon_polarity](https://huggingface.co/datasets/amazon_polarity) dataset has 4 shards (in this case they are 4 Parquet files):

```py
>>> from datasets import load_dataset
>>> dataset = load_dataset("amazon_polarity", split="train", streaming=True)
>>> print(dataset)
IterableDataset({
features: ['label', 'title', 'content'],
num_shards: 4
})
```

After sharding the dataset into two chunks, the first one will only have 2 shards:

```py
>>> dataset.shard(num_shards=2, index=0)
IterableDataset({
features: ['label', 'title', 'content'],
num_shards: 2
})
```

If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.

## Interleave

[`interleave_datasets`] can combine an [`IterableDataset`] with other datasets. The combined dataset returns alternating examples from each of the original datasets.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi

```py
>>> my_iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
>>> my_iterable_dataset.n_shards
>>> my_iterable_dataset.num_shards
39
>>> dataloader = DataLoader(my_iterable_dataset, batch_size=32, num_workers=4)
```
Expand Down Expand Up @@ -259,7 +259,7 @@ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of t

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

Expand Down
24 changes: 12 additions & 12 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _get_output_signature(
tf_dtype = tf.float32
np_dtype = np.float32
elif np_arrays[0].dtype.kind == "U": # Unicode strings
np_dtype = np.unicode_
np_dtype = np.str_
tf_dtype = tf.string
else:
raise RuntimeError(
Expand Down Expand Up @@ -4630,40 +4630,40 @@ def shard(
self,
num_shards: int,
index: int,
contiguous: bool = False,
contiguous: bool = True,
keep_in_memory: bool = False,
indices_cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
) -> "Dataset":
"""Return the `index`-nth shard from dataset split into `num_shards` pieces.
This shards deterministically. `dset.shard(n, i)` will contain all elements of dset whose
index mod `n = i`.
This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks,
so it can be easily concatenated back together after processing. If `len(dataset) % n == l`, then the
first `l` dataset each have length `(len(dataset) // n) + 1`, and the remaining dataset have length `(len(dataset) // n)`.
`datasets.concatenate_datasets([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original.
`dset.shard(n, i, contiguous=True)` will instead split dset into contiguous chunks,
so it can be easily concatenated back together after processing. If `n % i == l`, then the
first `l` shards will have length `(n // i) + 1`, and the remaining shards will have length `(n // i)`.
`datasets.concatenate([dset.shard(n, i, contiguous=True) for i in range(n)])` will return
a dataset with the same order as the original.
Note: n should be less or equal to the number of elements in the dataset `len(dataset)`.
On the other hand, `dataset.shard(n, i, contiguous=False)` contains all elements of the dataset whose index mod `n = i`.
Be sure to shard before using any randomizing operator (such as `shuffle`).
It is best if the shard operator is used early in the dataset pipeline.
Args:
num_shards (`int`):
How many shards to split the dataset into.
index (`int`):
Which shard to select and return.
contiguous: (`bool`, defaults to `False`):
contiguous: (`bool`, defaults to `True`):
Whether to select contiguous blocks of indices for shards.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
indices_cache_file_name (`str`, *optional*):
Provide the name of a path for the cache file. It is used to store the
indices of each shard instead of the automatically generated cache file name.
writer_batch_size (`int`, defaults to `1000`):
Number of rows per write operation for the cache file writer.
This only concerns the indices mapping.
Number of indices per write operation for the cache file writer.
This value is a good trade-off between memory usage during the processing, and processing speed.
Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `map`.
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
For iterable datasets:
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
Expand Down
Loading

0 comments on commit 65f6eb5

Please sign in to comment.