Skip to content

Commit

Permalink
add split argument to Generator (#7015)
Browse files Browse the repository at this point in the history
* add split argument to Generator, from_generator, AbstractDatasetInputStream, GeneratorDatasetInputStream

* split generator review feedbacks

* import Split

* tag added version in iterable_dataset, rollback change in _concatenate_iterable_datasets

* rm useless Generator __init__

* docstring formatting

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>

* format docstring

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>

* fix test_dataset_from_generator_split[None]

---------

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
  • Loading branch information
piercus and albertvillanova committed Aug 14, 2024
1 parent 7e79fe2 commit 786998a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
6 changes: 6 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ def from_generator(
keep_in_memory: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
**kwargs,
):
"""Create a Dataset from a generator.
Expand All @@ -1090,6 +1091,10 @@ def from_generator(
If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`.
<Added version="2.7.0"/>
split ([`NamedSplit`], defaults to `Split.TRAIN`):
Split name to be assigned to the dataset.
<Added version="2.21.0"/>
**kwargs (additional keyword arguments):
Keyword arguments to be passed to :[`GeneratorConfig`].
Expand Down Expand Up @@ -1126,6 +1131,7 @@ def from_generator(
keep_in_memory=keep_in_memory,
gen_kwargs=gen_kwargs,
num_proc=num_proc,
split=split,
**kwargs,
).read()

Expand Down
8 changes: 5 additions & 3 deletions src/datasets/io/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Optional

from .. import Features
from .. import Features, NamedSplit, Split
from ..packaged_modules.generator.generator import Generator
from .abc import AbstractDatasetInputStream

Expand All @@ -15,6 +15,7 @@ def __init__(
streaming: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
**kwargs,
):
super().__init__(
Expand All @@ -30,13 +31,14 @@ def __init__(
features=features,
generator=generator,
gen_kwargs=gen_kwargs,
split=split,
**kwargs,
)

def read(self):
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split="train")
dataset = self.builder.as_streaming_dataset(split=self.builder.config.split)
# Build regular (map-style) dataset
else:
download_config = None
Expand All @@ -52,6 +54,6 @@ def read(self):
num_proc=self.num_proc,
)
dataset = self.builder.as_dataset(
split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory
split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
)
return dataset
11 changes: 6 additions & 5 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects
from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter
from .info import DatasetInfo
from .splits import NamedSplit
from .splits import NamedSplit, Split
from .table import cast_table_to_features, read_schema_from_file, table_cast
from .utils.logging import get_logger
from .utils.py_utils import Literal
Expand Down Expand Up @@ -2083,6 +2083,7 @@ def from_generator(
generator: Callable,
features: Optional[Features] = None,
gen_kwargs: Optional[dict] = None,
split: NamedSplit = Split.TRAIN,
) -> "IterableDataset":
"""Create an Iterable Dataset from a generator.
Expand All @@ -2095,7 +2096,10 @@ def from_generator(
Keyword arguments to be passed to the `generator` callable.
You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`.
This can be used to improve shuffling and when iterating over the dataset with multiple workers.
split ([`NamedSplit`], defaults to `Split.TRAIN`):
Split name to be assigned to the dataset.
<Added version="2.21.0"/>
Returns:
`IterableDataset`
Expand Down Expand Up @@ -2126,10 +2130,7 @@ def from_generator(
from .io.generator import GeneratorDatasetInputStream

return GeneratorDatasetInputStream(
generator=generator,
features=features,
gen_kwargs=gen_kwargs,
streaming=True,
generator=generator, features=features, gen_kwargs=gen_kwargs, streaming=True, split=split
).read()

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/packaged_modules/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class GeneratorConfig(datasets.BuilderConfig):
generator: Optional[Callable] = None
gen_kwargs: Optional[dict] = None
features: Optional[datasets.Features] = None
split: datasets.NamedSplit = datasets.Split.TRAIN

def __post_init__(self):
super().__post_init__()
Expand All @@ -26,7 +27,7 @@ def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs=self.config.gen_kwargs)]
return [datasets.SplitGenerator(name=self.config.split, gen_kwargs=self.config.gen_kwargs)]

def _generate_examples(self, **gen_kwargs):
for idx, ex in enumerate(self.config.generator(**gen_kwargs)):
Expand Down
23 changes: 20 additions & 3 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3871,10 +3871,11 @@ def _gen():
return _gen


def _check_generator_dataset(dataset, expected_features):
def _check_generator_dataset(dataset, expected_features, split):
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 4
assert dataset.num_columns == 3
assert dataset.split == split
assert dataset.column_names == ["col_1", "col_2", "col_3"]
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
Expand All @@ -3886,7 +3887,7 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
_check_generator_dataset(dataset, expected_features)
_check_generator_dataset(dataset, expected_features, NamedSplit("train"))


@pytest.mark.parametrize(
Expand All @@ -3907,7 +3908,23 @@ def test_dataset_from_generator_features(features, data_generator, tmp_path):
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
)
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
_check_generator_dataset(dataset, expected_features)
_check_generator_dataset(dataset, expected_features, NamedSplit("train"))


@pytest.mark.parametrize(
"split",
[None, NamedSplit("train"), "train", NamedSplit("foo"), "foo"],
)
def test_dataset_from_generator_split(split, data_generator, tmp_path):
cache_dir = tmp_path / "cache"
default_expected_split = "train"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
expected_split = split if split else default_expected_split
if split:
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split)
else:
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir)
_check_generator_dataset(dataset, expected_features, expected_split)


@require_not_windows
Expand Down

0 comments on commit 786998a

Please sign in to comment.