Skip to content

Commit

Permalink
patch llama32 class
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Oct 29, 2024
1 parent dc613f2 commit 9e11cd4
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 18 deletions.
48 changes: 42 additions & 6 deletions nbs/llama.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,62 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-10-29 17:54:30.075\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mModel: meta-llama/Llama-3.2-11B-Vision-Instruct\u001b[0m\n",
"\u001b[32m2024-10-29 17:54:30.076\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m64\u001b[0m - \u001b[1mDevice: cuda\u001b[0m\n",
"\u001b[32m2024-10-29 17:54:30.076\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mDtype: bfloat16\u001b[0m\n",
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a54363be84904348a857e9d36e3e1462",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import xinfer \n",
"\n",
"model = xinfer.create_model(\"meta-llama/Llama-3.2-11B-Vision\", device=\"cuda\", dtype=\"bfloat16\")\n",
"model = xinfer.create_model(\"meta-llama/Llama-3.2-11B-Vision-Instruct\", device=\"cuda\", dtype=\"bfloat16\")\n",
"# model = xinfer.create_model(\"Salesforce/blip2-opt-2.7b\", device=\"cuda\", dtype=\"bfloat16\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'The image is a screenshot from an anime featuring a character with long white hair, elf-like ears, and green eyes. She'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image = \"https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg\"\n",
"prompt = \"Describe image.\"\n",
"\n",
"# model.infer(image, prompt, max_new_tokens=25)"
"model.infer(image, prompt, max_new_tokens=25)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion xinfer/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .blip2 import BLIP2
from .joycaption import JoyCaption
from .llama32 import Llama32
from .llama32 import Llama32Vision, Llama32VisionInstruct
from .moondream import Moondream
from .vision2seq import Vision2SeqModel
from .vlrm_blip2 import VLRMBlip2
56 changes: 45 additions & 11 deletions xinfer/transformers/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,12 @@
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
@register_model(
"meta-llama/Llama-3.2-11B-Vision",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
@register_model(
"meta-llama/Llama-3.2-90B-Vision-Instruct",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
@register_model(
"meta-llama/Llama-3.2-90B-Vision",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
class Llama32(BaseModel):
class Llama32VisionInstruct(BaseModel):
def __init__(
self, model_id: str, device: str = "cpu", dtype: str = "float32", **kwargs
):
Expand Down Expand Up @@ -97,3 +87,47 @@ def infer_batch(self, images: list[str], prompts: list[str], **generate_kwargs):
decoded = self.processor.batch_decode(outputs, skip_special_tokens=True)
# Remove the prompt and assistant marker for each response
return [d.split("assistant")[-1].strip() for d in decoded]


@register_model(
"meta-llama/Llama-3.2-11B-Vision",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
@register_model(
"meta-llama/Llama-3.2-90B-Vision",
"transformers",
ModelInputOutput.IMAGE_TEXT_TO_TEXT,
)
class Llama32Vision(Llama32VisionInstruct):
@track_inference
def infer(self, image: str, prompt: str, **generate_kwargs) -> str:
image = super().parse_images(image)

# Format prompt for base vision model
input_text = f"<|image|><|begin_of_text|>{prompt}"

# Process inputs without adding special tokens
inputs = self.processor(image, input_text, return_tensors="pt").to(
self.model.device
)

with torch.inference_mode():
output = self.model.generate(**inputs, **generate_kwargs)

return self.processor.decode(output[0], skip_special_tokens=True)

def infer_batch(self, images: list[str], prompts: list[str], **generate_kwargs):
images = super().parse_images(images)

# Format prompts for base vision model
input_texts = [f"<|image|><|begin_of_text|>{prompt}" for prompt in prompts]

inputs = self.processor(images, input_texts, return_tensors="pt").to(
self.model.device
)

with torch.inference_mode():
outputs = self.model.generate(**inputs, **generate_kwargs)

return self.processor.batch_decode(outputs, skip_special_tokens=True)

0 comments on commit 9e11cd4

Please sign in to comment.