-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add split argument to Generator #7015
Changes from all commits
3524459
eef7c96
bdd9662
5512e3f
6f1c18b
d74a862
7e50f23
96b9e37
b912261
9480dae
5962e2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3871,10 +3871,11 @@ def _gen(): | |
return _gen | ||
|
||
|
||
def _check_generator_dataset(dataset, expected_features): | ||
def _check_generator_dataset(dataset, expected_features, split): | ||
assert isinstance(dataset, Dataset) | ||
assert dataset.num_rows == 4 | ||
assert dataset.num_columns == 3 | ||
assert dataset.split == split | ||
assert dataset.column_names == ["col_1", "col_2", "col_3"] | ||
for feature, expected_dtype in expected_features.items(): | ||
assert dataset.features[feature].dtype == expected_dtype | ||
|
@@ -3886,7 +3887,7 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t | |
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} | ||
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): | ||
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory) | ||
_check_generator_dataset(dataset, expected_features) | ||
_check_generator_dataset(dataset, expected_features, NamedSplit("train")) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -3907,7 +3908,23 @@ def test_dataset_from_generator_features(features, data_generator, tmp_path): | |
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None | ||
) | ||
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir) | ||
_check_generator_dataset(dataset, expected_features) | ||
_check_generator_dataset(dataset, expected_features, NamedSplit("train")) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"split", | ||
[None, NamedSplit("train"), "train", NamedSplit("foo"), "foo"], | ||
) | ||
def test_dataset_from_generator_split(split, data_generator, tmp_path): | ||
cache_dir = tmp_path / "cache" | ||
default_expected_split = "train" | ||
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} | ||
expected_split = split if split else default_expected_split | ||
if split: | ||
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split) | ||
else: | ||
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir) | ||
_check_generator_dataset(dataset, expected_features, expected_split) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a specific There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
@require_not_windows | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also added
<Added version="2.21.0"/>
please cross-check