From cef1c3408684e9ce96e8125f041378407daf0a7b Mon Sep 17 00:00:00 2001 From: Umair Ahmed Date: Thu, 26 Sep 2024 00:35:02 +0530 Subject: [PATCH 1/2] Added example notebook for translation with ct2 model. Signed-off-by: Ahmed Umair --- .../translation/ct2_hindi_translation.ipynb | 562 ++++++++++++++++++ 1 file changed, 562 insertions(+) create mode 100644 tutorials/translation/ct2_hindi_translation.ipynb diff --git a/tutorials/translation/ct2_hindi_translation.ipynb b/tutorials/translation/ct2_hindi_translation.ipynb new file mode 100644 index 00000000..02a498a3 --- /dev/null +++ b/tutorials/translation/ct2_hindi_translation.ipynb @@ -0,0 +1,562 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Indic Translation\n", + "This notebook demonstrate an example use of nemo-curator for Indic language generation via translation from English language. This workflow is accelarated by CrossFit, a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets. \n", + "This example uses ctransalte2 model from [here](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/en-indic-preprint.zip), taken from IndicTrans2 github repo, [here](https://github.com/AI4Bharat/IndicTrans2/tree/main?tab=readme-ov-file#multilingual-translation-models)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"DASK_DATAFRAME__QUERY_PLANNING\"] = \"False\"\n", + "import argparse\n", + "import re\n", + "import time\n", + "from dataclasses import dataclass\n", + "from functools import lru_cache\n", + "\n", + "import cudf\n", + "import ctranslate2\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from crossfit import op\n", + "from crossfit.backend.torch.hf.model import HFModel\n", + "from dask.distributed import get_worker\n", + "from nltk.tokenize import sent_tokenize\n", + "from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator.classifiers.base import DistributedDataClassifier\n", + "from nemo_curator.datasets import DocumentDataset\n", + "from nemo_curator.utils.distributed_utils import get_client, load_object_on_worker\n", + "from nemo_curator.utils.script_utils import ArgumentHelper" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For pre and post processing, we are using IndicTransToolkit, required for translation using IndcTrans2 models. It is simple, modoular library for preprocessing, normalizations, postprocessing stuff." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from IndicTransToolkit import IndicProcessor\n", + "except ImportError:\n", + " raise ImportError(\n", + " \"IndicTransToolkit not found. Please install it using the following command: \\n\"\n", + " + \"pip install git+https://github.com/VarunGumma/IndicTransToolkit.git\"\n", + " )\n", + "import dask_cudf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define CT2CustomModel\n", + "\n", + "Now, we will create a ctranslate2 model class, rquired for calling inference on ct2 model, and attch it with crossfit's HFModel class. These Model definitions can be found in CrossFit's example, [here](https://github.com/rapidsai/crossfit/pull/83/files#diff-d3c29a7456aac8be2bb3d53ba3d983e36631ea8dd36c4e52d9f3217183d4568f). A TranslationConfig dataclass is required for passing arelevent arguments." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class TranslationConfig:\n", + " pretrained_model_name_or_path: str\n", + " ct2_model_path: str\n", + " max_words_per_sen: int = 200\n", + " target_lang_code: str = \"hin_Deva\"\n", + "\n", + "\n", + "class CT2CustomModel:\n", + " def __init__(self, config: TranslationConfig, device=\"cuda\"):\n", + " self.config = config\n", + " self.tokenizer = AutoTokenizer.from_pretrained(\n", + " pretrained_model_name_or_path=config.pretrained_model_name_or_path,\n", + " trust_remote_code=True,\n", + " )\n", + " self.model = ctranslate2.Translator(\n", + " model_path=config.ct2_model_path, device=device\n", + " )\n", + "\n", + " def clean_extra_tokens(self, token_2d):\n", + " results = []\n", + " for token_1d in token_2d:\n", + " result = []\n", + " for t in token_1d:\n", + " if (\n", + " t == self.tokenizer.pad_token\n", + " or t == self.tokenizer.bos_token\n", + " or t == self.tokenizer.eos_token\n", + " or t == self.tokenizer.unk_token\n", + " ):\n", + " pass\n", + " else:\n", + " result.append(t)\n", + " results.append(result)\n", + " return results\n", + "\n", + " def __call__(self, batch):\n", + " token_ids_2d = batch[\"input_ids\"]\n", + " token_ids_1d = token_ids_2d.view(-1).tolist()\n", + " tokens_1d = self.tokenizer.convert_ids_to_tokens(token_ids_1d)\n", + " tokens_2d = [\n", + " tokens_1d[i : i + token_ids_2d.size(1)]\n", + " for i in range(0, len(tokens_1d), token_ids_2d.size(1))\n", + " ]\n", + " tokens = self.clean_extra_tokens(tokens_2d)\n", + "\n", + " tr_res = self.model.translate_batch(\n", + " tokens,\n", + " min_decoding_length=0,\n", + " max_decoding_length=256,\n", + " beam_size=5,\n", + " num_hypotheses=1,\n", + " )\n", + " translations = [\"\".join(x.hypotheses[0]) for x in tr_res]\n", + " return translations\n", + "\n", + "\n", + "class ModelForSeq2SeqModel(HFModel):\n", + " def __init__(self, config):\n", + " self.trans_config = config\n", + " self.config = self.load_cfg()\n", + " super().__init__(\n", + " self.trans_config.pretrained_model_name_or_path, model_output_type=\"string\"\n", + " )\n", + "\n", + " def load_model(self, device=\"cuda\"):\n", + " model = CT2CustomModel(self.trans_config)\n", + " return model\n", + "\n", + " def load_config(self):\n", + " return self.load_cfg()\n", + "\n", + " def load_tokenizer(self):\n", + " return AutoTokenizer.from_pretrained(\n", + " pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path,\n", + " trust_remote_code=True,\n", + " )\n", + "\n", + " def max_seq_length(self) -> int:\n", + " return self.config.max_source_positions\n", + "\n", + " def load_cfg(self):\n", + " config = AutoConfig.from_pretrained(\n", + " pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path,\n", + " trust_remote_code=True,\n", + " )\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define IndicTranslation class\n", + "\n", + "Now that we have created relevent model classes from crossfit side, for running the pipeline we need to inherit from __DistributedDataClassifier__ of nemo-curator, and implement its **_run_classifier** method inside __IndicTranslation__\n", + "\n", + "_run_classifier method is responsible for running the inference. For our translation use case we need to have preprocessing, filtering before call for inference and postprocessing after the inference.\n", + "\n", + "In this example we have added pre and postprocessing from _run_classifier method. Overall major steps will be as follows : \n", + "\n", + "1. Run process_input_text method which will be responsible for breaking english sentences via nltk's sentence tokenizer into sentences of specified length(default = 200 words).\n", + "2. Filter data where sentence should at least have 1 alphabet in it.\n", + "3. Left over data from step 2 won't go for translation but will be added for final data with translation as same as input text.\n", + "4. Data which passed from step 2 will go fro indic preprocessing from IndicTransToolkit.\n", + "5. CrossFit's Toeknizer and Predictor will run on data.\n", + "6. Output from step 5 will go for detokenization and indic postprocessing.\n", + "7. Combining the results from step 6 and step 3 and reutrn the data.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class IndicTranslation(DistributedDataClassifier):\n", + " def __init__(\n", + " self,\n", + " ct2_model_path: str,\n", + " pretrained_model_name_or_path: str = \"ai4bharat/indictrans2-en-indic-1B\",\n", + " input_column: str = \"indic_proc_text\",\n", + " batch_size: int = 128,\n", + " autocast: bool = False,\n", + " target_lang_code: str = \"hin_Deva\",\n", + " ):\n", + " self.input_column = input_column\n", + " self.batch_size = batch_size\n", + " self.autocast = autocast\n", + "\n", + " self.translation_config = TranslationConfig(\n", + " pretrained_model_name_or_path=pretrained_model_name_or_path,\n", + " ct2_model_path=ct2_model_path,\n", + " target_lang_code=target_lang_code,\n", + " )\n", + " self.model = ModelForSeq2SeqModel(self.translation_config)\n", + " super().__init__(\n", + " model=self.model,\n", + " batch_size=self.batch_size,\n", + " device_type=\"cuda\",\n", + " autocast=self.autocast,\n", + " labels=None,\n", + " filter_by=None,\n", + " out_dim=None,\n", + " pred_column=None,\n", + " max_chars=None,\n", + " )\n", + "\n", + " def preprocess_df(self, df: cudf.DataFrame) -> cudf.DataFrame:\n", + " ip = load_object_on_worker(\n", + " \"IndicProcessor\", IndicProcessor, {\"inference\": True}\n", + " )\n", + " indices = df[\"text\"].index.to_arrow().to_pylist()\n", + " sentences = df[\"text\"].to_arrow().to_pylist()\n", + " sentences = ip.preprocess_batch(\n", + " sentences,\n", + " src_lang=\"eng_Latn\",\n", + " tgt_lang=self.translation_config.target_lang_code, # \"hin_Deva\"\n", + " )\n", + " df[\"indic_proc_text\"] = cudf.Series(sentences, index=indices)\n", + " return df\n", + "\n", + " def translate_tokens(self, df: cudf.DataFrame) -> cudf.DataFrame:\n", + " worker = get_worker()\n", + " if hasattr(worker, \"IndicProcessor\"):\n", + " ip = getattr(worker, \"IndicProcessor\")\n", + " else:\n", + " ip = load_object_on_worker(\n", + " \"IndicProcessor\", IndicProcessor, {\"inference\": True}\n", + " )\n", + " tokenizer = self.model.load_tokenizer()\n", + " indices = df[\"translation\"].index.to_arrow().to_pylist()\n", + " generated_tokens = df[\"translation\"].to_arrow().to_pylist()\n", + " converted_tokens = []\n", + " for g in generated_tokens:\n", + " converted_tokens.append(tokenizer.convert_tokens_to_string(g))\n", + " converted_tokens = ip.postprocess_batch(\n", + " converted_tokens, lang=self.translation_config.target_lang_code\n", + " )\n", + " df[\"translation\"] = cudf.Series(data=converted_tokens, index=indices)\n", + " return df\n", + "\n", + " def has_alphabet_characters(self, text: str) -> bool:\n", + " return any(c.isalpha() for c in text)\n", + "\n", + " def custom_tokenize(self, text: str):\n", + " split_text = re.split(\n", + " r\"(\\#{2,}|\\_{2,}|\\…{2,}|\\+{2,}|\\.{2,}|\\-{3,}|\\*{2,}|\\~{2,}|\\={2,}|\\!{2,}|\\n|\\t|\\‣|\\⁃|\\⁌|\\⁍|\\●|\\○|\\•|\\·|\\◘|\\◦|\\⦾|\\⦿|\\|)\",\n", + " text,\n", + " )\n", + " split_text = [s for s in split_text if len(s) > 0]\n", + " tokenized_sentences = []\n", + " len_flag = False\n", + " for line in split_text:\n", + " # Tokenize sentences using NLTK's sent_tokenize function\n", + " if self.has_alphabet_characters(line) == True:\n", + " sentences = sent_tokenize(line)\n", + " i = 0\n", + " j = 0\n", + " curr_tokenized_snt = []\n", + " non_translation_str = \"\"\n", + " # Comparing the list of tokenized sentences (using NLTK) and actual sentence and preserving the spaces,\n", + " # newline and other special characters\n", + " while i < len(line):\n", + " if j < len(sentences):\n", + " stripped_sent = sentences[j].strip()\n", + " if len(stripped_sent) == 0:\n", + " j += 1\n", + " continue\n", + " # If tokenized sentence matches then moving to next sentence\n", + " if line[i] == stripped_sent[0]:\n", + " if non_translation_str != \"\":\n", + " curr_tokenized_snt.append(non_translation_str)\n", + " curr_tokenized_snt.append(stripped_sent)\n", + " i += len(stripped_sent)\n", + " j += 1\n", + " non_translation_str = \"\"\n", + " else:\n", + " non_translation_str += line[i]\n", + " i += 1\n", + " else:\n", + " non_translation_str += line[i]\n", + " i += 1\n", + " if non_translation_str != \"\":\n", + " curr_tokenized_snt.append(non_translation_str)\n", + " # Add the tokenized sentences to the list\n", + " tokenized_sentences.extend(curr_tokenized_snt)\n", + " else:\n", + " tokenized_sentences.append(line)\n", + "\n", + " tokenized_sentence_len = []\n", + " for sentence in tokenized_sentences:\n", + " sent = sentence.split()\n", + " # removing the sentences with word length greater than threshold as the model may not be able translate it due to constraint on output token size\n", + " if len(sent) <= self.translation_config.max_words_per_sen:\n", + " tokenized_sentence_len.append(sentence)\n", + "\n", + " return tokenized_sentence_len\n", + "\n", + " def process_input_text(self, df: cudf.DataFrame) -> cudf.DataFrame:\n", + " df = df.to_pandas()\n", + " df[\"text\"] = df[\"text\"].apply(self.custom_tokenize)\n", + " df[\"doc_id\"] = np.arange(1, len(df) + 1)\n", + " df = df.explode(\"text\", ignore_index=True)\n", + " df = df.reset_index(drop=False)\n", + " df = cudf.DataFrame.from_pandas(df)\n", + " return df\n", + "\n", + " def remove_false_fullstop(self, df: cudf.DataFrame) -> cudf.DataFrame:\n", + " engligh_stop_flag = df[\"text\"].str.endswith(\".\")\n", + " hindi_stop_flag = df[\"translation\"].str.endswith(\"|\")\n", + " df[\"translation\"][~engligh_stop_flag & hindi_stop_flag] = df[\n", + " \"translation\"\n", + " ].str.rstrip(\"|\")\n", + " df[\"translation\"] = df[\"translation\"].str.strip()\n", + " return df\n", + "\n", + " def grouping(self, df: cudf.DataFrame) -> cudf.DataFrame:\n", + " df = df.to_pandas()\n", + " agg_funcs = {\n", + " \"translation\": lambda s: \"\".join(s),\n", + " \"text\": lambda s: \"\".join(s),\n", + " }\n", + " other_columns = {\n", + " col: \"first\"\n", + " for col in df.columns\n", + " if col not in agg_funcs and col != \"doc_id\"\n", + " }\n", + "\n", + " agg_funcs.update(other_columns)\n", + " df = df.groupby(\"doc_id\").agg(agg_funcs).reset_index()\n", + " df = cudf.DataFrame.from_pandas(df)\n", + " return df\n", + "\n", + " def atleast_letter(self, df: cudf.DataFrame, column_name: str) -> cudf.DataFrame:\n", + " df = df.to_pandas()\n", + " df[\"isalpha\"] = df[column_name].apply(self.has_alphabet_characters)\n", + " df = cudf.DataFrame(df)\n", + " return df\n", + "\n", + " def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:\n", + " ddf = dataset.df\n", + " # Applying process_input_text for following :\n", + " # 1. nltk tokenization to break doc into sentences\n", + " # 2. craeting a row w.r.t each sentence.\n", + " # 3. Process sentences strip symbols from start and end\n", + " ddf_true = ddf.map_partitions(self.process_input_text, enforce_metadata=False)\n", + " ddf_true[\"text\"] = ddf_true[\"text\"].astype(\"str\")\n", + "\n", + " # To filter for atleast one unicode letter in text\n", + " has_letter = ddf_true.map_partitions(self.atleast_letter, column_name=\"text\")\n", + " ddf = ddf_true[has_letter[\"isalpha\"]]\n", + " ## ddf false operations\n", + " ddf_false = ddf_true[~has_letter[\"isalpha\"]]\n", + " ddf_false[\"translation\"] = ddf_false[\"text\"]\n", + " # Applying preprocess_df for Indic preprocessing\n", + " ddf[\"text\"] = ddf[\"text\"].astype(\"str\")\n", + " ddf_meta = ddf._meta.copy()\n", + " ddf_meta[\"indic_proc_text\"] = \"\"\n", + " ddf = ddf.map_partitions(self.preprocess_df, meta=ddf_meta)\n", + "\n", + " columns = ddf.columns.tolist()\n", + " pipe = op.Sequential(\n", + " op.Tokenizer(\n", + " self.model,\n", + " cols=[self.input_column],\n", + " tokenizer_type=\"default\",\n", + " max_length=255,\n", + " ),\n", + " op.Predictor(\n", + " self.model,\n", + " sorted_data_loader=True,\n", + " batch_size=self.batch_size,\n", + " pred_output_col=\"translation\",\n", + " ),\n", + " keep_cols=columns,\n", + " )\n", + " ddf = pipe(ddf)\n", + " translated_meta = ddf._meta.copy()\n", + " translated_meta[\"translation\"] = \"DUMMY_STRING\"\n", + " ddf = ddf.map_partitions(self.translate_tokens, meta=translated_meta)\n", + " ddf = ddf.map_partitions(self.remove_false_fullstop, meta=translated_meta)\n", + "\n", + " # Merging translated and non-translated samples\n", + " ddf_true[\"false_translation\"] = ddf_false[\"translation\"]\n", + " ddf_true[\"false_translation\"] = ddf_true[\"false_translation\"].fillna(\"\")\n", + " ddf_true[\"translation\"] = ddf[\"translation\"]\n", + " ddf_true[\"translation\"] = ddf_true[\"translation\"].fillna(\"\")\n", + " ddf_true[\"translation\"] = (\n", + " ddf_true[\"translation\"] + ddf_true[\"false_translation\"]\n", + " )\n", + "\n", + " ddf = ddf_true.map_partitions(self.grouping)\n", + " return DocumentDataset(ddf)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start dask client." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "client = get_client(cluster_type=\"gpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define input" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "text = [\n", + " \"Quantum computing is set to revolutionize the field of cryptography.\",\n", + " \"Investing in index funds is a popular strategy for long-term financial growth.\",\n", + " \"Recent advancements in gene therapy offer new hope for treating genetic disorders.\",\n", + " \"Online learning platforms have transformed the way students access educational resources.\",\n", + " \"Traveling to Europe during the off-season can be a more budget-friendly option.\",\n", + " \"Training regimens for athletes have become more sophisticated with the use of data analytics.\",\n", + " \"Streaming services are changing the way people consume television and film content.\",\n", + " \"Vegan recipes have gained popularity as more people adopt plant-based diets.\",\n", + " \"Climate change research is critical for developing sustainable environmental policies.\",\n", + " \"Telemedicine has become increasingly popular due to its convenience and accessibility.\",\n", + "]\n", + "df = cudf.DataFrame({\"text\": text})\n", + "input_dataset = DocumentDataset(dask_cudf.from_cudf(df, npartitions=1))\n", + "write_to_filename = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define output directory\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "output_data_dir = \"out_data\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start the translation\n", + "\n", + "IndicTranslation will need ct2_model_path, the model path of ctranslate2 converted model(which is downloaded from [here](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/en-indic-preprint.zip))." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/uahmed/Desktop/NeMo-Curator/.ndc/lib/python3.10/site-packages/sklearn/base.py:376: InconsistentVersionWarning: Trying to unpickle estimator LinearRegression from version 1.4.2 when using version 1.5.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", + "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", + " warnings.warn(\n", + "GPU: 0, Part: 0: 0%| | 0/10 [05:33 Date: Thu, 3 Oct 2024 10:59:23 +0530 Subject: [PATCH 2/2] Restructured according to comments. Signed-off-by: Ahmed Umair --- .../translation/ct2_hindi_translation.ipynb | 115 ++++++++++++++---- 1 file changed, 88 insertions(+), 27 deletions(-) diff --git a/tutorials/translation/ct2_hindi_translation.ipynb b/tutorials/translation/ct2_hindi_translation.ipynb index 02a498a3..32960280 100644 --- a/tutorials/translation/ct2_hindi_translation.ipynb +++ b/tutorials/translation/ct2_hindi_translation.ipynb @@ -5,10 +5,19 @@ "metadata": {}, "source": [ "# Indic Translation\n", - "This notebook demonstrate an example use of nemo-curator for Indic language generation via translation from English language. This workflow is accelarated by CrossFit, a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets. \n", + "This notebook demonstrate an example use of nemo-curator for Indic language generation via translation from English language which can be scaleup to use multiple node multiple gpus. This workflow is accelarated by CrossFit, a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets. \n", "This example uses ctransalte2 model from [here](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/en-indic-preprint.zip), taken from IndicTrans2 github repo, [here](https://github.com/AI4Bharat/IndicTrans2/tree/main?tab=readme-ov-file#multilingual-translation-models)" ] }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Imports section" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -22,36 +31,42 @@ "import re\n", "import time\n", "from dataclasses import dataclass\n", - "from functools import lru_cache\n", "\n", "import cudf\n", "import ctranslate2\n", "import numpy as np\n", "import torch\n", - "import torch.nn as nn\n", - "from crossfit import op\n", - "from crossfit.backend.torch.hf.model import HFModel\n", "from dask.distributed import get_worker\n", + "import dask_cudf\n", "from nltk.tokenize import sent_tokenize\n", "from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### crossfit and nemo_curator imports" + ] + }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ + "from crossfit import op\n", + "from crossfit.backend.torch.hf.model import HFModel\n", "from nemo_curator.classifiers.base import DistributedDataClassifier\n", "from nemo_curator.datasets import DocumentDataset\n", - "from nemo_curator.utils.distributed_utils import get_client, load_object_on_worker\n", - "from nemo_curator.utils.script_utils import ArgumentHelper" + "from nemo_curator.utils.distributed_utils import get_client, load_object_on_worker" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ + "### IndicTransToolkit import\n", "For pre and post processing, we are using IndicTransToolkit, required for translation using IndcTrans2 models. It is simple, modoular library for preprocessing, normalizations, postprocessing stuff." ] }, @@ -67,17 +82,41 @@ " raise ImportError(\n", " \"IndicTransToolkit not found. Please install it using the following command: \\n\"\n", " + \"pip install git+https://github.com/VarunGumma/IndicTransToolkit.git\"\n", - " )\n", - "import dask_cudf" + " )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Define CT2CustomModel\n", + "## Table of content\n", + "1. [Ctranslate2 model integration](#ctranslate2-model-integration).\n", + "2. [Define IndicTranslation class](#define-indictranslation-class).\n", + "3. [Start the dask cluster](#start-the-dask-cluster).\n", + "4. [Define input](#define-input).\n", + "5. [Define output directory](#define-output-directory).\n", + "6. [Start the trasnaltion](#start-the-translation)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "### CTranslate2 Model Integration\n", + "\n", + "We'll now create a custom CTranslate2 model class, which is essential for performing inference on the CT2 converted model. This example uses ctransalte2 model from [here](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/en-indic-preprint.zip), taken from IndicTrans2 github repo, [here](https://github.com/AI4Bharat/IndicTrans2/tree/main?tab=readme-ov-file#multilingual-translation-models). CTranslate2 is a C++ and Python library for efficient inference with Transformer models. You can read more about it [here](https://github.com/OpenNMT/CTranslate2). One of the features of it is, it enables fast and efficient execution on both CPU and GPU.\n", + "\n", + "This class will be integrated with CrossFit's `HFModel` class to leverage CrossFit's efficient batching and processing capabilities.\n", "\n", - "Now, we will create a ctranslate2 model class, rquired for calling inference on ct2 model, and attch it with crossfit's HFModel class. These Model definitions can be found in CrossFit's example, [here](https://github.com/rapidsai/crossfit/pull/83/files#diff-d3c29a7456aac8be2bb3d53ba3d983e36631ea8dd36c4e52d9f3217183d4568f). A TranslationConfig dataclass is required for passing arelevent arguments." + "#### Key Components:\n", + "\n", + "1. **CT2CustomModel**: A custom class for CTranslate2 model inference.\n", + "2. **ModelForSeq2SeqModel**: An extension of CrossFit's `HFModel` class, tailored for our translation task.\n", + "3. **TranslationConfig**: A dataclass for managing translation-specific configuration parameters.\n", + "\n", + "These model definitions are inspired by examples from the CrossFit project. For reference, you can find similar implementations in the [CrossFit GitHub repository](https://github.com/rapidsai/crossfit/pull/83/files#diff-d3c29a7456aac8be2bb3d53ba3d983e36631ea8dd36c4e52d9f3217183d4568f).\n" ] }, { @@ -177,9 +216,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ - "## Define IndicTranslation class\n", + "### Define IndicTranslation class\n", "\n", "Now that we have created relevent model classes from crossfit side, for running the pipeline we need to inherit from __DistributedDataClassifier__ of nemo-curator, and implement its **_run_classifier** method inside __IndicTranslation__\n", "\n", @@ -198,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -265,6 +306,7 @@ " converted_tokens = ip.postprocess_batch(\n", " converted_tokens, lang=self.translation_config.target_lang_code\n", " )\n", + " print(f\"Translated samples :\\n{converted_tokens}\")\n", " df[\"translation\"] = cudf.Series(data=converted_tokens, index=indices)\n", " return df\n", "\n", @@ -417,16 +459,19 @@ " ddf_true[\"translation\"] = (\n", " ddf_true[\"translation\"] + ddf_true[\"false_translation\"]\n", " )\n", - "\n", + " ddf_true.drop(columns=['false_translation'])\n", " ddf = ddf_true.map_partitions(self.grouping)\n", " return DocumentDataset(ddf)" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ - "## Start dask client." + "### Start the Dask Cluster\n", + "NeMo Curator runs on Dask and Dask-cuDF to distribute computation. You can read more about it [in the documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/cpuvsgpu.html). All of the image curation modules are GPU-based, so we need to start a GPU-based local Dask cluster before we can use them." ] }, { @@ -440,9 +485,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ - "## Define input" + "### Define input" ] }, { @@ -470,9 +517,11 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ - "## Define output directory\n" + "### Define output directory\n" ] }, { @@ -488,14 +537,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Start the translation\n", + "### Start the translation\n", "\n", "IndicTranslation will need ct2_model_path, the model path of ctranslate2 converted model(which is downloaded from [here](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/en-indic-preprint.zip))." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -504,9 +553,22 @@ "text": [ "/home/uahmed/Desktop/NeMo-Curator/.ndc/lib/python3.10/site-packages/sklearn/base.py:376: InconsistentVersionWarning: Trying to unpickle estimator LinearRegression from version 1.4.2 when using version 1.5.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", - " warnings.warn(\n", - "GPU: 0, Part: 0: 0%| | 0/10 [05:33