Skip to content

Commit

Permalink
add IterableDataset.shard (and rename n_shards -> num_shards)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 25, 2024
1 parent 8413aac commit b4a98f4
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 137 deletions.
8 changes: 4 additions & 4 deletions docs/source/about_mapstyle_vs_iterable.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ But using a shuffle buffer is not enough to provide a satisfactory shuffling for
```python
# Stream from the internet
my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True)
my_iterable_dataset.n_shards # 39
my_iterable_dataset.num_shards # 39

# Stream from local files
data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024

# From a generator function
def my_generator(n, sources):
Expand All @@ -154,7 +154,7 @@ def my_generator(n, sources):

gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024
```

## Speed differences
Expand Down Expand Up @@ -242,5 +242,5 @@ my_iterable_dataset = my_dataset.to_iterable_dataset()
If you want to shuffle your dataset or [use it with a PyTorch DataLoader](./use_with_pytorch#stream-data), we recommend generating a sharded [`IterableDataset`]:
```python
my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024)
my_iterable_dataset.n_shards # 1024
my_iterable_dataset.num_shards # 1024
```
4 changes: 2 additions & 2 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi

```py
>>> my_iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
>>> my_iterable_dataset.n_shards
>>> my_iterable_dataset.num_shards
39
>>> dataloader = DataLoader(my_iterable_dataset, batch_size=32, num_workers=4)
```
Expand Down Expand Up @@ -259,7 +259,7 @@ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of t

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

Expand Down
22 changes: 11 additions & 11 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4630,40 +4630,40 @@ def shard(
self,
num_shards: int,
index: int,
contiguous: bool = False,
contiguous: bool = True,
keep_in_memory: bool = False,
indices_cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
) -> "Dataset":
"""Return the `index`-nth shard from dataset split into `num_shards` pieces.
This shards deterministically. `dset.shard(n, i)` will contain all elements of dset whose
index mod `n = i`.
This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks,
so it can be easily concatenated back together after processing. If `len(dataset) % n == l`, then the
first `l` dataset each have length `(len(dataset) // n) + 1`, and the remaining dataset have length `(len(dataset) // n)`.
`datasets.concatenate([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original.
`dset.shard(n, i, contiguous=True)` will instead split dset into contiguous chunks,
so it can be easily concatenated back together after processing. If `n % i == l`, then the
first `l` shards will have length `(n // i) + 1`, and the remaining shards will have length `(n // i)`.
`datasets.concatenate([dset.shard(n, i, contiguous=True) for i in range(n)])` will return
a dataset with the same order as the original.
Note: n should be less or equal to the number of elements in the dataset `len(dataset)`.
On the other hand, `dataset.shard(n, i, contiguous=False)` contains all elements of the dataset whose index mod `n = i`.
Be sure to shard before using any randomizing operator (such as `shuffle`).
It is best if the shard operator is used early in the dataset pipeline.
Args:
num_shards (`int`):
How many shards to split the dataset into.
index (`int`):
Which shard to select and return.
contiguous: (`bool`, defaults to `False`):
contiguous: (`bool`, defaults to `True`):
Whether to select contiguous blocks of indices for shards.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
indices_cache_file_name (`str`, *optional*):
Provide the name of a path for the cache file. It is used to store the
indices of each shard instead of the automatically generated cache file name.
writer_batch_size (`int`, defaults to `1000`):
Number of rows per write operation for the cache file writer.
This only concerns the indices mapping.
Number of indices per write operation for the cache file writer.
This value is a good trade-off between memory usage during the processing, and processing speed.
Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `map`.
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
For iterable datasets:
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
Expand Down
Loading

0 comments on commit b4a98f4

Please sign in to comment.