diff --git a/colabs/huggingface/LLM_Finetuning_Notebook.ipynb b/colabs/huggingface/LLM_Finetuning_Notebook.ipynb new file mode 100644 index 00000000..942e60bc --- /dev/null +++ b/colabs/huggingface/LLM_Finetuning_Notebook.ipynb @@ -0,0 +1,379 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLM Finetuning with HuggingFace and Weights and Biases\n", + "\n", + "- Fine-tune a lightweight LLM (OPT-125M) with LoRA and 8-bit quantization using Launch\n", + "- Checkpoint the LoRA adapter weights as artifacts\n", + "- Link the best checkpoint in Model Registry\n", + "- Run inference on a quantized model\n", + "\n", + "The same workflow and principles from this notebook can be applied to fine-tuning some of the stronger OSS LLMs (e.g. Llama2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fine-tune large models using 🤗 `peft` adapters, `transformers` & `bitsandbytes`\n", + "\n", + "In this tutorial we will cover how we can fine-tune large language models using the very recent `peft` library and `bitsandbytes` for loading large models in 8-bit.\n", + "The fine-tuning method will rely on a recent method called \"Low Rank Adapters\" (LoRA), instead of fine-tuning the entire model you just have to fine-tune these adapters and load them properly inside the model.\n", + "After fine-tuning the model you can also share your adapters on the 🤗 Hub and load them very easily. Let's get started!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Install requirements\n", + "\n", + "First, run the cells below to install the requirements:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q bitsandbytes datasets accelerate loralib\n", + "!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git\n", + "!pip install -q wandb\n", + "!pip install -q ctranslate2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model Loading\n", + "\n", + "- Here we leverage 8-bit quantization to reduce the memory footprint of the model during training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "import torch\n", + "import torch.nn as nn\n", + "import bitsandbytes as bnb\n", + "from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " \"facebook/opt-125m\",\n", + " load_in_8bit=True,\n", + " device_map='auto',\n", + ")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-125m\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Post-processing on the model\n", + "\n", + "Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for param in model.parameters():\n", + " param.requires_grad = False # freeze the model - train adapters later\n", + " if param.ndim == 1:\n", + " # cast the small parameters (e.g. layernorm) to fp32 for stability\n", + " param.data = param.data.to(torch.float32)\n", + "\n", + "model.gradient_checkpointing_enable() # reduce number of stored activations\n", + "model.enable_input_require_grads()\n", + "\n", + "class CastOutputToFloat(nn.Sequential):\n", + " def forward(self, x): return super().forward(x).to(torch.float32)\n", + "model.lm_head = CastOutputToFloat(model.lm_head)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Apply LoRA\n", + "\n", + "Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def print_trainable_parameters(model):\n", + " \"\"\"\n", + " Prints the number of trainable parameters in the model.\n", + " \"\"\"\n", + " trainable_params = 0\n", + " all_param = 0\n", + " for _, param in model.named_parameters():\n", + " all_param += param.numel()\n", + " if param.requires_grad:\n", + " trainable_params += param.numel()\n", + " print(\n", + " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from peft import LoraConfig, get_peft_model\n", + "\n", + "config = LoraConfig(\n", + " r=16,\n", + " lora_alpha=32,\n", + " target_modules=[\"q_proj\", \"v_proj\"],\n", + " lora_dropout=0.05,\n", + " bias=\"none\",\n", + " task_type=\"CAUSAL_LM\"\n", + ")\n", + "\n", + "model = get_peft_model(model, config)\n", + "print_trainable_parameters(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "- [W&B HuggingFace integration](https://docs.wandb.ai/guides/integrations/huggingface) automatically tracks important metrics during the course of training\n", + "- Also track the HF checkpoints as artifacts and register them in the model registry!\n", + "- Change the number of steps to 200+ for real results!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import transformers\n", + "from datasets import load_dataset\n", + "import wandb\n", + "\n", + "project_name = \"llm-finetuning\" #@param\n", + "entity = \"wandb\" #@param\n", + "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\"\n", + "\n", + "wandb.init(project=project_name,\n", + " entity=entity,\n", + " job_type=\"training\")\n", + "\n", + "data = load_dataset(\"Abirate/english_quotes\")\n", + "data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)\n", + "\n", + "trainer = transformers.Trainer(\n", + " model=model,\n", + " train_dataset=data['train'],\n", + " args=transformers.TrainingArguments(\n", + " per_device_train_batch_size=4,\n", + " gradient_accumulation_steps=4,\n", + " report_to=\"wandb\",\n", + " warmup_steps=5,\n", + " max_steps=25,\n", + " learning_rate=2e-4,\n", + " fp16=True,\n", + " logging_steps=1,\n", + " save_steps=5,\n", + " output_dir='outputs'\n", + " ),\n", + " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", + ")\n", + "model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n", + "trainer.train()\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Adding Model Weights to W&B Model Registry\n", + "- Here we get our best checkpoint from the finetuning run and register it as our best model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "last_run_id = \"zz0lxkc8\" #@param\n", + "wandb.init(project=project_name, entity=entity, job_type=\"registering_best_model\")\n", + "best_model = wandb.use_artifact(f'{entity}/{project_name}/checkpoint-{last_run_id}:latest')\n", + "registered_model_name = \"OPT-125M-english\" #@param {type: \"string\"}\n", + "wandb.run.link_artifact(best_model, f'{entity}/model-registry/{registered_model_name}', aliases=['staging'])\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Consuming Model From Registry and Quantizing using ctranslate2\n", + "- LLMs are typically too large to run in full-precision on even decent hardware.\n", + "- You can quantize the model to run it more efficiently with minimal loss in accuracy.\n", + " - CTranslate2 is a great first pass at quantization but doesn't do \"smart\" quantization. It just converts all weights to half precision.\n", + " - Checkout out GPTQ and AutoGPTQ for SOTA quantization at scale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pull model from the registry\n", + "\n", + "wandb.init(project=project_name, entity=entity, job_type=\"ctranslate2\")\n", + "best_model = wandb.use_artifact(f'{entity}/model-registry/{registered_model_name}:latest')\n", + "best_model.download(root=f'model-registry/{registered_model_name}:latest')\n", + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from peft import PeftModel, PeftConfig\n", + "\n", + "def convert_qlora2ct2(adapter_path=f'model-registry/{registered_model_name}:latest',\n", + " full_model_path=\"opt125m-finetuned\",\n", + " offload_path=\"opt125m-offload\",\n", + " ct2_path=\"opt125m-finetuned-ct2\",\n", + " quantization=\"int8\"):\n", + "\n", + "\n", + " peft_model_id = adapter_path\n", + " peftconfig = PeftConfig.from_pretrained(peft_model_id)\n", + "\n", + " model = AutoModelForCausalLM.from_pretrained(\n", + " \"facebook/opt-125m\",\n", + " offload_folder = offload_path,\n", + " device_map='auto',\n", + " )\n", + "\n", + " tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-125m\")\n", + "\n", + " model = PeftModel.from_pretrained(model, peft_model_id)\n", + "\n", + " print(\"Peft model loaded\")\n", + "\n", + " merged_model = model.merge_and_unload()\n", + "\n", + " merged_model.save_pretrained(full_model_path)\n", + " tokenizer.save_pretrained(full_model_path)\n", + "\n", + " if quantization == False:\n", + " os.system(f\"ct2-transformers-converter --model {full_model_path} --output_dir {ct2_path} --force\")\n", + " else:\n", + " os.system(f\"ct2-transformers-converter --model {full_model_path} --output_dir {ct2_path} --quantization {quantization} --force\")\n", + " print(\"Convert successfully\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "convert_qlora2ct2(adapter_path=f'model-registry/{registered_model_name}:latest')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run Inference Using Quantized CTranslate2 Model\n", + "- Record the results in a W&B Table!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ctranslate2\n", + "\n", + "\n", + "run = wandb.init(project=project_name, entity=entity, job_type=\"inference\")\n", + "generator = ctranslate2.Generator(\"opt125m-finetuned-ct2\")\n", + "\n", + "prompts = [\"Hey, are you conscious? Can you talk to me?\",\n", + " \"What is machine learning?\",\n", + " \"What is W&B?\"]\n", + "\n", + "\n", + "wandb_table = wandb.Table(columns=['prompt', 'completion'])\n", + "for prompt in prompts:\n", + " start_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))\n", + " results = generator.generate_batch([start_tokens], max_length=30)\n", + " output = tokenizer.decode(results[0].sequences_ids[0])\n", + " wandb_table.add_data(prompt, output)\n", + "\n", + "wandb.log({\"inference_table\": wandb_table})\n", + "wandb.finish()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/colabs/prompts/WandB_LLM_QA_bot.ipynb b/colabs/prompts/WandB_LLM_QA_bot.ipynb new file mode 100644 index 00000000..6cf57f5b --- /dev/null +++ b/colabs/prompts/WandB_LLM_QA_bot.ipynb @@ -0,0 +1,621 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Building an LLM App for Document Retrieval / Extraction\n", + "\n", + "This tutorial runs through [this report](https://wandb.ai/gladiator/gradient_dissent_qabot/reports/Building-a-Q-A-Bot-for-Weights-Biases-Gradient-Dissent-Podcast--Vmlldzo0MTcyMDQz) on how to build a basic LLM App for retrieval-augmented question-answering.\n", + "- Track datasets and embeddings as artifacts\n", + "- Track prompts and chain executions\n", + "- Log token counts and cost" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qqq wandb langchain pytube tiktoken openai youtube-transcript-api chromadb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set up OpenAI API Key" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from getpass import getpass\n", + "import os\n", + "\n", + "if os.getenv(\"OPENAI_API_KEY\") is None:\n", + " if any(['VSCODE' in x for x in os.environ.keys()]):\n", + " print('Please enter password in the VS Code prompt at the top of your VS Code window!')\n", + " os.environ[\"OPENAI_API_KEY\"] = getpass(\"Paste your OpenAI key from: https://platform.openai.com/account/api-keys\\n\")\n", + "\n", + "assert os.getenv(\"OPENAI_API_KEY\", \"\").startswith(\"sk-\"), \"This doesn't look like a valid OpenAI API key\"\n", + "print(\"OpenAI API key configured\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set up config and environment variables\n", + "- NOTE: set the `entity` to your username or team name\n", + "- Set wandb [environment variables](https://docs.wandb.ai/guides/track/environment-variables) to change behavior of logging\n", + "- `ENTITY` - username or team where your projects live\n", + "- `PROJECT` - project where your runs will live\n", + "- `LANGCHAIN_WANDB_TRACING` - automatically logs langchain traces, inputs and outputs as part of runs in Weights and Biases" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "from pathlib import Path\n", + "import os\n", + "\n", + "project_name = \"gradient-dissent-qabot\" #@param\n", + "entity = \"wandb\" #@param\n", + "TOTAL_EPISODES = 5\n", + "\n", + "playlist_url = \"https://www.youtube.com/playlist?list=PLD80i8An1OEEb1jP0sjEyiLG8ULRXFob_\"\n", + "root_data_dir = Path(\"/contents/data\")\n", + "root_artifact_dir = Path(\"downloaded_artifacts\")\n", + "yt_podcast_data_artifact = f\"{entity}/{project_name}/yt_podcast_transcript:latest\"\n", + "summarized_data_artifact = f\"{entity}/{project_name}/summarized_podcasts:latest\"\n", + "summarized_que_data_artifact = f\"{entity}/{project_name}/summarized_que_podcasts:latest\"\n", + "transcript_embeddings_artifact = f\"{entity}/{project_name}/transcript_embeddings:latest\"\n", + "\n", + "os.makedirs(\"/contents/data\", exist_ok=True)\n", + "os.environ[\"LANGCHAIN_WANDB_TRACING\"] = \"true\"\n", + "os.environ['WANDB_PROJECT'] = project_name\n", + "os.environ['WANDB_ENTITY'] = entity" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Log in to W&B\n", + "- You can explicitly login using `wandb login` or `wandb.login()` (See below)\n", + "- Alternatively you can set environment variables. There are several env variables which you can set to change the behavior of W&B logging. The most important are:\n", + " - `WANDB_API_KEY` - find this in your \"Settings\" section under your profile\n", + " - `WANDB_BASE_URL` - this is the url of the W&B server (You only need this if you are using a private instance)\n", + "- Find your API Token in \"Profile\" -> \"Setttings\" in the W&B App\n", + "\n", + "![api_token](https://drive.google.com/uc?export=view&id=1Xn7hnn0rfPu_EW0A_-32oCXqDmpA0-kx)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "wandb.login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import pandas as pd\n", + "import wandb\n", + "from langchain.document_loaders import YoutubeLoader\n", + "from pytube import Playlist, YouTube\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "def retry_access_yt_object(url, max_retries=5, interval_secs=5):\n", + " \"\"\"\n", + " Retries creating a YouTube object with the given URL and accessing its title several times\n", + " with a given interval in seconds, until it succeeds or the maximum number of attempts is reached.\n", + " If the object still cannot be created or the title cannot be accessed after the maximum number\n", + " of attempts, the last exception is raised.\n", + " \"\"\"\n", + " last_exception = None\n", + " for i in range(max_retries):\n", + " try:\n", + " yt = YouTube(url)\n", + " title = yt.title # Access the title of the YouTube object.\n", + " return yt # Return the YouTube object if successful.\n", + " except Exception as err:\n", + " last_exception = err # Keep track of the last exception raised.\n", + " print(\n", + " f\"Failed to create YouTube object or access title. Retrying... ({i+1}/{max_retries})\"\n", + " )\n", + " time.sleep(interval_secs) # Wait for the specified interval before retrying.\n", + "\n", + " # If the YouTube object still cannot be created or the title cannot be accessed after the maximum number of attempts, raise the last exception.\n", + " raise last_exception" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Log Data Snapshots as Artifacts\n", + "\n", + "W&B is very unopinionated with regard to how you track your experiments. We could log data in any number of ways. \n", + "* Log one artifact which represents all the data - training, validation, and test data to one artifact\n", + "* Log several artifacts - one for each of the training, validation, and test data loaders. \n", + "\n", + "It is a matter of what best suites your needs and workflows and expectations. \n", + "\n", + "### Anatomy of an artifact\n", + "\n", + "The `Artifact` class will correspond to an entry in the W&B Artifact registry. The artifact has\n", + "* a name\n", + "* a type\n", + "* metadata\n", + "* description\n", + "* files, directory of files, or references\n", + "\n", + "Example usage\n", + "```\n", + "run = wandb.init(project = \"my-project\")\n", + "artifact = wandb.Artifact(name = \"my_artifact\", type = \"data\")\n", + "artifact.add_file(\"/path/to/my/file.txt\")\n", + "run.log_artifact(artifact)\n", + "run.finish()\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run = wandb.init(project=project_name, entity=entity, job_type=\"dataset\")\n", + "\n", + "playlist = Playlist(playlist_url)\n", + "playlist_video_urls = playlist.video_urls[0:TOTAL_EPISODES]\n", + "\n", + "print(f\"There are total {len(playlist_video_urls)} videos in the playlist.\")\n", + "\n", + "video_data = []\n", + "for video in tqdm(playlist_video_urls, total=len(playlist_video_urls)):\n", + " try:\n", + " curr_video_data = {}\n", + " yt = retry_access_yt_object(video, max_retries=25, interval_secs=2)\n", + " curr_video_data[\"title\"] = yt.title\n", + " curr_video_data[\"url\"] = video\n", + " curr_video_data[\"duration\"] = yt.length\n", + " curr_video_data[\"publish_date\"] = yt.publish_date.strftime(\"%Y-%m-%d\")\n", + " loader = YoutubeLoader.from_youtube_url(video)\n", + " transcript = loader.load()[0].page_content\n", + " transcript = \" \".join(transcript.split())\n", + " curr_video_data[\"transcript\"] = transcript\n", + " curr_video_data[\"total_words\"] = len(transcript.split())\n", + " video_data.append(curr_video_data)\n", + " except Exception as inst:\n", + " print(type(inst)) # the exception type\n", + " print(inst.args) # arguments stored in .args\n", + " print(inst)\n", + " print(f\"Failed to scrape {video}\")\n", + "\n", + "print(f\"Total podcast episodes scraped: {len(video_data)}\")\n", + "\n", + "# save the scraped data to a csv file\n", + "df = pd.DataFrame(video_data)\n", + "data_path = root_data_dir / \"yt_podcast_transcript.csv\"\n", + "df.to_csv(data_path, index=False)\n", + "\n", + "# upload the scraped data to wandb\n", + "artifact = wandb.Artifact(\"yt_podcast_transcript\", type=\"dataset\")\n", + "artifact.add_file(data_path)\n", + "run.log_artifact(artifact)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Log a wandb Table to interact with your data\n", + "- Here we log the dataframe of metadata about the youtube transcripts (urls, length, transcripts)\n", + "- This allows us to interrogate the original data (filtering, grouping, etc.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create wandb table\n", + "table = wandb.Table(dataframe=df)\n", + "run.log({\"yt_podcast_transcript\": table})\n", + "run.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summarize YouTube Transcripts\n", + "- Here we summarize the transcripts in chunks, summarizing each chunk and then summarizing the summaries using the LangChain `load_summarize_chain`\n", + "- We can do this in parallel since each chunk of a transcript can be summarized independently so we employ `map_reduce`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "from langchain.callbacks import get_openai_callback\n", + "from langchain.chains.summarize import load_summarize_chain\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.document_loaders import DataFrameLoader\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.text_splitter import TokenTextSplitter\n", + "from tqdm import tqdm\n", + "\n", + "import wandb\n", + "\n", + "\n", + "def get_data(artifact_name: str, total_episodes: int = None):\n", + " podcast_artifact = wandb.use_artifact(artifact_name)\n", + " podcast_artifact_dir = podcast_artifact.download(root_artifact_dir)\n", + " filename = artifact_name.split(\":\")[0].split(\"/\")[-1]\n", + " df = pd.read_csv(os.path.join(podcast_artifact_dir, f\"{filename}.csv\"))\n", + " if total_episodes is not None:\n", + " df = df.iloc[:total_episodes]\n", + " return df\n", + "\n", + "\n", + "def summarize_episode(episode_df: pd.DataFrame):\n", + " # load docs into langchain format\n", + " loader = DataFrameLoader(episode_df, page_content_column=\"transcript\")\n", + " data = loader.load()\n", + "\n", + " # split the documents\n", + " text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)\n", + " docs = text_splitter.split_documents(data)\n", + " print(f\"Number of documents for podcast {data[0].metadata['title']}: {len(docs)}\")\n", + "\n", + " # initialize LLM\n", + " llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0)\n", + "\n", + " # define map prompt\n", + " map_prompt = \"\"\"Write a concise summary of the following short transcript from a podcast.\n", + " Don't add your opinions or interpretations.\n", + "\n", + " {text}\n", + "\n", + " CONCISE SUMMARY:\"\"\"\n", + "\n", + " # define combine prompt\n", + " combine_prompt = \"\"\"You have been provided with summaries of chunks of transcripts from a podcast.\n", + " Your task is to merge these intermediate summaries to create a brief and comprehensive summary of the entire podcast.\n", + " The summary should encompass all the crucial points of the podcast.\n", + " Ensure that the summary is atleast 2 paragraph long and effectively captures the essence of the podcast.\n", + " {text}\n", + "\n", + " SUMMARY:\"\"\"\n", + "\n", + " map_prompt_template = PromptTemplate(template=map_prompt, input_variables=[\"text\"])\n", + " combine_prompt_template = PromptTemplate(template=combine_prompt, input_variables=[\"text\"])\n", + "\n", + " # initialize the summarizer chain\n", + " chain = load_summarize_chain(\n", + " llm,\n", + " chain_type=\"map_reduce\",\n", + " return_intermediate_steps=True,\n", + " map_prompt=map_prompt_template,\n", + " combine_prompt=combine_prompt_template,\n", + " )\n", + "\n", + " summary = chain({\"input_documents\": docs})\n", + " return summary" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Execute Summary Chain and log results\n", + "- You can instantiate a `WandbTracer` and pass in additional config about this LangChain run.\n", + "- Log the outputs of the chain like tokens used, cost, etc.\n", + "- Log the resulting summaries as artifacts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.callbacks.tracers import WandbTracer\n", + "\n", + "tracer = WandbTracer(run_args = {\"job_type\": \"summarize\"})\n", + "\n", + "# get scraped data\n", + "df = get_data(artifact_name=yt_podcast_data_artifact, total_episodes=TOTAL_EPISODES)\n", + "\n", + "summaries = []\n", + "with get_openai_callback() as cb:\n", + " for episode in tqdm(df.iterrows(), total=len(df), desc=\"Summarizing episodes\"):\n", + " episode_data = episode[1].to_frame().T\n", + "\n", + " summary = summarize_episode(episode_data)\n", + " summaries.append(summary[\"output_text\"])\n", + "\n", + " print(\"*\" * 25)\n", + " print(cb)\n", + " print(\"*\" * 25)\n", + "\n", + " wandb.log(\n", + " {\n", + " \"total_prompt_tokens\": cb.prompt_tokens,\n", + " \"total_completion_tokens\": cb.completion_tokens,\n", + " \"total_tokens\": cb.total_tokens,\n", + " \"total_cost\": cb.total_cost,\n", + " }\n", + " )\n", + "\n", + "df[\"summary\"] = summaries\n", + "\n", + "# save data\n", + "path_to_save = os.path.join(root_data_dir, \"summarized_podcasts.csv\")\n", + "df.to_csv(path_to_save, index=False)\n", + "\n", + "# log to wandb artifact\n", + "artifact = wandb.Artifact(\"summarized_podcasts\", type=\"dataset\")\n", + "artifact.add_file(path_to_save)\n", + "wandb.log_artifact(artifact)\n", + "\n", + "# create wandb table\n", + "table = wandb.Table(dataframe=df)\n", + "wandb.log({\"summarized_podcasts\": table})\n", + "\n", + "tracer.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Embed the contents of the YouTube transcripts\n", + "- Here we use OpenAI embeddings and [ChromaDB](https://www.trychroma.com/) to embed the summaries to make them queriable via vector similarity search when we ask contextual questions to the LLM\n", + "- Use `wandb.log` and artifacts to log the resulting ChromaDB serialized embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from dataclasses import asdict\n", + "\n", + "import pandas as pd\n", + "from langchain.callbacks import get_openai_callback\n", + "from langchain.document_loaders import DataFrameLoader\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import TokenTextSplitter\n", + "from langchain.vectorstores import Chroma\n", + "from tqdm import tqdm\n", + "from wandb.integration.langchain import WandbTracer\n", + "\n", + "import wandb\n", + "\n", + "\n", + "def get_data(artifact_name: str, total_episodes=None):\n", + " podcast_artifact = wandb.use_artifact(artifact_name, type=\"dataset\")\n", + " podcast_artifact_dir = podcast_artifact.download(root_artifact_dir)\n", + " filename = artifact_name.split(\":\")[0].split(\"/\")[-1]\n", + " df = pd.read_csv(os.path.join(podcast_artifact_dir, f\"{filename}.csv\"))\n", + " if total_episodes is not None:\n", + " df = df.iloc[:total_episodes]\n", + " return df\n", + "\n", + "\n", + "def create_embeddings(episode_df: pd.DataFrame, index: int):\n", + " # load docs into langchain format\n", + " loader = DataFrameLoader(episode_df, page_content_column=\"transcript\")\n", + " data = loader.load()\n", + "\n", + " # split the documents\n", + " text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)\n", + " docs = text_splitter.split_documents(data)\n", + "\n", + " title = data[0].metadata[\"title\"]\n", + " print(f\"Number of documents for podcast {title}: {len(docs)}\")\n", + "\n", + " # initialize embedding engine\n", + " embeddings = OpenAIEmbeddings()\n", + "\n", + " db = Chroma.from_documents(\n", + " docs,\n", + " embeddings,\n", + " persist_directory=os.path.join(root_data_dir / \"chromadb\", str(index)),\n", + " )\n", + " db.persist()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tracer = WandbTracer(run_args = {\"job_type\": \"embed_transcripts\"})\n", + "\n", + "# get data\n", + "df = get_data(artifact_name=summarized_data_artifact, total_episodes=TOTAL_EPISODES)\n", + "\n", + "# create embeddings\n", + "with get_openai_callback() as cb:\n", + " for episode in tqdm(df.iterrows(), total=len(df), desc=\"Embedding transcripts\"):\n", + " episode_data = episode[1].to_frame().T\n", + "\n", + " create_embeddings(episode_data, index=episode[0])\n", + "\n", + " print(\"*\" * 25)\n", + " print(cb)\n", + " print(\"*\" * 25)\n", + "\n", + " wandb.log(\n", + " {\n", + " \"total_prompt_tokens\": cb.prompt_tokens,\n", + " \"total_completion_tokens\": cb.completion_tokens,\n", + " \"total_tokens\": cb.total_tokens,\n", + " \"total_cost\": cb.total_cost,\n", + " }\n", + " )\n", + "\n", + "# log embeddings to wandb artifact\n", + "artifact = wandb.Artifact(\"transcript_embeddings\", type=\"dataset\")\n", + "artifact.add_dir(root_data_dir / \"chromadb\")\n", + "wandb.log_artifact(artifact)\n", + "\n", + "tracer.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ask Questions Against your Summarized Documents\n", + "\n", + "Finally we tie everything together:\n", + "1. We can pull down our ChromaDB embeddings from W&B\n", + "2. Pass them along with a prompt template for QA to the `RetrievalQA` chain and start asking questions!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import RetrievalQA\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "\n", + "def get_answer(podcast: str, question: str):\n", + " index = df[df[\"title\"] == podcast].index[0]\n", + " db_dir = os.path.join(chromadb_dir, str(index))\n", + " embeddings = OpenAIEmbeddings()\n", + " db = Chroma(persist_directory=db_dir, embedding_function=embeddings)\n", + "\n", + " prompt_template = \"\"\"Use the following pieces of context to answer the question.\n", + " If you don't know the answer, just say that you don't know, don't try to make up an answer.\n", + " Don't add your opinions or interpretations. Ensure that you complete the answer.\n", + " If the question is not relevant to the context, just say that it is not relevant.\n", + "\n", + " CONTEXT:\n", + " {context}\n", + "\n", + " QUESTION: {question}\n", + "\n", + " ANSWER:\"\"\"\n", + "\n", + " prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n", + "\n", + " retriever = db.as_retriever()\n", + " retriever.search_kwargs[\"k\"] = 2\n", + "\n", + " qa = RetrievalQA.from_chain_type(\n", + " llm=ChatOpenAI(temperature=0),\n", + " chain_type=\"stuff\",\n", + " retriever=retriever,\n", + " chain_type_kwargs={\"prompt\": prompt},\n", + " return_source_documents=True,\n", + " )\n", + "\n", + " with get_openai_callback() as cb:\n", + " result = qa({\"query\": question})\n", + " print(cb)\n", + "\n", + " answer = result[\"result\"]\n", + " return answer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# download and read data\n", + "api = wandb.Api()\n", + "artifact_df = api.artifact(summarized_data_artifact)\n", + "artifact_df.download(root_data_dir)\n", + "\n", + "artifact_embeddings = api.artifact(transcript_embeddings_artifact)\n", + "chromadb_dir = artifact_embeddings.download(root_data_dir / \"chromadb\")\n", + "\n", + "df_path = root_data_dir / \"summarized_podcasts.csv\"\n", + "df = pd.read_csv(df_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df[\"title\"].tolist()[0:TOTAL_EPISODES]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tracer = WandbTracer(run_args = {\"job_type\": \"retriealQA\"})\n", + "\n", + "answer = get_answer('Enabling LLM-Powered Applications with Harrison Chase of LangChain', \"What did Harrison Chase say?\")\n", + "print(answer)\n", + "\n", + "tracer.finish()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}