Skip to content

Commit

Permalink
Fix incorrect rank value in data splitting (#6994)
Browse files Browse the repository at this point in the history
* Fix incorrect rank value in data splitting (#6990)

* Add tests for splitting distributed datasets

* make style

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
  • Loading branch information
yzhangcs and lhoestq authored Jun 25, 2024
1 parent 1e1d313 commit 637246b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3013,8 +3013,8 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
[`IterableDataset`]: The iterable dataset to be used on the node at rank `rank`.
"""
if dataset._distributed:
world_size = world_size * dataset._distributed.world_size
rank = world_size * dataset._distributed.rank + rank
world_size = world_size * dataset._distributed.world_size
distributed = DistributedConfig(rank=rank, world_size=world_size)
return IterableDataset(
ex_iterable=dataset._ex_iterable,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ def gen(shards):
assert len({tuple(x.values()) for ds in datasets_per_rank for x in ds}) == full_size


def test_split_dataset_by_node_iterable_distributed():
def gen():
return ({"i": i} for i in range(100))

world_size = 3
num_workers = 3
full_ds = IterableDataset.from_generator(gen)
full_size = len(list(full_ds))
datasets_per_rank = [
split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size)
]
datasets_per_rank_per_worker = [
split_dataset_by_node(ds, rank=worker, world_size=num_workers)
for ds in datasets_per_rank
for worker in range(num_workers)
]
assert sum(len(list(ds)) for ds in datasets_per_rank_per_worker) == full_size
assert len({tuple(x.values()) for ds in datasets_per_rank_per_worker for x in ds}) == full_size


def test_distributed_shuffle_iterable():
def gen():
return ({"i": i} for i in range(17))
Expand Down

0 comments on commit 637246b

Please sign in to comment.