From 639227430822feaab733dc7132a0904d3787b3f2 Mon Sep 17 00:00:00 2001 From: jp1924 Date: Sat, 19 Oct 2024 09:55:09 +0000 Subject: [PATCH] Add: with_split --- src/datasets/dataset_dict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index f92a1a8afda..89e054fabcf 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -784,6 +784,7 @@ def map( function: Optional[Callable] = None, with_indices: bool = False, with_rank: bool = False, + with_split: bool = False, input_columns: Optional[Union[str, List[str]]] = None, batched: bool = False, batch_size: Optional[int] = 1000, @@ -795,7 +796,7 @@ def map( writer_batch_size: Optional[int] = 1000, features: Optional[Features] = None, disable_nullable: bool = False, - fn_kwargs: Optional[dict] = None, + fn_kwargs: dict = {}, num_proc: Optional[int] = None, desc: Optional[str] = None, ) -> "DatasetDict": @@ -882,6 +883,7 @@ def map( self._check_values_type() if cache_file_names is None: cache_file_names = {k: None for k in self} + return DatasetDict( { k: dataset.map( @@ -899,7 +901,7 @@ def map( writer_batch_size=writer_batch_size, features=features, disable_nullable=disable_nullable, - fn_kwargs=fn_kwargs, + fn_kwargs={**fn_kwargs, "split": k} if with_split else fn_kwargs, num_proc=num_proc, desc=desc, )