From d32f71ad7253290535ddfbeed5c86757553ee1dd Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Sat, 21 Sep 2024 17:16:22 +0200 Subject: [PATCH 1/2] fix zero proba interleave_datasets --- src/datasets/iterable_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 5f5c49f1556..cd7e1eb19cb 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -821,6 +821,11 @@ def __init__( probabilities: Optional[List[float]] = None, stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", ): + if probabilities is not None: + ex_iterables = [ + ex_iterable for ex_iterable, probability in zip(ex_iterables, probabilities) if probability > 0 + ] + probabilities = [probability for probability in probabilities if probability > 0] super().__init__(ex_iterables, stopping_strategy) self.generator = deepcopy(generator) self.probabilities = probabilities From 05365e49a4dbe5e77cf545dcb69fbb31910022c9 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Sat, 21 Sep 2024 17:18:53 +0200 Subject: [PATCH 2/2] for map-style too --- src/datasets/arrow_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index e7ac1a34665..1840c090900 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -6142,6 +6142,9 @@ def _interleave_map_style_datasets( raise ValueError( f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}" ) + if probabilities is not None: + datasets = [dataset for dataset, proability in zip(datasets, probabilities) if proability > 0] + probabilities = [probability for probability in probabilities if probability > 0] # To interleave the datasets, we concatenate them and then we re-order the indices concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split)