Skip to content

Commit

Permalink
Global step for merging datasets (#9)
Browse files Browse the repository at this point in the history
Step uses the `inputs` key as the datasets to be combined, `output` as
the new dataset name and can optionally shuffle the resulting dataset.
  • Loading branch information
danielfleischer authored Aug 19, 2024
1 parent 22d61e5 commit c65e202
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/reference/processing/global_steps/aggregation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragfoundry.processing.global_steps.aggregation
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ nav:
- Prompt Creation: "reference/processing/local_steps/prompter.md"
- RAFT: "reference/processing/local_steps/raft.md"
- Global Steps:
- Aggregation and merging: "reference/processing/global_steps/aggregation.md"
- Sampling and Fewshot: "reference/processing/global_steps/sampling.md"
- Output: "reference/processing/global_steps/output.md"
- Answer Processors:
Expand Down
31 changes: 31 additions & 0 deletions ragfoundry/processing/global_steps/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from datasets import concatenate_datasets

from ..step import GlobalStep


class MergeDatasets(GlobalStep):
"""
Step for merging datasets.
Merge is done using concatenation. Optional shuffling by providing a seed.
"""

def __init__(self, output, shuffle=None, **kwargs):
"""
Args:
output (str): Name of the output dataset. Should be unique.
shuffle (int, optional): seed for shuffling. Default is None.
"""
super().__init__(**kwargs)
self.output = output
self.shuffle = shuffle
self.completed = False
self.cache_step = False

def process(self, dataset_name, datasets, **kwargs):
if not self.completed:
data = concatenate_datasets([datasets[name] for name in self.inputs])
if self.shuffle:
data = data.shuffle(self.shuffle)
datasets[self.output] = data
self.completed = True

0 comments on commit c65e202

Please sign in to comment.