From 9e11cd4f9d0e424fa3b751dceb56d78c5abe2803 Mon Sep 17 00:00:00 2001 From: dnth Date: Tue, 29 Oct 2024 18:06:45 +0800 Subject: [PATCH] patch llama32 class --- nbs/llama.ipynb | 48 ++++++++++++++++++++++++---- xinfer/transformers/__init__.py | 2 +- xinfer/transformers/llama32.py | 56 ++++++++++++++++++++++++++------- 3 files changed, 88 insertions(+), 18 deletions(-) diff --git a/nbs/llama.ipynb b/nbs/llama.ipynb index a14f382..90f65fa 100644 --- a/nbs/llama.ipynb +++ b/nbs/llama.ipynb @@ -2,26 +2,62 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-10-29 17:54:30.075\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mModel: meta-llama/Llama-3.2-11B-Vision-Instruct\u001b[0m\n", + "\u001b[32m2024-10-29 17:54:30.076\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m64\u001b[0m - \u001b[1mDevice: cuda\u001b[0m\n", + "\u001b[32m2024-10-29 17:54:30.076\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mDtype: bfloat16\u001b[0m\n", + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a54363be84904348a857e9d36e3e1462", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/5 [00:00 str: + image = super().parse_images(image) + + # Format prompt for base vision model + input_text = f"<|image|><|begin_of_text|>{prompt}" + + # Process inputs without adding special tokens + inputs = self.processor(image, input_text, return_tensors="pt").to( + self.model.device + ) + + with torch.inference_mode(): + output = self.model.generate(**inputs, **generate_kwargs) + + return self.processor.decode(output[0], skip_special_tokens=True) + + def infer_batch(self, images: list[str], prompts: list[str], **generate_kwargs): + images = super().parse_images(images) + + # Format prompts for base vision model + input_texts = [f"<|image|><|begin_of_text|>{prompt}" for prompt in prompts] + + inputs = self.processor(images, input_texts, return_tensors="pt").to( + self.model.device + ) + + with torch.inference_mode(): + outputs = self.model.generate(**inputs, **generate_kwargs) + + return self.processor.batch_decode(outputs, skip_special_tokens=True)