Skip to content

Commit

Permalink
Add batching to IterableDataset (#7054)
Browse files Browse the repository at this point in the history
* feat: add `.batch() to `IterableDataset` and introduce new `BatchedExamplesIterable`

* style: formatting...

* refactor: implement feedback to use .map()

* test: add tests for new `batch()` method

* style: formatting...

* fix: remove type hints in `batch_fn()` to fix failing CI

* docs: add section "Batching data in IterableDataset" to "Differences between Dataset and IterableDataset"

* refactor: apply feedback

* docs nit

---------

Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
  • Loading branch information
2 people authored and albertvillanova committed Aug 13, 2024
1 parent 3a12370 commit 0689b4e
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 4 deletions.
4 changes: 0 additions & 4 deletions docs/source/about_mapstyle_vs_iterable.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,6 @@ for epoch in range(n_epochs):
pass
```

## Checkpoint and resuming differences

If you training loop stops, you may want to restart the training from where it was. To do so you can save a checkpoint of your model and optimizers, as well as your data loader.

To restart the iteration of a map-style dataset, you can simply skip the first examples:

```python
Expand Down
1 change: 1 addition & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
- rename_column
- filter
- shuffle
- batch
- skip
- take
- load_state_dict
Expand Down
38 changes: 38 additions & 0 deletions docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,44 @@ You can filter rows in the dataset based on a predicate function using [`Dataset
{'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)? Normally, ...'}]
```

## Batch

The `batch` method transforms your `IterableDataset` into an iterable of batches. This is particularly useful when you want to work with batches in your training loop or when using frameworks that expect batched inputs.

<Tip>

There is also a "Batch Processing" option when using the `map` function to apply a function to batches of data, which is discussed in the [Map section](#map) above. The `batch` method described here is different and provides a more direct way to create batches from your dataset.

</Tip>

You can use the `batch` method like this:

```python
from datasets import load_dataset

# Load a dataset in streaming mode
dataset = load_dataset("some_dataset", split="train", streaming=True)

# Create batches of 32 samples
batched_dataset = dataset.batch(batch_size=32)

# Iterate over the batched dataset
for batch in batched_dataset:
print(batch)
break
```

In this example, batched_dataset is still an IterableDataset, but each item yielded is now a batch of 32 samples instead of a single sample.
This batching is done on-the-fly as you iterate over the dataset, preserving the memory-efficient nature of IterableDataset.

The batch method also provides a drop_last_batch parameter.
When set to True, it will discard the last batch if it's smaller than the specified batch_size.
This can be useful in scenarios where your downstream processing requires all batches to be of the same size:

```python
batched_dataset = dataset.batch(batch_size=32, drop_last_batch=True)
```

## Stream in a training loop

[`IterableDataset`] can be integrated into a training loop. First, shuffle the dataset:
Expand Down
20 changes: 20 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2885,6 +2885,26 @@ def _resolve_features(self):
token_per_repo_id=self._token_per_repo_id,
)

def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableDataset":
"""
Group samples from the dataset into batches.
Args:
batch_size (`int`): The number of samples in each batch.
drop_last_batch (`bool`, defaults to `False`): Whether to drop the last incomplete batch.
Example:
```py
>>> ds = load_dataset("some_dataset", streaming=True)
>>> batched_ds = ds.batch(batch_size=32)
```
"""

def batch_fn(unbatched):
return {k: [v] for k, v in unbatched.items()}

return self.map(batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch)


def _concatenate_iterable_datasets(
dsets: List[IterableDataset],
Expand Down
51 changes: 51 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2176,3 +2176,54 @@ def test_resume_dataloader(dataset: IterableDataset):
dl = StatefulDataLoader(dataset)
dl.load_state_dict(state_dict)
assert remaining == list(dl)


def test_iterable_dataset_batch():
# Create a simple IterableDataset
data = [{"id": i, "text": f"Text {i}"} for i in range(10)]
ds = IterableDataset.from_generator(lambda: (x for x in data))

# Test with batch_size=3, drop_last_batch=False
batched_ds = ds.batch(batch_size=3, drop_last_batch=False)
batches = list(batched_ds)

assert len(batches) == 4 # 3 full batches and 1 partial batch
for i, batch in enumerate(batches[:3]): # Check full batches
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]

# Check last partial batch
assert len(batches[3]["id"]) == 1
assert len(batches[3]["text"]) == 1
assert batches[3]["id"] == [9]
assert batches[3]["text"] == ["Text 9"]

# Test with batch_size=3, drop_last_batch=True
batched_ds = ds.batch(batch_size=3, drop_last_batch=True)
batches = list(batched_ds)

assert len(batches) == 3 # Only full batches
for i, batch in enumerate(batches):
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]

# Test with batch_size=4 (doesn't evenly divide dataset size)
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
batches = list(batched_ds)

assert len(batches) == 3 # 2 full batches and 1 partial batch
for i, batch in enumerate(batches[:2]): # Check full batches
assert len(batch["id"]) == 4
assert len(batch["text"]) == 4
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]

# Check last partial batch
assert len(batches[2]["id"]) == 2
assert len(batches[2]["text"]) == 2
assert batches[2]["id"] == [8, 9]
assert batches[2]["text"] == ["Text 8", "Text 9"]

0 comments on commit 0689b4e

Please sign in to comment.