diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 2d67db679fc..5ac987e3f9b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1068,6 +1068,7 @@ def from_generator( keep_in_memory: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, + split: NamedSplit = Split.TRAIN, **kwargs, ): """Create a Dataset from a generator. @@ -1090,6 +1091,10 @@ def from_generator( If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`. + split ([`NamedSplit`], defaults to `Split.TRAIN`): + Split name to be assigned to the dataset. + + **kwargs (additional keyword arguments): Keyword arguments to be passed to :[`GeneratorConfig`]. @@ -1126,6 +1131,7 @@ def from_generator( keep_in_memory=keep_in_memory, gen_kwargs=gen_kwargs, num_proc=num_proc, + split=split, **kwargs, ).read() diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index 2566d5fcdcc..b10609cac23 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -1,6 +1,6 @@ from typing import Callable, Optional -from .. import Features +from .. import Features, NamedSplit, Split from ..packaged_modules.generator.generator import Generator from .abc import AbstractDatasetInputStream @@ -15,6 +15,7 @@ def __init__( streaming: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, + split: NamedSplit = Split.TRAIN, **kwargs, ): super().__init__( @@ -30,13 +31,14 @@ def __init__( features=features, generator=generator, gen_kwargs=gen_kwargs, + split=split, **kwargs, ) def read(self): # Build iterable dataset if self.streaming: - dataset = self.builder.as_streaming_dataset(split="train") + dataset = self.builder.as_streaming_dataset(split=self.builder.config.split) # Build regular (map-style) dataset else: download_config = None @@ -52,6 +54,6 @@ def read(self): num_proc=self.num_proc, ) dataset = self.builder.as_dataset( - split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory + split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory ) return dataset diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 94fd11a1b55..d3e261eed5c 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -19,7 +19,7 @@ from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter from .info import DatasetInfo -from .splits import NamedSplit +from .splits import NamedSplit, Split from .table import cast_table_to_features, read_schema_from_file, table_cast from .utils.logging import get_logger from .utils.py_utils import Literal @@ -2083,6 +2083,7 @@ def from_generator( generator: Callable, features: Optional[Features] = None, gen_kwargs: Optional[dict] = None, + split: NamedSplit = Split.TRAIN, ) -> "IterableDataset": """Create an Iterable Dataset from a generator. @@ -2095,7 +2096,10 @@ def from_generator( Keyword arguments to be passed to the `generator` callable. You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`. This can be used to improve shuffling and when iterating over the dataset with multiple workers. + split ([`NamedSplit`], defaults to `Split.TRAIN`): + Split name to be assigned to the dataset. + Returns: `IterableDataset` @@ -2126,10 +2130,7 @@ def from_generator( from .io.generator import GeneratorDatasetInputStream return GeneratorDatasetInputStream( - generator=generator, - features=features, - gen_kwargs=gen_kwargs, - streaming=True, + generator=generator, features=features, gen_kwargs=gen_kwargs, streaming=True, split=split ).read() @staticmethod diff --git a/src/datasets/packaged_modules/generator/generator.py b/src/datasets/packaged_modules/generator/generator.py index 336942f2edc..8a42ba05aa6 100644 --- a/src/datasets/packaged_modules/generator/generator.py +++ b/src/datasets/packaged_modules/generator/generator.py @@ -9,6 +9,7 @@ class GeneratorConfig(datasets.BuilderConfig): generator: Optional[Callable] = None gen_kwargs: Optional[dict] = None features: Optional[datasets.Features] = None + split: datasets.NamedSplit = datasets.Split.TRAIN def __post_init__(self): super().__post_init__() @@ -26,7 +27,7 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): - return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs=self.config.gen_kwargs)] + return [datasets.SplitGenerator(name=self.config.split, gen_kwargs=self.config.gen_kwargs)] def _generate_examples(self, **gen_kwargs): for idx, ex in enumerate(self.config.generator(**gen_kwargs)): diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 6294111d821..01bb71024dc 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3871,10 +3871,11 @@ def _gen(): return _gen -def _check_generator_dataset(dataset, expected_features): +def _check_generator_dataset(dataset, expected_features, split): assert isinstance(dataset, Dataset) assert dataset.num_rows == 4 assert dataset.num_columns == 3 + assert dataset.split == split assert dataset.column_names == ["col_1", "col_2", "col_3"] for feature, expected_dtype in expected_features.items(): assert dataset.features[feature].dtype == expected_dtype @@ -3886,7 +3887,7 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory) - _check_generator_dataset(dataset, expected_features) + _check_generator_dataset(dataset, expected_features, NamedSplit("train")) @pytest.mark.parametrize( @@ -3907,7 +3908,23 @@ def test_dataset_from_generator_features(features, data_generator, tmp_path): Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None ) dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir) - _check_generator_dataset(dataset, expected_features) + _check_generator_dataset(dataset, expected_features, NamedSplit("train")) + + +@pytest.mark.parametrize( + "split", + [None, NamedSplit("train"), "train", NamedSplit("foo"), "foo"], +) +def test_dataset_from_generator_split(split, data_generator, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_split = "train" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_split = split if split else default_expected_split + if split: + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split) + else: + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir) + _check_generator_dataset(dataset, expected_features, expected_split) @require_not_windows