From 717da186d8c0748d20160b0a4a9b233c60222656 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 24 Oct 2024 13:29:59 -0700 Subject: [PATCH] Add image documentation (#238) * Add partial image implementation Signed-off-by: Ryan Wolf * Refactor requirements Signed-off-by: Ryan Wolf * Fix bugs Signed-off-by: Ryan Wolf * Change from_map to map_partitions Signed-off-by: Ryan Wolf * Add super constructor Signed-off-by: Ryan Wolf * Add kwargs for load_object_on_worker Signed-off-by: Ryan Wolf * Get proper epoch size Signed-off-by: Ryan Wolf * Complete embedding creation loop Signed-off-by: Ryan Wolf * Change devices Signed-off-by: Ryan Wolf * Add device Signed-off-by: Ryan Wolf * Refactor embedding creation and add classifier Signed-off-by: Ryan Wolf * Fix bugs in classifiers Signed-off-by: Ryan Wolf * Refactor model names Signed-off-by: Ryan Wolf * Add model name Signed-off-by: Ryan Wolf * Fix classifier bugs Signed-off-by: Ryan Wolf * Allow postprocessing for classifiers Signed-off-by: Ryan Wolf * Fix name and add print Signed-off-by: Ryan Wolf * Fix variable name Signed-off-by: Ryan Wolf * Add NSFW Signed-off-by: Ryan Wolf * Update init for import Signed-off-by: Ryan Wolf * Fix embedding size Signed-off-by: Ryan Wolf * Add fused classifiers Signed-off-by: Ryan Wolf * Fix missing index Signed-off-by: Ryan Wolf * Update metdata for fused classifiers Signed-off-by: Ryan Wolf * Add export to webdataset Signed-off-by: Ryan Wolf * Fix missing id col Signed-off-by: Ryan Wolf * Sort embeddings by id Signed-off-by: Ryan Wolf * Add timm Signed-off-by: Ryan Wolf * Update init file Signed-off-by: Ryan Wolf * Add autocast to timm Signed-off-by: Ryan Wolf * Update requirements and transform Signed-off-by: Ryan Wolf * Add additional interpolation support Signed-off-by: Ryan Wolf * Fix transform normalization Signed-off-by: Ryan Wolf * Remove open_clip Signed-off-by: Ryan Wolf * Add index path support to wds Signed-off-by: Ryan Wolf * Address Vibhu's feedback Signed-off-by: Ryan Wolf * Add import guard for image dataset Signed-off-by: Ryan Wolf * Change default device Signed-off-by: Ryan Wolf * Remove commented code Signed-off-by: Ryan Wolf * Remove device id Signed-off-by: Ryan Wolf * Fix index issue Signed-off-by: Ryan Wolf * Add docstrings and standardize variable names Signed-off-by: Ryan Wolf * Add image curation tutorial Signed-off-by: Ryan Wolf * Add initial image docs Signed-off-by: Ryan Wolf * Remove tutorial Signed-off-by: Ryan Wolf * Add dataset docs Signed-off-by: Ryan Wolf * Add embedder documentation Signed-off-by: Ryan Wolf * Revert embedding column name change Signed-off-by: Ryan Wolf * Update user guide for images Signed-off-by: Ryan Wolf * Update README Signed-off-by: Ryan Wolf * Update README with RAPIDS nightly instructions Signed-off-by: Ryan Wolf * Fix formatting issues in image documentation Signed-off-by: Ryan Wolf * Remove extra newline in README Signed-off-by: Ryan Wolf * Address most of Sarah's feedback Signed-off-by: Ryan Wolf * Add section summary Signed-off-by: Ryan Wolf * Fix errors and REWORD GPU bullets in README Signed-off-by: Ryan Wolf * Fix how table of contents displays with new sections Signed-off-by: Ryan Wolf --------- Signed-off-by: Ryan Wolf --- README.md | 152 ++++++++---------- docs/user-guide/api/datasets.rst | 8 + docs/user-guide/api/image/classifiers.rst | 21 +++ docs/user-guide/api/image/embedders.rst | 18 +++ docs/user-guide/api/image/index.rst | 10 ++ docs/user-guide/api/index.rst | 1 + .../user-guide/{images => assets}/diagram.png | Bin .../sorted_sequence_dataloader.png | Bin .../{images => assets}/zeroshot_ablations.png | Bin .../distributeddataclassification.rst | 2 +- .../image/classifiers/aesthetic.rst | 97 +++++++++++ docs/user-guide/image/classifiers/index.rst | 8 + docs/user-guide/image/classifiers/nsfw.rst | 97 +++++++++++ docs/user-guide/image/datasets.rst | 121 ++++++++++++++ docs/user-guide/image/embedders.rst | 121 ++++++++++++++ docs/user-guide/image/gettingstarted.rst | 64 ++++++++ docs/user-guide/image/index.rst | 7 + docs/user-guide/index.rst | 69 ++++++-- .../datasets/image_text_pair_dataset.py | 82 +++++++++- nemo_curator/image/classifiers/aesthetic.py | 27 +++- nemo_curator/image/classifiers/base.py | 62 ++++++- nemo_curator/image/classifiers/nsfw.py | 28 +++- nemo_curator/image/embedders/base.py | 65 +++++++- nemo_curator/image/embedders/timm.py | 71 +++++++- setup.py | 2 +- 25 files changed, 1002 insertions(+), 131 deletions(-) create mode 100644 docs/user-guide/api/image/classifiers.rst create mode 100644 docs/user-guide/api/image/embedders.rst create mode 100644 docs/user-guide/api/image/index.rst rename docs/user-guide/{images => assets}/diagram.png (100%) rename docs/user-guide/{images => assets}/sorted_sequence_dataloader.png (100%) rename docs/user-guide/{images => assets}/zeroshot_ablations.png (100%) create mode 100644 docs/user-guide/image/classifiers/aesthetic.rst create mode 100644 docs/user-guide/image/classifiers/index.rst create mode 100644 docs/user-guide/image/classifiers/nsfw.rst create mode 100644 docs/user-guide/image/datasets.rst create mode 100644 docs/user-guide/image/embedders.rst create mode 100644 docs/user-guide/image/gettingstarted.rst create mode 100644 docs/user-guide/image/index.rst diff --git a/README.md b/README.md index f6ba195d..21127d34 100644 --- a/README.md +++ b/README.md @@ -9,51 +9,43 @@ # NeMo Curator -πŸš€ **The GPU-Accelerated Open Source Framework for Efficient Large Language Model Data Curation** πŸš€ +πŸš€ **The GPU-Accelerated Open Source Framework for Efficient Generative AI Model Data Curation** πŸš€ -

- diagram -

- -NeMo Curator is a Python library specifically designed for fast and scalable dataset preparation and curation for [large language model (LLM)](https://www.nvidia.com/en-us/glossary/large-language-models/) use-cases such as foundation model pretraining, domain-adaptive pretraining (DAPT), supervised fine-tuning (SFT) and paramter-efficient fine-tuning (PEFT). It greatly accelerates data curation by leveraging GPUs with [Dask](https://www.dask.org/) and [RAPIDS](https://developer.nvidia.com/rapids), resulting in significant time savings. The library provides a customizable and modular interface, simplifying pipeline expansion and accelerating model convergence through the preparation of high-quality tokens. - -At the core of the NeMo Curator is the `DocumentDataset` which serves as the the main dataset class. It acts as a straightforward wrapper around a Dask `DataFrame`. The Python library offers easy-to-use methods for expanding the functionality of your curation pipeline while eliminating scalability concerns. +NeMo Curator is a Python library specifically designed for fast and scalable dataset preparation and curation for generative AI use cases such as foundation language model pretraining, text-to-image model training, domain-adaptive pretraining (DAPT), supervised fine-tuning (SFT) and parameter-efficient fine-tuning (PEFT). It greatly accelerates data curation by leveraging GPUs with [Dask](https://www.dask.org/) and [RAPIDS](https://developer.nvidia.com/rapids), resulting in significant time savings. The library provides a customizable and modular interface, simplifying pipeline expansion and accelerating model convergence through the preparation of high-quality tokens. ## Key Features -NeMo Curator provides a collection of scalable data-mining modules. Some of the key features include: - -- [Data download and text extraction](docs/user-guide/download.rst) - - - Default implementations for downloading and extracting Common Crawl, Wikipedia, and ArXiv data - - Easily customize the download and extraction and extend to other datasets - -- [Language identification and separation](docs/user-guide/languageidentificationunicodeformatting.rst) with [fastText](https://fasttext.cc/docs/en/language-identification.html) and [pycld2](https://pypi.org/project/pycld2/) - -- [Text reformatting and cleaning](docs/user-guide/languageidentificationunicodeformatting.rst) to fix unicode decoding errors via [ftfy](https://ftfy.readthedocs.io/en/latest/) - -- [Quality filtering](docs/user-guide/qualityfiltering.rst) - - - Multilingual heuristic-based filtering - - Classifier-based filtering via [fastText](https://fasttext.cc/) - -- [Document-level deduplication](docs/user-guide/gpudeduplication.rst) - - - exact and fuzzy (near-identical) deduplication are accelerated using cuDF and Dask - - For fuzzy deduplication, our implementation follows the method described in [Microsoft Turing NLG 530B](https://arxiv.org/abs/2201.11990) - - For semantic deduplication, our implementation follows the method described in [SemDeDup](https://arxiv.org/pdf/2303.09540) by Meta AI (FAIR) [facebookresearch/SemDeDup](https://github.com/facebookresearch/SemDeDup) - -- [Multilingual downstream-task decontamination](docs/user-guide/taskdecontamination.rst) following the approach of [OpenAI GPT3](https://arxiv.org/pdf/2005.14165.pdf) and [Microsoft Turing NLG 530B](https://arxiv.org/abs/2201.11990) - -- [Distributed data classification](docs/user-guide/distributeddataclassification.rst) - - - Multi-node, multi-GPU classifier inference - - Provides sophisticated domain and quality classification - - Flexible interface for extending to your own classifier network - -- [Personal identifiable information (PII) redaction](docs/user-guide/personalidentifiableinformationidentificationandremoval.rst) for removing addresses, credit card numbers, social security numbers, and more - -These modules offer flexibility and permit reordering, with only a few exceptions. In addition, the [NeMo Framework Launcher](https://github.com/NVIDIA/NeMo-Megatron-Launcher) provides pre-built pipelines that can serve as a foundation for your customization use cases. +NeMo Curator provides a collection of scalable data curation modules for text and image curation. + +### Text Curation +All of our text pipelines have great multilingual support. + +- [Download and Extraction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/download.html) + - Default implementations for Common Crawl, Wikipedia, and ArXiv sources + - Easily customize and extend to other sources +- [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) +- [Unicode Reformatting](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) +- [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) +- Classifier Filtering + - [fastText](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) + - GPU-Accelerated models: [Domain, Quality, and Safety Classification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html) +- **GPU-Accelerated Deduplication** + - [Exact Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) + - [Fuzzy Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/gpudeduplication.html) via MinHash Locality Sensitive Hashing + - [Semantic Deduplication](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) +- [Downstream-task Decontamination](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/taskdecontamination.html) +- [Personal Identifiable Information (PII) Redaction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/personalidentifiableinformationidentificationandremoval.html) + +### Image Curation + +- [Embedding Creation](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/embedders.html) +- Classifier Filtering + - [Aesthetic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/aesthetic.html) and [NSFW](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/image/classifiers/nsfw.html) Classification +- GPU Deduplication + - [Semantic](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/semdedup.html) + +These modules offer flexibility and permit reordering, with only a few exceptions. +All the modules automatically scale to multiple nodes to increase throughput. ## Resources @@ -83,59 +75,52 @@ Before installing NeMo Curator, ensure that the following requirements are met: - Voltaβ„’ or higher ([compute capability 7.0+](https://developer.nvidia.com/cuda-gpus)) - CUDA 12 (or above) -You can install NeMo-Curator -1. from PyPi -2. from source -3. get it through the [NeMo Framework container](https://github.com/NVIDIA/NeMo?tab=readme-ov-file#docker-containers). - - +You can get NeMo-Curator in 3 ways. +1. PyPi +2. Source +3. NeMo Framework Container -#### From PyPi - -To install the CPU-only modules: +#### PyPi ```bash pip install cython -pip install nemo-curator +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[all] ``` -To install the CPU and CUDA-accelerated modules: - +#### Source ```bash +git clone https://github.com/NVIDIA/NeMo-Curator.git pip install cython -pip install --extra-index-url https://pypi.nvidia.com nemo-curator[cuda12x] +pip install ./NeMo-Curator[all] ``` -#### From Source +#### NeMo Framework Container -1. Clone the NeMo Curator repository in GitHub. - - ```bash - git clone https://github.com/NVIDIA/NeMo-Curator.git - cd NeMo-Curator - ``` - -2. Install the modules that you need. +The latest release of NeMo Curator comes preinstalled in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags). If you want the latest commit inside the container, you can reinstall NeMo Curator using: - To install the CPU-only modules: +```bash +pip uninstall nemo-curator +rm -r /opt/NeMo-Curator +git clone https://github.com/NVIDIA/NeMo-Curator.git /opt/NeMo-Curator +pip install --extra-index-url https://pypi.nvidia.com /opt/NeMo-Curator[all] +``` - ```bash - pip install cython - pip install . - ``` +#### Extras +NeMo Curator has a set of extras you can use to only install the necessary modules for your workload. +These extras are available for all installation methods provided. - To install the CPU and CUDA-accelerated modules: +```bash +pip install nemo-curator # Installs CPU-only text curation modules +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[cuda12x] # Installs CPU + GPU text curation modules +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[image] # Installs CPU + GPU text and image curation modules +pip install --extra-index-url https://pypi.nvidia.com nemo-curator[all] # Installs all of the above +``` - ```bash - pip install cython - pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]" - ``` #### Using Nightly Dependencies for RAPIDS You can also install NeMo Curator using the [RAPIDS Nightly Builds](https://docs.rapids.ai/install). To do so, you can set the environment variable `RAPIDS_NIGHTLY=1`. - ```bash # installing from pypi RAPIDS_NIGHTLY=1 pip install --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple "nemo-curator[cuda12x]" @@ -146,18 +131,6 @@ RAPIDS_NIGHTLY=1 pip install --extra-index-url=https://pypi.anaconda.org/rapidsa When the `RAPIDS_NIGHTLY` variable is set to 0 (which is the default), it will use the stable version of RAPIDS. -#### From the NeMo Framework Container - -The latest release of NeMo Curator comes preinstalled in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags). If you want the latest commit inside the container, you can reinstall NeMo Curator using: - -```bash -pip uninstall nemo-curator -rm -r /opt/NeMo-Curator -git clone https://github.com/NVIDIA/NeMo-Curator.git /opt/NeMo-Curator -pip install --extra-index-url https://pypi.nvidia.com /opt/NeMo-Curator[cuda12x] -``` -And follow the instructions for installing from source from [above](#from-source). - ## Use NeMo Curator ### Python API Quick Example @@ -189,6 +162,7 @@ To get started with NeMo Curator, you can follow the tutorials [available here]( - [`peft-curation`](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/peft-curation) which focuses on data curation for LLM parameter-efficient fine-tuning (PEFT) use-cases. - [`distributed_data_classification`](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/distributed_data_classification) which focuses on using the quality and domain classifiers to help with data annotation. - [`single_node_tutorial`](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/single_node_tutorial) which demonstrates an end-to-end data curation pipeline for curating Wikipedia data in Thai. +- [`image-curation`](https://github.com/NVIDIA/NeMo-Curator/blob/main/tutorials/image-curation/image-curation.ipynb) which explores the scalable image curation modules. ### Access Python Modules @@ -201,9 +175,9 @@ NeMo Curator also offers CLI scripts for you to use. The scripts in `nemo_curato ### Use NeMo Framework Launcher -As an alternative method for interfacing with NeMo Curator, you can use the [NeMo Framework Launcher](https://github.com/NVIDIA/NeMo-Megatron-Launcher). The launcher enables you to easily configure the parameters and cluster. It can also automatically generate the SLURM batch scripts that wrap around the CLI scripts required to run your pipeline. +As an alternative method for interfacing with NeMo Curator, you can use the [NeMo Framework Launcher](https://github.com/NVIDIA/NeMo-Megatron-Launcher). The launcher enables you to easily configure the parameters and cluster. It can also automatically generate the Slurm batch scripts that wrap around the CLI scripts required to run your pipeline. -In addition, other methods are available to run NeMo Curator on SLURM. For example, refer to the example scripts in [`examples/slurm`](examples/slurm/) for information on how to run NeMo Curator on SLURM without the NeMo Framework Launcher. +In addition, other methods are available to run NeMo Curator on Slurm. For example, refer to the example scripts in [`examples/slurm`](examples/slurm/) for information on how to run NeMo Curator on Slurm without the NeMo Framework Launcher. ## Module Ablation and Compute Performance @@ -212,7 +186,7 @@ The modules within NeMo Curator were primarily designed to curate high-quality d The following figure shows that the use of different data curation modules implemented in NeMo Curator led to improved model zero-shot downstream task performance.

- drawing + drawing

In terms of scalability and compute performance, using the combination of RAPIDS and Dask fuzzy deduplication enabled us to deduplicate the 1.1 Trillion token Red Pajama dataset in 1.8 hours with 64 NVIDIA A100 Tensor Core GPUs. diff --git a/docs/user-guide/api/datasets.rst b/docs/user-guide/api/datasets.rst index 43e532b1..c8dba791 100644 --- a/docs/user-guide/api/datasets.rst +++ b/docs/user-guide/api/datasets.rst @@ -7,4 +7,12 @@ DocumentDataset ------------------- .. autoclass:: nemo_curator.datasets.DocumentDataset + :members: + + +------------------------------- +ImageTextPairDataset +------------------------------- + +.. autoclass:: nemo_curator.datasets.ImageTextPairDataset :members: \ No newline at end of file diff --git a/docs/user-guide/api/image/classifiers.rst b/docs/user-guide/api/image/classifiers.rst new file mode 100644 index 00000000..a43560e5 --- /dev/null +++ b/docs/user-guide/api/image/classifiers.rst @@ -0,0 +1,21 @@ +====================================== +Classifiers +====================================== + +------------------------------ +Base Class +------------------------------ + +.. autoclass:: nemo_curator.image.classifiers.ImageClassifier + :members: + + +------------------------------ +Image Classifiers +------------------------------ + +.. autoclass:: nemo_curator.image.classifiers.AestheticClassifier + :members: + +.. autoclass:: nemo_curator.image.classifiers.NsfwClassifier + :members: \ No newline at end of file diff --git a/docs/user-guide/api/image/embedders.rst b/docs/user-guide/api/image/embedders.rst new file mode 100644 index 00000000..aa1de81e --- /dev/null +++ b/docs/user-guide/api/image/embedders.rst @@ -0,0 +1,18 @@ +====================================== +Embedders +====================================== + +------------------------------ +Base Class +------------------------------ + +.. autoclass:: nemo_curator.image.embedders.ImageEmbedder + :members: + + +------------------------------ +Timm +------------------------------ + +.. autoclass:: nemo_curator.image.embedders.TimmImageEmbedder + :members: \ No newline at end of file diff --git a/docs/user-guide/api/image/index.rst b/docs/user-guide/api/image/index.rst new file mode 100644 index 00000000..c58862f4 --- /dev/null +++ b/docs/user-guide/api/image/index.rst @@ -0,0 +1,10 @@ +====================================== +Image Curation +====================================== + +.. toctree:: + :maxdepth: 4 + :titlesonly: + + embedders.rst + classifiers.rst \ No newline at end of file diff --git a/docs/user-guide/api/index.rst b/docs/user-guide/api/index.rst index 866f06b9..b76dd75b 100644 --- a/docs/user-guide/api/index.rst +++ b/docs/user-guide/api/index.rst @@ -18,4 +18,5 @@ API Reference decontamination.rst services.rst synthetic.rst + image/index.rst misc.rst \ No newline at end of file diff --git a/docs/user-guide/images/diagram.png b/docs/user-guide/assets/diagram.png similarity index 100% rename from docs/user-guide/images/diagram.png rename to docs/user-guide/assets/diagram.png diff --git a/docs/user-guide/images/sorted_sequence_dataloader.png b/docs/user-guide/assets/sorted_sequence_dataloader.png similarity index 100% rename from docs/user-guide/images/sorted_sequence_dataloader.png rename to docs/user-guide/assets/sorted_sequence_dataloader.png diff --git a/docs/user-guide/images/zeroshot_ablations.png b/docs/user-guide/assets/zeroshot_ablations.png similarity index 100% rename from docs/user-guide/images/zeroshot_ablations.png rename to docs/user-guide/assets/zeroshot_ablations.png diff --git a/docs/user-guide/distributeddataclassification.rst b/docs/user-guide/distributeddataclassification.rst index 43d673b8..b411896b 100644 --- a/docs/user-guide/distributeddataclassification.rst +++ b/docs/user-guide/distributeddataclassification.rst @@ -201,7 +201,7 @@ The key feature of CrossFit used in NeMo Curator is the sorted sequence data loa - Groups sorted sequences into optimized batches. - Efficiently allocates batches to the provided GPU memories by estimating the memory footprint for each sequence length and batch size. -.. image:: images/sorted_sequence_dataloader.png +.. image:: assets/sorted_sequence_dataloader.png :alt: Sorted Sequence Data Loader Check out the `rapidsai/crossfit`_ repository for more information. diff --git a/docs/user-guide/image/classifiers/aesthetic.rst b/docs/user-guide/image/classifiers/aesthetic.rst new file mode 100644 index 00000000..3a43cebe --- /dev/null +++ b/docs/user-guide/image/classifiers/aesthetic.rst @@ -0,0 +1,97 @@ +========================= +Aesthetic Classifier +========================= + +-------------------- +Overview +-------------------- +Aesthetic classifiers can be used to assess the subjective quality of an image. +NeMo Curator integrates the `improved aesthetic predictor `_ that outputs a score from 0-10 where 10 is aesthetically pleasing. + +-------------------- +Use Cases +-------------------- +Filtering by aesthetic quality is common in generative image pipelines. +For example, `Stable Diffusion `_ progressively filtered by aesthetic score during training. + + +-------------------- +Prerequisites +-------------------- +Make sure you check out the `image curation getting started page `_ to install everything you will need. + +-------------------- +Usage +-------------------- + +The aesthetic classifier is a linear classifier that takes OpenAI CLIP ViT-L/14 image embeddings as input. +This model is available through the ``vit_large_patch14_clip_quickgelu_224.openai`` identifier in ``TimmImageEmbedder``. +First, we can compute these embeddings, then we can perform the classification. + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import AestheticClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + ) + aesthetic_classifier = AestheticClassifier() + + dataset_with_embeddings = embedding_model(dataset) + dataset_with_aesthetic_scores = aesthetic_classifier(dataset_with_embeddings) + + # Metadata will have a new column named "aesthetic_score" + dataset_with_aesthetic_scores.save_metadata() + +-------------------- +Key Parameters +-------------------- +* ``batch_size=-1`` is the optional batch size parameter. By default, it will process all the embeddings in a shard at once. Since the aesthetic classifier is a linear model, this is usually fine. + +--------------------------- +Performance Considerations +--------------------------- +Since the aesthetic model is so small, you can load it onto the GPU at the same time as the embedding model and perform inference directly after computing the embeddings. +Check out this example: + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import AestheticClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + classifiers=[AestheticClassifier()], + ) + + dataset_with_aesthetic_scores = embedding_model(dataset) + + # Metadata will have a new column named "aesthetic_score" + dataset_with_aesthetic_scores.save_metadata() + +--------------------------- +Additional Resources +--------------------------- +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/classifiers/index.rst b/docs/user-guide/image/classifiers/index.rst new file mode 100644 index 00000000..bbd67f0e --- /dev/null +++ b/docs/user-guide/image/classifiers/index.rst @@ -0,0 +1,8 @@ +.. _data-curator-image-classifiers: + +.. toctree:: + :maxdepth: 4 + :titlesonly: + + aesthetic.rst + nsfw.rst \ No newline at end of file diff --git a/docs/user-guide/image/classifiers/nsfw.rst b/docs/user-guide/image/classifiers/nsfw.rst new file mode 100644 index 00000000..d7d3533d --- /dev/null +++ b/docs/user-guide/image/classifiers/nsfw.rst @@ -0,0 +1,97 @@ +========================= +NSFW Classifier +========================= + +-------------------- +Overview +-------------------- +Not-safe-for-work (NSFW) classifiers determine the likelihood of an image containing sexually explicity material. +NeMo Curator integrates with `CLIP-based-NSFW-Detector `_ that outputs a value between 0 and 1 where 1 means the content is NSFW. + +-------------------- +Use Cases +-------------------- +Removing unsafe content is common in most data processing pipelines to prevent your generative AI model from learning to produce unsafe material. +For example, `Data Comp `_ filter out NSFW content before conducting their experiments. + +-------------------- +Prerequisites +-------------------- +Make sure you check out the `image curation getting started page `_ to install everything you will need. + +-------------------- +Usage +-------------------- + +The NSFW classifier is a small MLP classifier that takes OpenAI CLIP ViT-L/14 image embeddings as input. +This model is available through the ``vit_large_patch14_clip_quickgelu_224.openai`` identifier in ``TimmImageEmbedder``. +First, we can compute these embeddings, then we can perform the classification. + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import NsfwClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + ) + safety_classifier = NsfwClassifier() + + dataset_with_embeddings = embedding_model(dataset) + dataset_with_nsfw_scores = safety_classifier(dataset_with_embeddings) + + # Metadata will have a new column named "nsfw_score" + dataset_with_nsfw_scores.save_metadata() + +-------------------- +Key Parameters +-------------------- +* ``batch_size=-1`` is the optional batch size parameter. By default, it will process all the embeddings in a shard at once. Since the NSFW classifier is a small model, this is usually fine. + +--------------------------- +Performance Considerations +--------------------------- +Since the NSFW model is so small, you can load it onto the GPU at the same time as the embedding model and perform inference directly after computing the embeddings. +Check out this example: + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + from nemo_curator.image.classifiers import NsfwClassifier + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + classifiers=[NsfwClassifier()], + ) + + dataset_with_nsfw_scores = embedding_model(dataset) + + # Metadata will have a new column named "nsfw_score" + dataset_with_nsfw_scores.save_metadata() + + +--------------------------- +Additional Resources +--------------------------- +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/datasets.rst b/docs/user-guide/image/datasets.rst new file mode 100644 index 00000000..11000937 --- /dev/null +++ b/docs/user-guide/image/datasets.rst @@ -0,0 +1,121 @@ +.. _data-curator-image-datasets: + +========================= +Image-Text Pair Datasets +========================= + +Image-text pair datasets are commonly used for training generative text to image models or CLIP models. +NeMo Curator supports reading and writing datasets based on the `WebDataset `_ file format. +This format allows NeMo Curator to annotate the dataset with metadata including embeddings and classifier scores. +Its sharded format also makes it easier to distribute work to different workers processing the dataset. + +------------ +File Format +------------ + +Here is an example of what a dataset directory that is in the WebDataset format should look like. + +:: + + dataset/ + β”œβ”€β”€ 00000.tar + β”‚ β”œβ”€β”€ 000000000.jpg + β”‚ β”œβ”€β”€ 000000000.json + β”‚ β”œβ”€β”€ 000000000.txt + β”‚ β”œβ”€β”€ 000000001.jpg + β”‚ β”œβ”€β”€ 000000001.json + β”‚ β”œβ”€β”€ 000000001.txt + β”‚ └── ... + β”œβ”€β”€ 00001.tar + β”‚ β”œβ”€β”€ 000010000.jpg + β”‚ β”œβ”€β”€ 000010000.json + β”‚ β”œβ”€β”€ 000010000.txt + β”‚ β”œβ”€β”€ 000010001.jpg + β”‚ β”œβ”€β”€ 000010001.json + β”‚ β”œβ”€β”€ 000010001.txt + β”‚ └── ... + β”œβ”€β”€ 00002.tar + β”‚ └── ... + β”œβ”€β”€ 00000.parquet + β”œβ”€β”€ 00001.parquet + └── 00002.parquet + + +The exact format assumes a single directory with sharded ``.tar``, ``.parquet``, and (optionally) +``.idx`` files. Each tar file should have a unique integer ID as its name (``00000.tar``, +``00001.tar``, ``00002.tar``, etc.). The tar files should contain images in ``.jpg`` files, text captions +in ``.txt`` files, and metadata in ``.json`` files. Each record of the dataset is identified by +a unique ID that is a mix of the shard ID along with the offset of the record within a shard. +For example, the 32nd record of the 43rd shard would be in ``00042.tar`` and have image ``000420031.jpg``, +caption ``000420031.txt``, and metadata ``000420031.json`` (assuming zero indexing). + +In addition to the collection of tar files, NeMo Curator's ``ImageTextPairDataset`` expects there to be .parquet files +in the root directory that follow the same naming convention as the shards (``00042.tar`` -> ``00042.parquet``). +Each Parquet file should contain an aggregated tabular form of the metadata for each record, with +each row in the Parquet file corresponding to a record in that shard. The metadata, both in the Parquet +files and the JSON files, must contain a unique ID column that is the same as its record ID (000420031 +in our examples). + +------- +Reading +------- + +Datasets can be read in using ``ImageTextPairDataset.from_webdataset()`` + +.. code-block:: python + from nemo_curator.datasets import ImageTextPairDataset + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + +* ``path="/path/to/dataset"`` should point to the root directory of the WebDataset. +* ``id_col="key"`` lets us know that the unique ID column in the dataset is named "key". + +A more thorough list of parameters can be found in the `API Reference `_. + +------- +Writing +------- + +There are two ways to write an image dataset. The first way only saves the metadata, while the second way will reshard the tar files. +Both trigger the computation of all the tasks you have set to run beforehand. + +.. code-block:: python + from nemo_curator.datasets import ImageTextPairDataset + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + # Perform your operations (embedding creation, classifiers, etc.) + + dataset.save_metadata() + +``save_metadata()`` will only save sharded Parquet files to the target directory. It does not modify the tar files. +There are two optional parameters: + +* ``path`` allows you to change the location of where the dataset is saved. By default, it will overwrite the original Parquet files. +* ``columns`` allows you to only save a subset of metadata. By default, all metadata will be saved. + + +.. code-block:: python + from nemo_curator.datasets import ImageTextPairDataset + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + # Perform your operations (embedding creation, classifiers, etc.) + + dataset.to_webdataset(path="/path/to/output", filter_column="passes_curation") + +``to_webdataset()`` will reshard the WebDataset to only include elements that have a value of ``True`` in the ``filter_column``. +Resharding can take a while, so this should typically only be done at the end of your curation pipeline when you are ready to export the dataset for training. + + +A more thorough list of parameters can be found in the `API Reference `_. + +------------- +Index Files +------------- + +NeMo Curator uses `DALI `_ for image data loading from the tar files. +In order to speed up the data loading, you can supply ``.idx`` files in your dataset. +The index files must be generated by DALI's wds2idx tool. +See the `DALI documentation `_ for more information. +Each index file must follow the same naming convention as the tar files (00042.tar -> 00042.idx). \ No newline at end of file diff --git a/docs/user-guide/image/embedders.rst b/docs/user-guide/image/embedders.rst new file mode 100644 index 00000000..83033f9c --- /dev/null +++ b/docs/user-guide/image/embedders.rst @@ -0,0 +1,121 @@ +.. _data-curator-image-embedding: + +========================= +Image Embedders +========================= + +-------------------- +Overview +-------------------- +Many image curation features in NeMo Curator operate on image embeddings instead of images directly. +Image embedders provide a scalable way of generating embeddings for each image in the dataset. + +-------------------- +Use Cases +-------------------- +* Aesthetic and NSFW classification both use image embeddings generated from OpenAI's CLIP ViT-L variant. +* Semantic deduplication computes the similarity of datapoints. + +-------------------- +Prerequisites +-------------------- +Make sure you check out the `image curation getting started page `_ to install everything you will need. + +-------------------- +Timm Image Embedder +-------------------- + +`PyTorch Image Models (timm) `_ is a library containing SOTA computer vision models. +Many of these models are useful in generating image embeddings for modules in NeMo Curator. + +.. code-block:: python + + from nemo_curator import get_client + from nemo_curator.datasets import ImageTextPairDataset + from nemo_curator.image.embedders import TimmImageEmbedder + + client = get_client(cluster_type="gpu") + + dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key") + + embedding_model = TimmImageEmbedder( + "vit_large_patch14_clip_quickgelu_224.openai", + pretrained=True, + batch_size=1024, + num_threads_per_worker=16, + normalize_embeddings=True, + ) + + dataset_with_embeddings = embedding_model(dataset) + + # Metadata will have a new column named "image_embedding" + dataset_with_embeddings.save_metadata() + +Here, we load a dataset in and compute the image embeddings using ``vit_large_patch14_clip_quickgelu_224.openai``. +At the end of the process, our metadata files have a new column named "image_embedding" that contains the image embedddings for each datapoint. + +-------------------- +Key Parameters +-------------------- +* ``pretrained=True`` ensures you download the pretrained weights of the model. +* ``batch_size=1024`` determines the number of images processed on each individual GPU at once. +* ``num_threads_per_worker=16`` determines the number of threads used by DALI for dataloading. +* ``normalize_embeddings=True`` will normalize each embedding. NeMo Curator's classifiers expect normalized embeddings as input. + +--------------------------- +Performance Considerations +--------------------------- + +Under the hood, the image embedding model performs the following operations: + +1. Download the weights of the model. +2. Download the PyTorch image transformations (resize and center-crop for example). +3. Convert the PyTorch image transformations to DALI transformations. +4. Load a shard of metadata (a ``.parquet`` file) onto each GPU you have available using Dask-cuDF. +5. Load a copy of the model onto each GPU. +6. Repeatedly load images into batches of size ``batch_size`` onto each GPU with a given threads per worker (``num_threads_per_worker``) using DALI. +7. The model is run on the batch (without ``torch.autocast()`` since ``autocast=False``). +8. The output embeddings of the model are normalized since ``normalize_embeddings=True``. + +There are a couple of key performance considerations from this flow. + +* You must have an NVIDIA GPU that mets the `requirements `_. +* You can create ``.idx`` files in the same directory of the tar files to speed up dataloading times. See the `DALI documentation `_ for more information. + +------------------------ +Custom Image Embedder +------------------------ + +To write your own custom embedder, you inherit from ``nemo_curator.image.embedders.ImageEmbedder`` and override two methods as shown below: + +.. code-block:: python + + from nemo_curator.image.embedders import ImageEmbedder + + class MyCustomEmbedder(ImageEmbedder): + + def load_dataset_shard(self, tar_path: str) -> Iterable: + # Implement me! + pass + + def load_embedding_model(self, device: str) -> Callable: + # Implement me! + pass + + +* ``load_dataset_shard()`` will take in a path to a tar file and return an iterable over the shard. The iterable should return a tuple of ``(a batch of data, metadata)``. + The batch of data can be of any form. It will be directly passed to the model returned by ``load_embedding_model()``. + The metadata should be a dictionary of metadata, with a field corresponding to the ``id_col`` of the dataset. + In our example, the metadata should include a value for ``"key"``. +* ``load_embedding_model()`` will take a device and return a callable object. + This callable will take as input a batch of data produced by ``load_dataset_shard()``. + +--------------------------- +Additional Resources +--------------------------- + +* `Aesthetic Classifier `_ +* `NSFW Classifier `_ +* `Semantic Deduplication `_ +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/gettingstarted.rst b/docs/user-guide/image/gettingstarted.rst new file mode 100644 index 00000000..dae4240d --- /dev/null +++ b/docs/user-guide/image/gettingstarted.rst @@ -0,0 +1,64 @@ + +.. _data-curator-image-getting-started: + +================ +Get Started +================ + +NeMo Curator provides many tools for curating large scale text-image pair datasets for training generative image models. + +--------------------- +Install NeMo Curator +--------------------- +To install the image curation modules of NeMo Curator, ensure you meet the following requirements: + +* Python 3.10 +* Ubuntu 22.04/20.04 +* NVIDIA GPU + * Voltaβ„’ or higher (compute capability 7.0+) + * CUDA 12 (or above) + +Note: While some of the text-based NeMo Curator modules do not require a GPU, all image curation modules require a GPU. + +You can get NeMo Curator in 3 ways. + +1. PyPi +2. Source +3. NeMo Framework Container + +##################### +PyPi +##################### +NeMo Curator's PyPi page can be found `here `_. + +.. code-block:: bash + + pip install cython + pip install nemo-curator[image] + +##################### +Source +##################### +NeMo Curator's GitHub can be found `here `_. + +.. code-block:: bash + + git clone https://github.com/NVIDIA/NeMo-Curator.git + pip install cython + pip install ./NeMo-Curator[image] + +############################ +NeMo Framework Container +############################ +NeMo Curator comes preinstalled in the NeMo Framework container. You can find a list of all the NeMo Framework container tags `here `_. + +--------------------- +Use NeMo Curator +--------------------- + +NeMo Curator can be run locally, or on a variety of compute platforms (Slurm, k8s, and more). + +To get started using the image modules in NeMo Curator, we recommend you check out the following resources: + +* `Image Curation Tutorial `_ +* `API Reference `_ \ No newline at end of file diff --git a/docs/user-guide/image/index.rst b/docs/user-guide/image/index.rst new file mode 100644 index 00000000..c6a53a86 --- /dev/null +++ b/docs/user-guide/image/index.rst @@ -0,0 +1,7 @@ +.. toctree:: + :maxdepth: 4 + :titlesonly: + + datasets.rst + embedders.rst + classifiers/index.rst \ No newline at end of file diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index 8ea2ea6b..1db64716 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -1,5 +1,9 @@ .. include:: datacuration.rsts +------------------- +Text Curation +------------------- + :ref:`Downloading and Extracting Text ` Downloading a massive public dataset is usually the first step in data curation, and it can be cumbersome due to the dataset’s massive size and hosting method. This section describes how to download and extract large corpora efficiently. @@ -19,7 +23,7 @@ Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF. :ref:`GPU Accelerated Semantic Deduplication ` - NeMo-Curator provides scalable and GPU accelerated semantic deduplication functionality using RAPIDS cuML, cuDF, crossfit and Pytorch. + NeMo Curator provides scalable and GPU accelerated semantic deduplication functionality using RAPIDS cuML, cuDF, crossfit and PyTorch. :ref:`Distributed Data Classification ` NeMo-Curator provides a scalable and GPU accelerated module to help users run inference with pre-trained models on large volumes of text documents. @@ -33,6 +37,56 @@ :ref:`Personally Identifiable Information Identification and Removal ` The purpose of the personally identifiable information (PII) redaction tool is to help scrub sensitive data out of training datasets +.. toctree:: + :maxdepth: 4 + :titlesonly: + + + download.rst + documentdataset.rst + cpuvsgpu.rst + qualityfiltering.rst + languageidentificationunicodeformatting.rst + gpudeduplication.rst + semdedup.rst + syntheticdata.rst + taskdecontamination.rst + personalidentifiableinformationidentificationandremoval.rst + distributeddataclassification.rst + +------------------- +Image Curation +------------------- + +:ref:`Get Started ` + Install NeMo Curator's image curation modules. + +:ref:`Image-Text Pair Datasets ` + Image-text pair datasets are commonly used as the basis for training multimodal generative models. NeMo Curator interfaces with the standardized WebDataset format for curating such datasets. + +:ref:`Image Embedding Creation ` + Image embeddings are the backbone to many data curation operations in NeMo Curator. This section describes how to efficiently create embeddings for massive datasets. + +:ref:`Classifiers ` + NeMo Curator provides several ways to use common classifiers like aesthetic scoring and not-safe-for-work (NSFW) scoring. + +:ref:`Semantic Deduplication ` + Semantic deduplication with image datasets has been shown to drastically improve model performance. NeMo Curator has a semantic deduplication module that can work with any modality. + +.. toctree:: + :maxdepth: 4 + :titlesonly: + + image/gettingstarted.rst + image/datasets.rst + image/classifiers/index.rst + semdedup.rst + + +------------------- +Reference +------------------- + :ref:`NeMo Curator on Kubernetes ` Demonstration of how to run the NeMo Curator on a Dask Cluster deployed on top of Kubernetes @@ -56,19 +110,8 @@ :titlesonly: - download.rst - documentdataset.rst - cpuvsgpu.rst - qualityfiltering.rst - languageidentificationunicodeformatting.rst - gpudeduplication.rst - semdedup.rst - syntheticdata.rst - taskdecontamination.rst - personalidentifiableinformationidentificationandremoval.rst - distributeddataclassification.rst kubernetescurator.rst sparkother.rst bestpractices.rst nextsteps.rst - api/index.rst + api/index.rst \ No newline at end of file diff --git a/nemo_curator/datasets/image_text_pair_dataset.py b/nemo_curator/datasets/image_text_pair_dataset.py index abdf244b..b580015c 100644 --- a/nemo_curator/datasets/image_text_pair_dataset.py +++ b/nemo_curator/datasets/image_text_pair_dataset.py @@ -28,9 +28,43 @@ class ImageTextPairDataset: + """ + A collection of image text pairs stored in WebDataset-like format on disk or in cloud storage. + + The exact format assumes a single directory with sharded .tar, .parquet, and (optionally) + .idx files. Each tar file should have a unique integer ID as its name (00000.tar, + 00001.tar, 00002.tar, etc.). The tar files should contain images in .jpg files, text captions + in .txt files, and metadata in .json files. Each record of the dataset is identified by + a unique ID that is a mix of the shard ID along with the offset of the record within a shard. + For example, the 32nd record of the 43rd shard would be in 00042.tar and have image 000420031.jpg, + caption 000420031.txt, and metadata 000420031.json (assuming zero indexing). + + In addition to the collection of tar files, ImageTextPairDataset expects there to be .parquet files + in the root directory that follow the same naming convention as the shards (00042.tar -> 00042.parquet). + Each Parquet file should contain an aggregated tabular form of the metadata for each record, with + each row in the Parquet file corresponding to a record in that shard. The metadata, both in the Parquet + files and the JSON files, must contain a unique ID column that is the same as its record ID (000420031 + in our examples). + + Index files may also be in the directory to speed up dataloading with DALI. + The index files must be generated by DALI's wds2idx tool. + See https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/dataloading_webdataset.html#Creating-an-index + for more information. Each index file must follow the same naming convention as the tar files + (00042.tar -> 00042.idx). + """ + def __init__( self, path: str, metadata: dd.DataFrame, tar_files: List[str], id_col: str ) -> None: + """ + Constructs an image-text pair dataset. + + Args: + path (str): The root directory of the files. + metadata (dd.DataFrame): A Dask-cuDF DataFrame of the metadata. + tar_files (List[str]): A list of paths to the tar files. + id_col (str): The column storing the unique identifier for each record. + """ self.path = path self.metadata = metadata self.tar_files = tar_files @@ -38,6 +72,13 @@ def __init__( @classmethod def from_webdataset(cls, path: str, id_col: str): + """ + Loads an ImageTextPairDataset from a WebDataset + + Args: + path (str): The path to the WebDataset-like format on disk or cloud storage. + id_col (str): The column storing the unique identifier for each record. + """ metadata = dask_cudf.read_parquet(path) metadata = metadata.map_partitions(cls._sort_partition, id_col=id_col) @@ -53,6 +94,7 @@ def _sort_partition(partition, id_col): def _get_tar_files(path: str) -> List[str]: glob_str = os.path.join(path, "*.tar") # open_files doesn't actually open a file descriptor + # tar_files is sorted by default tar_files = [file.path for file in open_files(glob_str)] return tar_files @@ -74,6 +116,16 @@ def _name_partition( def save_metadata( self, path: Optional[str] = None, columns: Optional[List[str]] = None ) -> None: + """ + Saves the metadata of the dataset to the specified path as a collection + of Parquet files. + + Args: + path (Optional[str]): The path to save the metadata to. If None, + writes to the original path. + columns (Optional[List[str]]): If specified, only saves a subset + of columns. + """ if path is None: path = self.path @@ -150,23 +202,37 @@ def _get_eligible_samples(self, output_path: str, samples_per_shard: int): yield curr_df, total_tar_samples @staticmethod - def combine_id(shard_id, sample_id, max_shards=5, max_samples_per_shard=4) -> str: + def _combine_id(shard_id, sample_id, max_shards=5, max_samples_per_shard=4) -> str: int_id = sample_id + (10**max_samples_per_shard) * shard_id n_digits = max_samples_per_shard + max_shards combined_id = f"{int_id:0{n_digits}d}" return combined_id - def split_id(combined_id: str, max_shards=5): - return int(combined_id[:max_shards]), int(combined_id[max_shards:]) - def to_webdataset( self, path: str, filter_column: str, samples_per_shard: int = 10000, - max_shards=5, - old_id_col=None, + max_shards: int = 5, + old_id_col: Optional[str] = None, ) -> None: + """ + Saves the dataset to a WebDataset format with Parquet files. + Will reshard the tar files to the specified number of samples per shard. + The ID value in ImageTextPairDataset.id_col will be overwritten with a new ID. + + Args: + path (str): The output path where the dataset should be written. + filter_column (str): A column of booleans. All samples with a value of True + in this column will be included in the output. Otherwise, the sample + will be omitted. + samples_per_shard (int): The number of samples to include in each tar file. + max_shards (int): The order of magnitude of the maximum number of shards + that will be created from the dataset. Will be used to determine the + number of leading zeros in the shard/sample IDs. + old_id_col (Optional[str]): If specified, will preserve the previous + ID value in the given column. + """ max_samples_per_shard = math.ceil(math.log10(samples_per_shard)) filtered_metadata = self.metadata[self.metadata[filter_column]] @@ -191,7 +257,7 @@ def to_webdataset( new_ids = np.arange(len(shard_df)) convert_ids = partial( - self.combine_id, + self._combine_id, shard_id, max_shards=max_shards, max_samples_per_shard=max_samples_per_shard, @@ -206,7 +272,7 @@ def to_webdataset( for i, (member, data) in enumerate(shard_tar): # Rename the each member to match the new id sample_id = int(i // members_per_sample) - member_id = self.combine_id( + member_id = self._combine_id( shard_id, sample_id, max_shards=max_shards, diff --git a/nemo_curator/image/classifiers/aesthetic.py b/nemo_curator/image/classifiers/aesthetic.py index 99e8c68b..c8881d12 100644 --- a/nemo_curator/image/classifiers/aesthetic.py +++ b/nemo_curator/image/classifiers/aesthetic.py @@ -46,6 +46,13 @@ def forward(self, x): class AestheticClassifier(ImageClassifier): + """ + LAION-Aesthetics_Predictor V2 is a linear classifier trained on top of + OpenAI CLIP ViT-L/14 image embeddings. It is used to assess the aesthetic + quality of images. More information on the model can be found here: + https://laion.ai/blog/laion-aesthetics/. + """ + def __init__( self, embedding_column: str = "image_embedding", @@ -53,6 +60,22 @@ def __init__( batch_size: int = -1, model_path: Optional[str] = None, ) -> None: + """ + Constructs the classifier. + + Args: + embedding_column (str): The column name that stores the image + embeddings. + pred_column (str): The column name to be added where the aesthetic + scores will be stored. + pred_type (Union[str, type]): The datatype of the pred_column. + batch_size (int): If greater than 0, the image embeddings + will be processed in batches of at most this size. If less than 0, + all embeddings will be processed at once. + model_path (Optional[str]): If specified, will load the model from the + given path. If not specified, will default to being stored in + NEMO_CURATOR_HOME. + """ super().__init__( model_name="aesthetic_classifier", embedding_column=embedding_column, @@ -90,11 +113,11 @@ def load_model(self, device): weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() - model = self.configure_forward(model) + model = self._configure_forward(model) return model - def configure_forward(self, model): + def _configure_forward(self, model): original_forward = model.forward def custom_forward(*args, **kwargs): diff --git a/nemo_curator/image/classifiers/base.py b/nemo_curator/image/classifiers/base.py index b9f5d52d..7ad9de01 100644 --- a/nemo_curator/image/classifiers/base.py +++ b/nemo_curator/image/classifiers/base.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from abc import ABC, abstractmethod -from typing import Union +from typing import Callable, Union +import cudf import cupy as cp import torch @@ -26,7 +26,12 @@ class ImageClassifier(ABC): """ An abstract base class that represents a classifier on top - of embeddings generated by a CLIP vision encoder + of embeddings generated by a CLIP vision encoder. + + Subclasses only need to define how a model is loaded. + They may also override the postprocess method if they would like + to modify output series of predictions before it gets combined into + the dataset. The classifier must be able to fit on a single GPU. """ def __init__( @@ -38,6 +43,21 @@ def __init__( batch_size: int, embedding_size: int, ) -> None: + """ + Constructs an image classifier. + + Args: + model_name (str): A unqiue name to identify the model on each worker + and in the logs. + embedding_column (str): The column name that stores the image + embeddings. + pred_column (str): The column name to be added where the classifier's + predictions will be stored. + pred_type (Union[str, type]): The datatype of the pred_column. + batch_size (int): If greater than 0, the image embeddings + will be processed in batches of at most this size. If less than 0, + all embeddings will be processed at once. + """ self.model_name = model_name self.embedding_column = embedding_column self.pred_column = pred_column @@ -46,6 +66,15 @@ def __init__( self.embedding_size = embedding_size def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + """ + Classifies all embeddings in the dataset. + + Args: + dataset (ImageTextPairDataset): The dataset to classify. + + Returns: + ImageTextPairDataset: A dataset with classifier scores. + """ meta = dataset.metadata.dtypes.to_dict() meta[self.pred_column] = self.pred_type embedding_df = dataset.metadata.map_partitions(self._run_inference, meta=meta) @@ -104,8 +133,31 @@ def _run_inference(self, partition, partition_info=None): return partition @abstractmethod - def load_model(self, device): + def load_model(self, device: str) -> Callable: + """ + Loads the classifier model. + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A callable model, usually a torch.nn.Module. + The input to this model will be the batches of images output + by the ImageEmbedder.load_dataset_shard. + """ pass - def postprocess(self, series): + def postprocess(self, series: cudf.Series) -> cudf.Series: + """ + Postprocesses the predictions of the classifier before saving + them to the metadata. + + Args: + series (cudf.Series): The cuDF series of raw model predictions. + + Returns: + cudf.Series: The same series unmodified. Override in your classifier + if needed. + """ return series diff --git a/nemo_curator/image/classifiers/nsfw.py b/nemo_curator/image/classifiers/nsfw.py index e66abcae..ef18fed7 100644 --- a/nemo_curator/image/classifiers/nsfw.py +++ b/nemo_curator/image/classifiers/nsfw.py @@ -53,6 +53,14 @@ def forward(self, x): class NsfwClassifier(ImageClassifier): + """ + NSFW Classifier is a small MLP trained on top of + OpenAI's ViT-L CLIP image embeddings. It is used to assess the likelihood + of images containing sexually explicit material. + More information on the model can be found here: + https://github.com/LAION-AI/CLIP-based-NSFW-Detector. + """ + def __init__( self, embedding_column: str = "image_embedding", @@ -60,6 +68,22 @@ def __init__( batch_size: int = -1, model_path: Optional[str] = None, ) -> None: + """ + Constructs the classifier. + + Args: + embedding_column (str): The column name that stores the image + embeddings. + pred_column (str): The column name to be added where the nsfw + scores will be stored. + pred_type (Union[str, type]): The datatype of the pred_column. + batch_size (int): If greater than 0, the image embeddings + will be processed in batches of at most this size. If less than 0, + all embeddings will be processed at once. + model_path (Optional[str]): If specified, will load the model from the + given path. If not specified, will default to being stored in + NEMO_CURATOR_HOME. + """ super().__init__( model_name="nsfw_classifier", embedding_column=embedding_column, @@ -94,11 +118,11 @@ def load_model(self, device): weights = torch.load(self.model_path, map_location=torch.device("cpu")) model.load_state_dict(weights) model.eval() - model = self.configure_forward(model) + model = self._configure_forward(model) return model - def configure_forward(self, model): + def _configure_forward(self, model): original_forward = model.forward def custom_forward(*args, **kwargs): diff --git a/nemo_curator/image/embedders/base.py b/nemo_curator/image/embedders/base.py index 5acfc0f0..d910e170 100644 --- a/nemo_curator/image/embedders/base.py +++ b/nemo_curator/image/embedders/base.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from abc import ABC, abstractmethod -from typing import Iterable +from typing import Callable, Iterable import cupy as cp import torch @@ -26,17 +25,49 @@ class ImageEmbedder(ABC): + """ + An abstract base class for generating image embeddings. + + Subclasses only need to define how a model is loaded and a dataset + is read in from a tar file shard. This class handles distributing + the tasks across workers and saving the metadata to the dataset. + The embedding model must be able to fit onto a single GPU. + """ + def __init__( self, model_name: str, image_embedding_column: str, classifiers: Iterable[ImageClassifier], ) -> None: + """ + Constructs an image embedder. + + Args: + model_name (str): A unqiue name to identify the model on each worker + and in the logs. + image_embedding_column (str): The column name to be added where the + image embeddings will be saved. + classifiers (Iterable[ImageClassifier]): A collection of classifiers. If + the iterable has a nonzero length, all classifiers will be loaded + on the GPU at the same time and be passed the image embeddings + immediately after they are created. + """ self.model_name = model_name self.image_embedding_column = image_embedding_column self.classifiers = classifiers def __call__(self, dataset: ImageTextPairDataset) -> ImageTextPairDataset: + """ + Generates image embeddings for all images in the dataset. + + Args: + dataset (ImageTextPairDataset): The dataset to create image embeddings for. + + Returns: + ImageTextPairDataset: A dataset with image embeddings and potentially + classifier scores. + """ meta = dataset.metadata.dtypes.to_dict() meta[self.image_embedding_column] = "object" for classifier in self.classifiers: @@ -122,8 +153,36 @@ def _run_inference(self, partition, tar_paths, id_col, partition_info=None): @abstractmethod def load_dataset_shard(self, tar_path: str) -> Iterable: + """ + Loads images and metadata from a tarfile in the dataset. + + Args: + tar_path (str): The path to a tar file shard in the input WebDataset. + + Returns: + Iterable: An iterator over the dataset. Each iteration should produce + a tuple of (image, metadata) pairs. The batch of images will be passed + directly to the model created by ImageEmbedder.load_embedding_model. + The metadata must be a list of dictionaries. Each element of the list + must correspond to the image in the batch at the same position. + Each dictionary must contain a field that is the same as + id_field in the dataset. This ID field in the metadata will be used + to match the image to the its record in the metadata (Parquet) files. + """ pass @abstractmethod - def load_embedding_model(self, device): + def load_embedding_model(self, device: str) -> Callable: + """ + Loads the model used to generate image embeddings. + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A callable model, usually a torch.nn.Module. + The input to this model will be the batches of images output + by the ImageEmbedder.load_dataset_shard. + """ pass diff --git a/nemo_curator/image/embedders/timm.py b/nemo_curator/image/embedders/timm.py index 3a7c258a..fac2fba2 100644 --- a/nemo_curator/image/embedders/timm.py +++ b/nemo_curator/image/embedders/timm.py @@ -26,6 +26,14 @@ class TimmImageEmbedder(ImageEmbedder): + """ + PyTorch Image Models (timm) is a library containing SOTA computer vision + models. Many of these models are useful in generating image embeddings + for modules in NeMo Curator. This module can also automatically convert + the image transformations from PyTorch transformations to DALI transformations + in the supported models. + """ + def __init__( self, model_name: str, @@ -36,8 +44,34 @@ def __init__( normalize_embeddings: bool = True, classifiers: Iterable = [], autocast: bool = True, - use_index_files=False, + use_index_files: bool = False, ) -> None: + """ + Constructs the embedder. + + Args: + model_name (str): The timm model to use. A list of available models + can be found by running timm.list_models() + pretrained (bool): If True, loads the pretrained weights of the model. + batch_size (int): The number of images to run inference on in a single batch. + If the batch_size is larger than the number of elements in a shard, only + the number of elements in a shard will be used. + num_threads_per_worker (int): The number of threads per worker (GPU) to use + for loading images with DALI. + image_embedding_column (str): The output column where the embeddings will be + stored in the dataset. + normalize_embeddings (bool): Whether to normalize the embeddings output by the + model. Defaults to True. + classifiers (Iterable): A collection of classifiers to immediately apply on top + of the image embeddings. + autocast (bool): If True, runs the timm model using torch.autocast(). + use_index_files (bool): If True, tries to find and use index files generated + by DALI at the same path as the tar file shards. The index files must be + generated by DALI's wds2idx tool. See https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/dataloading_webdataset.html#Creating-an-index + for more information. Each index file must be of the form "shard_id.idx" + where shard_id is the same integer as the corresponding tar file for the + data. The index files must be in the same folder as the tar files. + """ super().__init__( model_name=model_name, image_embedding_column=image_embedding_column, @@ -58,6 +92,21 @@ def __init__( self.dali_transforms = convert_transforms_to_dali(torch_transforms) def load_dataset_shard(self, tar_path: str): + """ + Loads a WebDataset tar shard using DALI. + + Args: + tar_path (str): The path of the tar shard to load. + + Returns: + Iterable: An iterator over the dataset. Each tar file + must have 3 files per record: a .jpg file, a .txt file, + and a .json file. The .jpg file must contain the image, the + .txt file must contain the associated caption, and the + .json must contain the metadata for the record (including + its ID). Images will be loaded using DALI. + """ + # Create the DALI pipeline @pipeline_def( batch_size=self.batch_size, @@ -115,13 +164,25 @@ def webdataset_pipeline(_tar_path: str): yield image, metadata def load_embedding_model(self, device="cuda"): + """ + Loads the model used to generate image embeddings. + + Args: + device (str): A PyTorch device identifier that specifies what GPU + to load the model on. + + Returns: + Callable: A timm model loaded on the specified device. + The model's forward call may be augmented with torch.autocast() + or embedding normalization if specified in the constructor. + """ model = timm.create_model(self.model_name, pretrained=self.pretrained).eval() model = model.to(device) - model = self.configure_forward(model) + model = self._configure_forward(model) return model - def configure_forward(self, model): + def _configure_forward(self, model): original_forward = model.forward def custom_forward(*args, **kwargs): @@ -139,7 +200,3 @@ def custom_forward(*args, **kwargs): model.forward = custom_forward return model - - @staticmethod - def torch_normalized(a, dim=-1): - return torch.nn.functional.normalize(a, dim=dim) diff --git a/setup.py b/setup.py index 8b35bc8e..8a182f18 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def req_file(filename, folder="requirements"): setup( name="nemo_curator", - version="0.4.0", + version="0.5.0", description="Scalable Data Preprocessing Tool for " "Training Large Language Models", long_description=long_description,