Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 25, 2024
1 parent d700230 commit f14caf6
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,27 @@ def test_iterable_dataset_take(dataset: IterableDataset, n):
assert list(take_dataset) == list(dataset)[:n]


def test_iterable_dataset_shard():
num_examples = 20
num_shards = 5
dataset = Dataset.from_dict({"a": range(num_examples)}).to_iterable_dataset(num_shards=num_shards)
assert sum(dataset.shard(num_shards, i).num_shards for i in range(num_shards)) == dataset.num_shards
assert list(concatenate_datasets([dataset.shard(num_shards, i) for i in range(num_shards)])) == list(dataset)
num_shards = 2
assert sum(dataset.shard(num_shards, i).num_shards for i in range(num_shards)) == dataset.num_shards
assert list(concatenate_datasets([dataset.shard(num_shards, i) for i in range(num_shards)])) == list(dataset)
assert (
sum(dataset.shard(num_shards, i, contiguous=False).num_shards for i in range(num_shards)) == dataset.num_shards
)
assert list(
concatenate_datasets([dataset.shard(num_shards, i, contiguous=False) for i in range(num_shards)])
) != list(dataset)
assert sorted(
concatenate_datasets([dataset.shard(num_shards, i, contiguous=False) for i in range(num_shards)]),
key=lambda x: x["a"],
) == list(dataset)


@pytest.mark.parametrize("method", ["skip", "take"])
@pytest.mark.parametrize("after_shuffle", [False, True])
@pytest.mark.parametrize("count", [2, 5, 11])
Expand Down

0 comments on commit f14caf6

Please sign in to comment.