Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter Large Dataset Entry by Entry #7128

Open
QiyaoWei opened this issue Aug 27, 2024 · 4 comments
Open

Filter Large Dataset Entry by Entry #7128

QiyaoWei opened this issue Aug 27, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@QiyaoWei
Copy link

Feature request

I am not sure if this is a new feature, but I wanted to post this problem here, and hear if others have ways of optimizing and speeding up this process.

Let's say I have a really large dataset that I cannot load into memory. At this point, I am only aware of streaming=True to load the dataset. Now, the dataset consists of many tables. Ideally, I would want to have some simple filtering criterion, such that I only see the "good" tables. Here is an example of what the code might look like:

dataset = load_dataset(
    "really-large-dataset",
    streaming=True
)
# And let's say we process the dataset bit by bit because we want intermediate results
dataset = islice(dataset, 10000)

# Define a function to filter the data
def filter_function(table):
    if some_condition:
        return True
    else:
        return False

# Use the filter function on your dataset
filtered_dataset = (ex for ex in dataset if filter_function(ex))

And then I work on the processed dataset, which would be magnitudes faster than working on the original. I would love to hear if the problem setup + solution makes sense to people, and if anyone has suggestions!

Motivation

See description above

Your contribution

Happy to make PR if this is a new feature

@QiyaoWei QiyaoWei added the enhancement New feature or request label Aug 27, 2024
@lhoestq
Copy link
Member

lhoestq commented Sep 20, 2024

Hi ! you can do

filtered_dataset = dataset.filter(filter_function)

on a subset:

filtered_subset = dataset.select(range(10_000)).filter(filter_function)

@jveitchmichaelis
Copy link

jveitchmichaelis commented Sep 21, 2024

Jumping on this as it seems relevant - when I use the filter method, it often results in an OOM (or at least unacceptably high memory usage).

For example in the this notebook, we load an object detection dataset from HF and imagine I want to filter such that I only have images which contain a single annotation class. Each row has a JSON field that contains MS-COCO annotations for the image, so we could load that field and filter on it.

The test dataset is only about 440 images, probably less than 1GB, but running the following filter crashes the VM (over 12 GB RAM):

import json
def filter_single_class(example, target_class_id):
  """Filters examples based on whether they contain annotations from a single class.

  Args:
    example: A dictionary representing a single example from the dataset.
    target_class_id: The target class ID to filter for.

  Returns:
    True if the example contains only annotations from the target class, False otherwise.
  """
  if not example['coco_annotations']:
    return False

  annotation_category_ids = set([annotation['category_id'] for annotation in json.loads(example['coco_annotations'])])

  return len(annotation_category_ids) == 1 and target_class_id in annotation_category_ids

target_class_id = 1 
filtered_dataset = dataset['test'].filter(lambda example: filter_single_class(example, target_class_id))
image

Iterating over the dataset works fine:

filtered_dataset = []
for example in dataset['test']:
  if filter_single_class(example, target_class_id):
    filtered_dataset.append(example)
image

It would be great if there was guidance in the documentation on how to use filters efficiently, or if this is some performance bug that could be addressed. At the very least I would expect a filter operation to use at most 2x the footprint of the database plus some overhead for the lambda (i.e. worst case would be a duplicate copy with all entries retained). Even if the operation is parallelised, each thread/worker should only take a subset of the dataset - so I'm not sure where this ballooning in memory usage comes from.

From some other comments there seems to be a workaround with writer_batch_size or caching to file, but in the docs at least, keep_in_memory defaults to False.

@lhoestq
Copy link
Member

lhoestq commented Sep 23, 2024

You can try passing input_columns=["coco_annotations"] to only load this column instead of all the columns. In that case your function should take coco_annotations as input instead of example

@Mythripaluri
Copy link

If your filter_function is large and computationally intensive, consider using multi-processing or multi-threading with concurrent.futures to filter the dataset. This approach allows you to process multiple tables concurrently, reducing overall processing time, especially for CPU-bound tasks. Use ThreadPoolExecutor for I/O-bound operations and ProcessPoolExecutor for CPU-bound operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants