Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split argument to Generator #7015

Merged
merged 11 commits into from
Jul 26, 2024
6 changes: 6 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ def from_generator(
keep_in_memory: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
**kwargs,
):
"""Create a Dataset from a generator.
Expand All @@ -1090,6 +1091,10 @@ def from_generator(
If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`.

<Added version="2.7.0"/>
split ([`NamedSplit`], defaults to `Split.TRAIN`):
Split name to be assigned to the dataset.

<Added version="2.21.0"/>
**kwargs (additional keyword arguments):
Keyword arguments to be passed to :[`GeneratorConfig`].

Expand Down Expand Up @@ -1126,6 +1131,7 @@ def from_generator(
keep_in_memory=keep_in_memory,
gen_kwargs=gen_kwargs,
num_proc=num_proc,
split=split,
**kwargs,
).read()

Expand Down
8 changes: 5 additions & 3 deletions src/datasets/io/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Optional

from .. import Features
from .. import Features, NamedSplit, Split
from ..packaged_modules.generator.generator import Generator
from .abc import AbstractDatasetInputStream

Expand All @@ -15,6 +15,7 @@ def __init__(
streaming: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
**kwargs,
):
super().__init__(
Expand All @@ -30,13 +31,14 @@ def __init__(
features=features,
generator=generator,
gen_kwargs=gen_kwargs,
split=split,
**kwargs,
)

def read(self):
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split="train")
dataset = self.builder.as_streaming_dataset(split=self.builder.config.split)
# Build regular (map-style) dataset
else:
download_config = None
Expand All @@ -52,6 +54,6 @@ def read(self):
num_proc=self.num_proc,
)
dataset = self.builder.as_dataset(
split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory
split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
)
return dataset
11 changes: 6 additions & 5 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects
from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter
from .info import DatasetInfo
from .splits import NamedSplit
from .splits import NamedSplit, Split
from .table import cast_table_to_features, read_schema_from_file, table_cast
from .utils.logging import get_logger
from .utils.py_utils import Literal
Expand Down Expand Up @@ -2083,6 +2083,7 @@ def from_generator(
generator: Callable,
features: Optional[Features] = None,
gen_kwargs: Optional[dict] = None,
split: NamedSplit = Split.TRAIN,
) -> "IterableDataset":
"""Create an Iterable Dataset from a generator.

Expand All @@ -2095,7 +2096,10 @@ def from_generator(
Keyword arguments to be passed to the `generator` callable.
You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`.
This can be used to improve shuffling and when iterating over the dataset with multiple workers.
split ([`NamedSplit`], defaults to `Split.TRAIN`):
Split name to be assigned to the dataset.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments as before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added <Added version="2.21.0"/> please cross-check


<Added version="2.21.0"/>
Returns:
`IterableDataset`

Expand Down Expand Up @@ -2126,10 +2130,7 @@ def from_generator(
from .io.generator import GeneratorDatasetInputStream

return GeneratorDatasetInputStream(
generator=generator,
features=features,
gen_kwargs=gen_kwargs,
streaming=True,
generator=generator, features=features, gen_kwargs=gen_kwargs, streaming=True, split=split
).read()

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/packaged_modules/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class GeneratorConfig(datasets.BuilderConfig):
generator: Optional[Callable] = None
gen_kwargs: Optional[dict] = None
features: Optional[datasets.Features] = None
split: datasets.NamedSplit = datasets.Split.TRAIN

def __post_init__(self):
super().__post_init__()
Expand All @@ -26,7 +27,7 @@ def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs=self.config.gen_kwargs)]
return [datasets.SplitGenerator(name=self.config.split, gen_kwargs=self.config.gen_kwargs)]

def _generate_examples(self, **gen_kwargs):
for idx, ex in enumerate(self.config.generator(**gen_kwargs)):
Expand Down
23 changes: 20 additions & 3 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3871,10 +3871,11 @@ def _gen():
return _gen


def _check_generator_dataset(dataset, expected_features):
def _check_generator_dataset(dataset, expected_features, split):
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 4
assert dataset.num_columns == 3
assert dataset.split == split
assert dataset.column_names == ["col_1", "col_2", "col_3"]
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
Expand All @@ -3886,7 +3887,7 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
_check_generator_dataset(dataset, expected_features)
_check_generator_dataset(dataset, expected_features, NamedSplit("train"))


@pytest.mark.parametrize(
Expand All @@ -3907,7 +3908,23 @@ def test_dataset_from_generator_features(features, data_generator, tmp_path):
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
)
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
_check_generator_dataset(dataset, expected_features)
_check_generator_dataset(dataset, expected_features, NamedSplit("train"))


@pytest.mark.parametrize(
"split",
[None, NamedSplit("train"), "train", NamedSplit("foo"), "foo"],
)
def test_dataset_from_generator_split(split, data_generator, tmp_path):
cache_dir = tmp_path / "cache"
default_expected_split = "train"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
expected_split = split if split else default_expected_split
if split:
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split)
else:
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir)
_check_generator_dataset(dataset, expected_features, expected_split)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a specific test_dataset_from_generator_split with a parametrized split values, such as not passing any value, passing NamedSplit("train"), passing literal "train", passing other NamedSplit, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_dataset_from_generator_split added, still i have impacted _check_generator_dataset to share the same generic check everywhere


@require_not_windows
Expand Down
Loading