diff --git a/colabs/huggingface/LLM_Finetuning_Notebook.ipynb b/colabs/huggingface/LLM_Finetuning_Notebook.ipynb
new file mode 100644
index 00000000..942e60bc
--- /dev/null
+++ b/colabs/huggingface/LLM_Finetuning_Notebook.ipynb
@@ -0,0 +1,379 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# LLM Finetuning with HuggingFace and Weights and Biases\n",
+ "\n",
+ "- Fine-tune a lightweight LLM (OPT-125M) with LoRA and 8-bit quantization using Launch\n",
+ "- Checkpoint the LoRA adapter weights as artifacts\n",
+ "- Link the best checkpoint in Model Registry\n",
+ "- Run inference on a quantized model\n",
+ "\n",
+ "The same workflow and principles from this notebook can be applied to fine-tuning some of the stronger OSS LLMs (e.g. Llama2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Fine-tune large models using 🤗 `peft` adapters, `transformers` & `bitsandbytes`\n",
+ "\n",
+ "In this tutorial we will cover how we can fine-tune large language models using the very recent `peft` library and `bitsandbytes` for loading large models in 8-bit.\n",
+ "The fine-tuning method will rely on a recent method called \"Low Rank Adapters\" (LoRA), instead of fine-tuning the entire model you just have to fine-tune these adapters and load them properly inside the model.\n",
+ "After fine-tuning the model you can also share your adapters on the 🤗 Hub and load them very easily. Let's get started!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install requirements\n",
+ "\n",
+ "First, run the cells below to install the requirements:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install -q bitsandbytes datasets accelerate loralib\n",
+ "!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git\n",
+ "!pip install -q wandb\n",
+ "!pip install -q ctranslate2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Model Loading\n",
+ "\n",
+ "- Here we leverage 8-bit quantization to reduce the memory footprint of the model during training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import bitsandbytes as bnb\n",
+ "from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM\n",
+ "\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " \"facebook/opt-125m\",\n",
+ " load_in_8bit=True,\n",
+ " device_map='auto',\n",
+ ")\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-125m\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Post-processing on the model\n",
+ "\n",
+ "Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for param in model.parameters():\n",
+ " param.requires_grad = False # freeze the model - train adapters later\n",
+ " if param.ndim == 1:\n",
+ " # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
+ " param.data = param.data.to(torch.float32)\n",
+ "\n",
+ "model.gradient_checkpointing_enable() # reduce number of stored activations\n",
+ "model.enable_input_require_grads()\n",
+ "\n",
+ "class CastOutputToFloat(nn.Sequential):\n",
+ " def forward(self, x): return super().forward(x).to(torch.float32)\n",
+ "model.lm_head = CastOutputToFloat(model.lm_head)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Apply LoRA\n",
+ "\n",
+ "Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def print_trainable_parameters(model):\n",
+ " \"\"\"\n",
+ " Prints the number of trainable parameters in the model.\n",
+ " \"\"\"\n",
+ " trainable_params = 0\n",
+ " all_param = 0\n",
+ " for _, param in model.named_parameters():\n",
+ " all_param += param.numel()\n",
+ " if param.requires_grad:\n",
+ " trainable_params += param.numel()\n",
+ " print(\n",
+ " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from peft import LoraConfig, get_peft_model\n",
+ "\n",
+ "config = LoraConfig(\n",
+ " r=16,\n",
+ " lora_alpha=32,\n",
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
+ " lora_dropout=0.05,\n",
+ " bias=\"none\",\n",
+ " task_type=\"CAUSAL_LM\"\n",
+ ")\n",
+ "\n",
+ "model = get_peft_model(model, config)\n",
+ "print_trainable_parameters(model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Training\n",
+ "- [W&B HuggingFace integration](https://docs.wandb.ai/guides/integrations/huggingface) automatically tracks important metrics during the course of training\n",
+ "- Also track the HF checkpoints as artifacts and register them in the model registry!\n",
+ "- Change the number of steps to 200+ for real results!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import transformers\n",
+ "from datasets import load_dataset\n",
+ "import wandb\n",
+ "\n",
+ "project_name = \"llm-finetuning\" #@param\n",
+ "entity = \"wandb\" #@param\n",
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\"\n",
+ "\n",
+ "wandb.init(project=project_name,\n",
+ " entity=entity,\n",
+ " job_type=\"training\")\n",
+ "\n",
+ "data = load_dataset(\"Abirate/english_quotes\")\n",
+ "data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)\n",
+ "\n",
+ "trainer = transformers.Trainer(\n",
+ " model=model,\n",
+ " train_dataset=data['train'],\n",
+ " args=transformers.TrainingArguments(\n",
+ " per_device_train_batch_size=4,\n",
+ " gradient_accumulation_steps=4,\n",
+ " report_to=\"wandb\",\n",
+ " warmup_steps=5,\n",
+ " max_steps=25,\n",
+ " learning_rate=2e-4,\n",
+ " fp16=True,\n",
+ " logging_steps=1,\n",
+ " save_steps=5,\n",
+ " output_dir='outputs'\n",
+ " ),\n",
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
+ ")\n",
+ "model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n",
+ "trainer.train()\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Adding Model Weights to W&B Model Registry\n",
+ "- Here we get our best checkpoint from the finetuning run and register it as our best model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "last_run_id = \"zz0lxkc8\" #@param\n",
+ "wandb.init(project=project_name, entity=entity, job_type=\"registering_best_model\")\n",
+ "best_model = wandb.use_artifact(f'{entity}/{project_name}/checkpoint-{last_run_id}:latest')\n",
+ "registered_model_name = \"OPT-125M-english\" #@param {type: \"string\"}\n",
+ "wandb.run.link_artifact(best_model, f'{entity}/model-registry/{registered_model_name}', aliases=['staging'])\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Consuming Model From Registry and Quantizing using ctranslate2\n",
+ "- LLMs are typically too large to run in full-precision on even decent hardware.\n",
+ "- You can quantize the model to run it more efficiently with minimal loss in accuracy.\n",
+ " - CTranslate2 is a great first pass at quantization but doesn't do \"smart\" quantization. It just converts all weights to half precision.\n",
+ " - Checkout out GPTQ and AutoGPTQ for SOTA quantization at scale"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Pull model from the registry\n",
+ "\n",
+ "wandb.init(project=project_name, entity=entity, job_type=\"ctranslate2\")\n",
+ "best_model = wandb.use_artifact(f'{entity}/model-registry/{registered_model_name}:latest')\n",
+ "best_model.download(root=f'model-registry/{registered_model_name}:latest')\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from peft import PeftModel, PeftConfig\n",
+ "\n",
+ "def convert_qlora2ct2(adapter_path=f'model-registry/{registered_model_name}:latest',\n",
+ " full_model_path=\"opt125m-finetuned\",\n",
+ " offload_path=\"opt125m-offload\",\n",
+ " ct2_path=\"opt125m-finetuned-ct2\",\n",
+ " quantization=\"int8\"):\n",
+ "\n",
+ "\n",
+ " peft_model_id = adapter_path\n",
+ " peftconfig = PeftConfig.from_pretrained(peft_model_id)\n",
+ "\n",
+ " model = AutoModelForCausalLM.from_pretrained(\n",
+ " \"facebook/opt-125m\",\n",
+ " offload_folder = offload_path,\n",
+ " device_map='auto',\n",
+ " )\n",
+ "\n",
+ " tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-125m\")\n",
+ "\n",
+ " model = PeftModel.from_pretrained(model, peft_model_id)\n",
+ "\n",
+ " print(\"Peft model loaded\")\n",
+ "\n",
+ " merged_model = model.merge_and_unload()\n",
+ "\n",
+ " merged_model.save_pretrained(full_model_path)\n",
+ " tokenizer.save_pretrained(full_model_path)\n",
+ "\n",
+ " if quantization == False:\n",
+ " os.system(f\"ct2-transformers-converter --model {full_model_path} --output_dir {ct2_path} --force\")\n",
+ " else:\n",
+ " os.system(f\"ct2-transformers-converter --model {full_model_path} --output_dir {ct2_path} --quantization {quantization} --force\")\n",
+ " print(\"Convert successfully\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "convert_qlora2ct2(adapter_path=f'model-registry/{registered_model_name}:latest')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run Inference Using Quantized CTranslate2 Model\n",
+ "- Record the results in a W&B Table!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import ctranslate2\n",
+ "\n",
+ "\n",
+ "run = wandb.init(project=project_name, entity=entity, job_type=\"inference\")\n",
+ "generator = ctranslate2.Generator(\"opt125m-finetuned-ct2\")\n",
+ "\n",
+ "prompts = [\"Hey, are you conscious? Can you talk to me?\",\n",
+ " \"What is machine learning?\",\n",
+ " \"What is W&B?\"]\n",
+ "\n",
+ "\n",
+ "wandb_table = wandb.Table(columns=['prompt', 'completion'])\n",
+ "for prompt in prompts:\n",
+ " start_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))\n",
+ " results = generator.generate_batch([start_tokens], max_length=30)\n",
+ " output = tokenizer.decode(results[0].sequences_ids[0])\n",
+ " wandb_table.add_data(prompt, output)\n",
+ "\n",
+ "wandb.log({\"inference_table\": wandb_table})\n",
+ "wandb.finish()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/prompts/WandB_LLM_QA_bot.ipynb b/colabs/prompts/WandB_LLM_QA_bot.ipynb
new file mode 100644
index 00000000..6cf57f5b
--- /dev/null
+++ b/colabs/prompts/WandB_LLM_QA_bot.ipynb
@@ -0,0 +1,621 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Building an LLM App for Document Retrieval / Extraction\n",
+ "\n",
+ "This tutorial runs through [this report](https://wandb.ai/gladiator/gradient_dissent_qabot/reports/Building-a-Q-A-Bot-for-Weights-Biases-Gradient-Dissent-Podcast--Vmlldzo0MTcyMDQz) on how to build a basic LLM App for retrieval-augmented question-answering.\n",
+ "- Track datasets and embeddings as artifacts\n",
+ "- Track prompts and chain executions\n",
+ "- Log token counts and cost"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install -qqq wandb langchain pytube tiktoken openai youtube-transcript-api chromadb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Set up OpenAI API Key"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from getpass import getpass\n",
+ "import os\n",
+ "\n",
+ "if os.getenv(\"OPENAI_API_KEY\") is None:\n",
+ " if any(['VSCODE' in x for x in os.environ.keys()]):\n",
+ " print('Please enter password in the VS Code prompt at the top of your VS Code window!')\n",
+ " os.environ[\"OPENAI_API_KEY\"] = getpass(\"Paste your OpenAI key from: https://platform.openai.com/account/api-keys\\n\")\n",
+ "\n",
+ "assert os.getenv(\"OPENAI_API_KEY\", \"\").startswith(\"sk-\"), \"This doesn't look like a valid OpenAI API key\"\n",
+ "print(\"OpenAI API key configured\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Set up config and environment variables\n",
+ "- NOTE: set the `entity` to your username or team name\n",
+ "- Set wandb [environment variables](https://docs.wandb.ai/guides/track/environment-variables) to change behavior of logging\n",
+ "- `ENTITY` - username or team where your projects live\n",
+ "- `PROJECT` - project where your runs will live\n",
+ "- `LANGCHAIN_WANDB_TRACING` - automatically logs langchain traces, inputs and outputs as part of runs in Weights and Biases"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dataclasses import dataclass\n",
+ "from pathlib import Path\n",
+ "import os\n",
+ "\n",
+ "project_name = \"gradient-dissent-qabot\" #@param\n",
+ "entity = \"wandb\" #@param\n",
+ "TOTAL_EPISODES = 5\n",
+ "\n",
+ "playlist_url = \"https://www.youtube.com/playlist?list=PLD80i8An1OEEb1jP0sjEyiLG8ULRXFob_\"\n",
+ "root_data_dir = Path(\"/contents/data\")\n",
+ "root_artifact_dir = Path(\"downloaded_artifacts\")\n",
+ "yt_podcast_data_artifact = f\"{entity}/{project_name}/yt_podcast_transcript:latest\"\n",
+ "summarized_data_artifact = f\"{entity}/{project_name}/summarized_podcasts:latest\"\n",
+ "summarized_que_data_artifact = f\"{entity}/{project_name}/summarized_que_podcasts:latest\"\n",
+ "transcript_embeddings_artifact = f\"{entity}/{project_name}/transcript_embeddings:latest\"\n",
+ "\n",
+ "os.makedirs(\"/contents/data\", exist_ok=True)\n",
+ "os.environ[\"LANGCHAIN_WANDB_TRACING\"] = \"true\"\n",
+ "os.environ['WANDB_PROJECT'] = project_name\n",
+ "os.environ['WANDB_ENTITY'] = entity"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Log in to W&B\n",
+ "- You can explicitly login using `wandb login` or `wandb.login()` (See below)\n",
+ "- Alternatively you can set environment variables. There are several env variables which you can set to change the behavior of W&B logging. The most important are:\n",
+ " - `WANDB_API_KEY` - find this in your \"Settings\" section under your profile\n",
+ " - `WANDB_BASE_URL` - this is the url of the W&B server (You only need this if you are using a private instance)\n",
+ "- Find your API Token in \"Profile\" -> \"Setttings\" in the W&B App\n",
+ "\n",
+ "![api_token](https://drive.google.com/uc?export=view&id=1Xn7hnn0rfPu_EW0A_-32oCXqDmpA0-kx)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "import pandas as pd\n",
+ "import wandb\n",
+ "from langchain.document_loaders import YoutubeLoader\n",
+ "from pytube import Playlist, YouTube\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "\n",
+ "def retry_access_yt_object(url, max_retries=5, interval_secs=5):\n",
+ " \"\"\"\n",
+ " Retries creating a YouTube object with the given URL and accessing its title several times\n",
+ " with a given interval in seconds, until it succeeds or the maximum number of attempts is reached.\n",
+ " If the object still cannot be created or the title cannot be accessed after the maximum number\n",
+ " of attempts, the last exception is raised.\n",
+ " \"\"\"\n",
+ " last_exception = None\n",
+ " for i in range(max_retries):\n",
+ " try:\n",
+ " yt = YouTube(url)\n",
+ " title = yt.title # Access the title of the YouTube object.\n",
+ " return yt # Return the YouTube object if successful.\n",
+ " except Exception as err:\n",
+ " last_exception = err # Keep track of the last exception raised.\n",
+ " print(\n",
+ " f\"Failed to create YouTube object or access title. Retrying... ({i+1}/{max_retries})\"\n",
+ " )\n",
+ " time.sleep(interval_secs) # Wait for the specified interval before retrying.\n",
+ "\n",
+ " # If the YouTube object still cannot be created or the title cannot be accessed after the maximum number of attempts, raise the last exception.\n",
+ " raise last_exception"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Log Data Snapshots as Artifacts\n",
+ "\n",
+ "W&B is very unopinionated with regard to how you track your experiments. We could log data in any number of ways. \n",
+ "* Log one artifact which represents all the data - training, validation, and test data to one artifact\n",
+ "* Log several artifacts - one for each of the training, validation, and test data loaders. \n",
+ "\n",
+ "It is a matter of what best suites your needs and workflows and expectations. \n",
+ "\n",
+ "### Anatomy of an artifact\n",
+ "\n",
+ "The `Artifact` class will correspond to an entry in the W&B Artifact registry. The artifact has\n",
+ "* a name\n",
+ "* a type\n",
+ "* metadata\n",
+ "* description\n",
+ "* files, directory of files, or references\n",
+ "\n",
+ "Example usage\n",
+ "```\n",
+ "run = wandb.init(project = \"my-project\")\n",
+ "artifact = wandb.Artifact(name = \"my_artifact\", type = \"data\")\n",
+ "artifact.add_file(\"/path/to/my/file.txt\")\n",
+ "run.log_artifact(artifact)\n",
+ "run.finish()\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "run = wandb.init(project=project_name, entity=entity, job_type=\"dataset\")\n",
+ "\n",
+ "playlist = Playlist(playlist_url)\n",
+ "playlist_video_urls = playlist.video_urls[0:TOTAL_EPISODES]\n",
+ "\n",
+ "print(f\"There are total {len(playlist_video_urls)} videos in the playlist.\")\n",
+ "\n",
+ "video_data = []\n",
+ "for video in tqdm(playlist_video_urls, total=len(playlist_video_urls)):\n",
+ " try:\n",
+ " curr_video_data = {}\n",
+ " yt = retry_access_yt_object(video, max_retries=25, interval_secs=2)\n",
+ " curr_video_data[\"title\"] = yt.title\n",
+ " curr_video_data[\"url\"] = video\n",
+ " curr_video_data[\"duration\"] = yt.length\n",
+ " curr_video_data[\"publish_date\"] = yt.publish_date.strftime(\"%Y-%m-%d\")\n",
+ " loader = YoutubeLoader.from_youtube_url(video)\n",
+ " transcript = loader.load()[0].page_content\n",
+ " transcript = \" \".join(transcript.split())\n",
+ " curr_video_data[\"transcript\"] = transcript\n",
+ " curr_video_data[\"total_words\"] = len(transcript.split())\n",
+ " video_data.append(curr_video_data)\n",
+ " except Exception as inst:\n",
+ " print(type(inst)) # the exception type\n",
+ " print(inst.args) # arguments stored in .args\n",
+ " print(inst)\n",
+ " print(f\"Failed to scrape {video}\")\n",
+ "\n",
+ "print(f\"Total podcast episodes scraped: {len(video_data)}\")\n",
+ "\n",
+ "# save the scraped data to a csv file\n",
+ "df = pd.DataFrame(video_data)\n",
+ "data_path = root_data_dir / \"yt_podcast_transcript.csv\"\n",
+ "df.to_csv(data_path, index=False)\n",
+ "\n",
+ "# upload the scraped data to wandb\n",
+ "artifact = wandb.Artifact(\"yt_podcast_transcript\", type=\"dataset\")\n",
+ "artifact.add_file(data_path)\n",
+ "run.log_artifact(artifact)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Log a wandb Table to interact with your data\n",
+ "- Here we log the dataframe of metadata about the youtube transcripts (urls, length, transcripts)\n",
+ "- This allows us to interrogate the original data (filtering, grouping, etc.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# create wandb table\n",
+ "table = wandb.Table(dataframe=df)\n",
+ "run.log({\"yt_podcast_transcript\": table})\n",
+ "run.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summarize YouTube Transcripts\n",
+ "- Here we summarize the transcripts in chunks, summarizing each chunk and then summarizing the summaries using the LangChain `load_summarize_chain`\n",
+ "- We can do this in parallel since each chunk of a transcript can be summarized independently so we employ `map_reduce`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import pandas as pd\n",
+ "from langchain.callbacks import get_openai_callback\n",
+ "from langchain.chains.summarize import load_summarize_chain\n",
+ "from langchain.chat_models import ChatOpenAI\n",
+ "from langchain.document_loaders import DataFrameLoader\n",
+ "from langchain.prompts import PromptTemplate\n",
+ "from langchain.text_splitter import TokenTextSplitter\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "import wandb\n",
+ "\n",
+ "\n",
+ "def get_data(artifact_name: str, total_episodes: int = None):\n",
+ " podcast_artifact = wandb.use_artifact(artifact_name)\n",
+ " podcast_artifact_dir = podcast_artifact.download(root_artifact_dir)\n",
+ " filename = artifact_name.split(\":\")[0].split(\"/\")[-1]\n",
+ " df = pd.read_csv(os.path.join(podcast_artifact_dir, f\"{filename}.csv\"))\n",
+ " if total_episodes is not None:\n",
+ " df = df.iloc[:total_episodes]\n",
+ " return df\n",
+ "\n",
+ "\n",
+ "def summarize_episode(episode_df: pd.DataFrame):\n",
+ " # load docs into langchain format\n",
+ " loader = DataFrameLoader(episode_df, page_content_column=\"transcript\")\n",
+ " data = loader.load()\n",
+ "\n",
+ " # split the documents\n",
+ " text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)\n",
+ " docs = text_splitter.split_documents(data)\n",
+ " print(f\"Number of documents for podcast {data[0].metadata['title']}: {len(docs)}\")\n",
+ "\n",
+ " # initialize LLM\n",
+ " llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0)\n",
+ "\n",
+ " # define map prompt\n",
+ " map_prompt = \"\"\"Write a concise summary of the following short transcript from a podcast.\n",
+ " Don't add your opinions or interpretations.\n",
+ "\n",
+ " {text}\n",
+ "\n",
+ " CONCISE SUMMARY:\"\"\"\n",
+ "\n",
+ " # define combine prompt\n",
+ " combine_prompt = \"\"\"You have been provided with summaries of chunks of transcripts from a podcast.\n",
+ " Your task is to merge these intermediate summaries to create a brief and comprehensive summary of the entire podcast.\n",
+ " The summary should encompass all the crucial points of the podcast.\n",
+ " Ensure that the summary is atleast 2 paragraph long and effectively captures the essence of the podcast.\n",
+ " {text}\n",
+ "\n",
+ " SUMMARY:\"\"\"\n",
+ "\n",
+ " map_prompt_template = PromptTemplate(template=map_prompt, input_variables=[\"text\"])\n",
+ " combine_prompt_template = PromptTemplate(template=combine_prompt, input_variables=[\"text\"])\n",
+ "\n",
+ " # initialize the summarizer chain\n",
+ " chain = load_summarize_chain(\n",
+ " llm,\n",
+ " chain_type=\"map_reduce\",\n",
+ " return_intermediate_steps=True,\n",
+ " map_prompt=map_prompt_template,\n",
+ " combine_prompt=combine_prompt_template,\n",
+ " )\n",
+ "\n",
+ " summary = chain({\"input_documents\": docs})\n",
+ " return summary"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Execute Summary Chain and log results\n",
+ "- You can instantiate a `WandbTracer` and pass in additional config about this LangChain run.\n",
+ "- Log the outputs of the chain like tokens used, cost, etc.\n",
+ "- Log the resulting summaries as artifacts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain.callbacks.tracers import WandbTracer\n",
+ "\n",
+ "tracer = WandbTracer(run_args = {\"job_type\": \"summarize\"})\n",
+ "\n",
+ "# get scraped data\n",
+ "df = get_data(artifact_name=yt_podcast_data_artifact, total_episodes=TOTAL_EPISODES)\n",
+ "\n",
+ "summaries = []\n",
+ "with get_openai_callback() as cb:\n",
+ " for episode in tqdm(df.iterrows(), total=len(df), desc=\"Summarizing episodes\"):\n",
+ " episode_data = episode[1].to_frame().T\n",
+ "\n",
+ " summary = summarize_episode(episode_data)\n",
+ " summaries.append(summary[\"output_text\"])\n",
+ "\n",
+ " print(\"*\" * 25)\n",
+ " print(cb)\n",
+ " print(\"*\" * 25)\n",
+ "\n",
+ " wandb.log(\n",
+ " {\n",
+ " \"total_prompt_tokens\": cb.prompt_tokens,\n",
+ " \"total_completion_tokens\": cb.completion_tokens,\n",
+ " \"total_tokens\": cb.total_tokens,\n",
+ " \"total_cost\": cb.total_cost,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ "df[\"summary\"] = summaries\n",
+ "\n",
+ "# save data\n",
+ "path_to_save = os.path.join(root_data_dir, \"summarized_podcasts.csv\")\n",
+ "df.to_csv(path_to_save, index=False)\n",
+ "\n",
+ "# log to wandb artifact\n",
+ "artifact = wandb.Artifact(\"summarized_podcasts\", type=\"dataset\")\n",
+ "artifact.add_file(path_to_save)\n",
+ "wandb.log_artifact(artifact)\n",
+ "\n",
+ "# create wandb table\n",
+ "table = wandb.Table(dataframe=df)\n",
+ "wandb.log({\"summarized_podcasts\": table})\n",
+ "\n",
+ "tracer.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Embed the contents of the YouTube transcripts\n",
+ "- Here we use OpenAI embeddings and [ChromaDB](https://www.trychroma.com/) to embed the summaries to make them queriable via vector similarity search when we ask contextual questions to the LLM\n",
+ "- Use `wandb.log` and artifacts to log the resulting ChromaDB serialized embeddings."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from dataclasses import asdict\n",
+ "\n",
+ "import pandas as pd\n",
+ "from langchain.callbacks import get_openai_callback\n",
+ "from langchain.document_loaders import DataFrameLoader\n",
+ "from langchain.embeddings.openai import OpenAIEmbeddings\n",
+ "from langchain.text_splitter import TokenTextSplitter\n",
+ "from langchain.vectorstores import Chroma\n",
+ "from tqdm import tqdm\n",
+ "from wandb.integration.langchain import WandbTracer\n",
+ "\n",
+ "import wandb\n",
+ "\n",
+ "\n",
+ "def get_data(artifact_name: str, total_episodes=None):\n",
+ " podcast_artifact = wandb.use_artifact(artifact_name, type=\"dataset\")\n",
+ " podcast_artifact_dir = podcast_artifact.download(root_artifact_dir)\n",
+ " filename = artifact_name.split(\":\")[0].split(\"/\")[-1]\n",
+ " df = pd.read_csv(os.path.join(podcast_artifact_dir, f\"{filename}.csv\"))\n",
+ " if total_episodes is not None:\n",
+ " df = df.iloc[:total_episodes]\n",
+ " return df\n",
+ "\n",
+ "\n",
+ "def create_embeddings(episode_df: pd.DataFrame, index: int):\n",
+ " # load docs into langchain format\n",
+ " loader = DataFrameLoader(episode_df, page_content_column=\"transcript\")\n",
+ " data = loader.load()\n",
+ "\n",
+ " # split the documents\n",
+ " text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)\n",
+ " docs = text_splitter.split_documents(data)\n",
+ "\n",
+ " title = data[0].metadata[\"title\"]\n",
+ " print(f\"Number of documents for podcast {title}: {len(docs)}\")\n",
+ "\n",
+ " # initialize embedding engine\n",
+ " embeddings = OpenAIEmbeddings()\n",
+ "\n",
+ " db = Chroma.from_documents(\n",
+ " docs,\n",
+ " embeddings,\n",
+ " persist_directory=os.path.join(root_data_dir / \"chromadb\", str(index)),\n",
+ " )\n",
+ " db.persist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tracer = WandbTracer(run_args = {\"job_type\": \"embed_transcripts\"})\n",
+ "\n",
+ "# get data\n",
+ "df = get_data(artifact_name=summarized_data_artifact, total_episodes=TOTAL_EPISODES)\n",
+ "\n",
+ "# create embeddings\n",
+ "with get_openai_callback() as cb:\n",
+ " for episode in tqdm(df.iterrows(), total=len(df), desc=\"Embedding transcripts\"):\n",
+ " episode_data = episode[1].to_frame().T\n",
+ "\n",
+ " create_embeddings(episode_data, index=episode[0])\n",
+ "\n",
+ " print(\"*\" * 25)\n",
+ " print(cb)\n",
+ " print(\"*\" * 25)\n",
+ "\n",
+ " wandb.log(\n",
+ " {\n",
+ " \"total_prompt_tokens\": cb.prompt_tokens,\n",
+ " \"total_completion_tokens\": cb.completion_tokens,\n",
+ " \"total_tokens\": cb.total_tokens,\n",
+ " \"total_cost\": cb.total_cost,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ "# log embeddings to wandb artifact\n",
+ "artifact = wandb.Artifact(\"transcript_embeddings\", type=\"dataset\")\n",
+ "artifact.add_dir(root_data_dir / \"chromadb\")\n",
+ "wandb.log_artifact(artifact)\n",
+ "\n",
+ "tracer.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Ask Questions Against your Summarized Documents\n",
+ "\n",
+ "Finally we tie everything together:\n",
+ "1. We can pull down our ChromaDB embeddings from W&B\n",
+ "2. Pass them along with a prompt template for QA to the `RetrievalQA` chain and start asking questions!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain.chains import RetrievalQA\n",
+ "from langchain.embeddings.openai import OpenAIEmbeddings\n",
+ "\n",
+ "def get_answer(podcast: str, question: str):\n",
+ " index = df[df[\"title\"] == podcast].index[0]\n",
+ " db_dir = os.path.join(chromadb_dir, str(index))\n",
+ " embeddings = OpenAIEmbeddings()\n",
+ " db = Chroma(persist_directory=db_dir, embedding_function=embeddings)\n",
+ "\n",
+ " prompt_template = \"\"\"Use the following pieces of context to answer the question.\n",
+ " If you don't know the answer, just say that you don't know, don't try to make up an answer.\n",
+ " Don't add your opinions or interpretations. Ensure that you complete the answer.\n",
+ " If the question is not relevant to the context, just say that it is not relevant.\n",
+ "\n",
+ " CONTEXT:\n",
+ " {context}\n",
+ "\n",
+ " QUESTION: {question}\n",
+ "\n",
+ " ANSWER:\"\"\"\n",
+ "\n",
+ " prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n",
+ "\n",
+ " retriever = db.as_retriever()\n",
+ " retriever.search_kwargs[\"k\"] = 2\n",
+ "\n",
+ " qa = RetrievalQA.from_chain_type(\n",
+ " llm=ChatOpenAI(temperature=0),\n",
+ " chain_type=\"stuff\",\n",
+ " retriever=retriever,\n",
+ " chain_type_kwargs={\"prompt\": prompt},\n",
+ " return_source_documents=True,\n",
+ " )\n",
+ "\n",
+ " with get_openai_callback() as cb:\n",
+ " result = qa({\"query\": question})\n",
+ " print(cb)\n",
+ "\n",
+ " answer = result[\"result\"]\n",
+ " return answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# download and read data\n",
+ "api = wandb.Api()\n",
+ "artifact_df = api.artifact(summarized_data_artifact)\n",
+ "artifact_df.download(root_data_dir)\n",
+ "\n",
+ "artifact_embeddings = api.artifact(transcript_embeddings_artifact)\n",
+ "chromadb_dir = artifact_embeddings.download(root_data_dir / \"chromadb\")\n",
+ "\n",
+ "df_path = root_data_dir / \"summarized_podcasts.csv\"\n",
+ "df = pd.read_csv(df_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df[\"title\"].tolist()[0:TOTAL_EPISODES]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tracer = WandbTracer(run_args = {\"job_type\": \"retriealQA\"})\n",
+ "\n",
+ "answer = get_answer('Enabling LLM-Powered Applications with Harrison Chase of LangChain', \"What did Harrison Chase say?\")\n",
+ "print(answer)\n",
+ "\n",
+ "tracer.finish()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}