Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak when wrapping datasets into PyTorch Dataset without explicit deletion #7180

Closed
iamwangyabin opened this issue Sep 28, 2024 · 1 comment

Comments

@iamwangyabin
Copy link

Describe the bug

I've encountered a memory leak when wrapping the HuggingFace dataset into a PyTorch Dataset. The RAM usage constantly increases during iteration if items are not explicitly deleted after use.

Steps to reproduce the bug

Steps to reproduce:

Create a PyTorch Dataset wrapper for 'nebula/cc12m':

from torch.utils.data import Dataset
from tqdm import tqdm
from datasets import load_dataset
from torchvision import transforms

Image.MAX_IMAGE_PIXELS = None

class CC12M(Dataset):
    def __init__(self, path_or_name='nebula/cc12m', split='train', transform=None, single_caption=True):
        self.raw_dataset = load_dataset(path_or_name)[split]

        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.48145466, 0.4578275, 0.40821073],
                    std=[0.26862954, 0.26130258, 0.27577711]
                )
            ])
        else:
            self.transform = transforms.Compose(transform)

        self.single_caption = single_caption
        self.length = len(self.raw_dataset)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        item = self.raw_dataset[index]
        caption = item['txt']
        with io.BytesIO(item['webp']) as buffer:
            image = Image.open(buffer).convert('RGB')
            if self.transform:
                image = self.transform(image)
        # del item  # Uncomment this line to prevent the memory leak
        return image, caption

Iterate through the dataset without the del item line in getitem.

Observe RAM usage increasing constantly.

Add del item at the end of getitem:

def __getitem__(self, index):
    item = self.raw_dataset[index]
    caption = item['txt']
    with io.BytesIO(item['webp']) as buffer:
        image = Image.open(buffer).convert('RGB')
        if self.transform:
            image = self.transform(image)
    del item  # This line prevents the memory leak
    return image, caption

Iterate through the dataset again and observe that RAM usage remains stable.

Expected behavior

Expected behavior:
RAM usage should remain stable during iteration without needing to explicitly delete items.

Actual behavior:
RAM usage constantly increases unless items are explicitly deleted after use

Environment info

  • datasets version: 2.21.0
  • Platform: Linux-4.18.0-513.5.1.el8_9.x86_64-x86_64-with-glibc2.28
  • Python version: 3.12.4
  • huggingface_hub version: 0.24.6
  • PyArrow version: 17.0.0
  • Pandas version: 2.2.2
  • fsspec version: 2024.6.1
@lhoestq
Copy link
Member

lhoestq commented Sep 30, 2024

I've encountered a memory leak when wrapping the HuggingFace dataset into a PyTorch Dataset. The RAM usage constantly increases during iteration if items are not explicitly deleted after use.

Datasets are memory mapped so they work like SWAP memory. In particular as long as you have RAM available the data will stay in RAM, and get paged out once your system needs RAM for something else (no OOM).

related: #4883

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants