diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py
index 2d67db679fc..5ac987e3f9b 100644
--- a/src/datasets/arrow_dataset.py
+++ b/src/datasets/arrow_dataset.py
@@ -1068,6 +1068,7 @@ def from_generator(
keep_in_memory: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
+ split: NamedSplit = Split.TRAIN,
**kwargs,
):
"""Create a Dataset from a generator.
@@ -1090,6 +1091,10 @@ def from_generator(
If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`.
+ split ([`NamedSplit`], defaults to `Split.TRAIN`):
+ Split name to be assigned to the dataset.
+
+
**kwargs (additional keyword arguments):
Keyword arguments to be passed to :[`GeneratorConfig`].
@@ -1126,6 +1131,7 @@ def from_generator(
keep_in_memory=keep_in_memory,
gen_kwargs=gen_kwargs,
num_proc=num_proc,
+ split=split,
**kwargs,
).read()
diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py
index 2566d5fcdcc..b10609cac23 100644
--- a/src/datasets/io/generator.py
+++ b/src/datasets/io/generator.py
@@ -1,6 +1,6 @@
from typing import Callable, Optional
-from .. import Features
+from .. import Features, NamedSplit, Split
from ..packaged_modules.generator.generator import Generator
from .abc import AbstractDatasetInputStream
@@ -15,6 +15,7 @@ def __init__(
streaming: bool = False,
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
+ split: NamedSplit = Split.TRAIN,
**kwargs,
):
super().__init__(
@@ -30,13 +31,14 @@ def __init__(
features=features,
generator=generator,
gen_kwargs=gen_kwargs,
+ split=split,
**kwargs,
)
def read(self):
# Build iterable dataset
if self.streaming:
- dataset = self.builder.as_streaming_dataset(split="train")
+ dataset = self.builder.as_streaming_dataset(split=self.builder.config.split)
# Build regular (map-style) dataset
else:
download_config = None
@@ -52,6 +54,6 @@ def read(self):
num_proc=self.num_proc,
)
dataset = self.builder.as_dataset(
- split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory
+ split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
)
return dataset
diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py
index 94fd11a1b55..d3e261eed5c 100644
--- a/src/datasets/iterable_dataset.py
+++ b/src/datasets/iterable_dataset.py
@@ -19,7 +19,7 @@
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects
from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter
from .info import DatasetInfo
-from .splits import NamedSplit
+from .splits import NamedSplit, Split
from .table import cast_table_to_features, read_schema_from_file, table_cast
from .utils.logging import get_logger
from .utils.py_utils import Literal
@@ -2083,6 +2083,7 @@ def from_generator(
generator: Callable,
features: Optional[Features] = None,
gen_kwargs: Optional[dict] = None,
+ split: NamedSplit = Split.TRAIN,
) -> "IterableDataset":
"""Create an Iterable Dataset from a generator.
@@ -2095,7 +2096,10 @@ def from_generator(
Keyword arguments to be passed to the `generator` callable.
You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`.
This can be used to improve shuffling and when iterating over the dataset with multiple workers.
+ split ([`NamedSplit`], defaults to `Split.TRAIN`):
+ Split name to be assigned to the dataset.
+
Returns:
`IterableDataset`
@@ -2126,10 +2130,7 @@ def from_generator(
from .io.generator import GeneratorDatasetInputStream
return GeneratorDatasetInputStream(
- generator=generator,
- features=features,
- gen_kwargs=gen_kwargs,
- streaming=True,
+ generator=generator, features=features, gen_kwargs=gen_kwargs, streaming=True, split=split
).read()
@staticmethod
diff --git a/src/datasets/packaged_modules/generator/generator.py b/src/datasets/packaged_modules/generator/generator.py
index 336942f2edc..8a42ba05aa6 100644
--- a/src/datasets/packaged_modules/generator/generator.py
+++ b/src/datasets/packaged_modules/generator/generator.py
@@ -9,6 +9,7 @@ class GeneratorConfig(datasets.BuilderConfig):
generator: Optional[Callable] = None
gen_kwargs: Optional[dict] = None
features: Optional[datasets.Features] = None
+ split: datasets.NamedSplit = datasets.Split.TRAIN
def __post_init__(self):
super().__post_init__()
@@ -26,7 +27,7 @@ def _info(self):
return datasets.DatasetInfo(features=self.config.features)
def _split_generators(self, dl_manager):
- return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs=self.config.gen_kwargs)]
+ return [datasets.SplitGenerator(name=self.config.split, gen_kwargs=self.config.gen_kwargs)]
def _generate_examples(self, **gen_kwargs):
for idx, ex in enumerate(self.config.generator(**gen_kwargs)):
diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py
index 6294111d821..01bb71024dc 100644
--- a/tests/test_arrow_dataset.py
+++ b/tests/test_arrow_dataset.py
@@ -3871,10 +3871,11 @@ def _gen():
return _gen
-def _check_generator_dataset(dataset, expected_features):
+def _check_generator_dataset(dataset, expected_features, split):
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 4
assert dataset.num_columns == 3
+ assert dataset.split == split
assert dataset.column_names == ["col_1", "col_2", "col_3"]
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
@@ -3886,7 +3887,7 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
- _check_generator_dataset(dataset, expected_features)
+ _check_generator_dataset(dataset, expected_features, NamedSplit("train"))
@pytest.mark.parametrize(
@@ -3907,7 +3908,23 @@ def test_dataset_from_generator_features(features, data_generator, tmp_path):
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
)
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
- _check_generator_dataset(dataset, expected_features)
+ _check_generator_dataset(dataset, expected_features, NamedSplit("train"))
+
+
+@pytest.mark.parametrize(
+ "split",
+ [None, NamedSplit("train"), "train", NamedSplit("foo"), "foo"],
+)
+def test_dataset_from_generator_split(split, data_generator, tmp_path):
+ cache_dir = tmp_path / "cache"
+ default_expected_split = "train"
+ expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
+ expected_split = split if split else default_expected_split
+ if split:
+ dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split)
+ else:
+ dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir)
+ _check_generator_dataset(dataset, expected_features, expected_split)
@require_not_windows