Skip to content

Commit

Permalink
initial blip2 model integration
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Oct 10, 2024
1 parent a7d24d0 commit c3a4a62
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 0 deletions.
2 changes: 2 additions & 0 deletions InferX/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
__author__ = """Dickson Neoh"""
__email__ = "dickson.neoh@gmail.com"
__version__ = "0.0.1"

from .model_factory import get_model
19 changes: 19 additions & 0 deletions InferX/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod


class BaseModel(ABC):
@abstractmethod
def load_model(self, **kwargs):
pass

@abstractmethod
def preprocess(self, input_data):
pass

@abstractmethod
def predict(self, processed_data):
pass

@abstractmethod
def postprocess(self, prediction):
pass
9 changes: 9 additions & 0 deletions InferX/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .transformers.blip2 import BLIP2


def get_model(model_type: str, implementation: str, **kwargs):
if implementation == "transformers":
if model_type == "Salesforce/blip2-opt-2.7b":
return BLIP2(model_name="Salesforce/blip2-opt-2.7b", **kwargs)
else:
raise ValueError(f"Unsupported model type: {model_type}")
47 changes: 47 additions & 0 deletions InferX/transformers/blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import requests
import torch
from PIL import Image
from transformers import Blip2ForConditionalGeneration, Blip2Processor

from ..base_model import BaseModel


class BLIP2(BaseModel):
def __init__(self, model_name: str, **kwargs):
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model(**kwargs)

def load_model(self, **kwargs):
self.processor = Blip2Processor.from_pretrained(self.model_name, **kwargs)
self.model = Blip2ForConditionalGeneration.from_pretrained(
self.model_name, **kwargs
).to(self.device, torch.bfloat16)

self.model = torch.compile(self.model, mode="max-autotune")

self.model.eval()

def preprocess(self, image, prompt=None):
if isinstance(image, str):
if image.startswith(("http://", "https://")):
image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
else:
raise ValueError("Input string must be an image URL for BLIP2")
else:
raise ValueError(
"Input must be either an image URL or a PIL Image for BLIP2"
)

return self.processor(images=image, text=prompt, return_tensors="pt").to(
self.device
)

def predict(self, processed_data):
with torch.inference_mode(), torch.amp.autocast(
device_type=self.device, dtype=torch.bfloat16
):
return self.model.generate(**processed_data)

def postprocess(self, prediction):
return self.processor.batch_decode(prediction, skip_special_tokens=True)[0]
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ Install InferX using pip:
pip install inferx
```

If using transformers, install the extra dependency:
```bash
pip install -e ".[transformers]"
```

## Getting Started

Here's a quick example demonstrating how to use InferX with a Transformers model:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ extra = [
"pandas",
]

transformers = [
"transformers",

]


[tool]
[tool.setuptools.packages.find]
Expand Down
73 changes: 73 additions & 0 deletions tests/debug.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dnth/mambaforge-pypy3/envs/inferx/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.57it/s]\n"
]
}
],
"source": [
"from InferX import get_model\n",
"\n",
"# Instantiate a Transformers model\n",
"model = get_model(\"Salesforce/blip2-opt-2.7b\", implementation=\"transformers\")\n",
"\n",
"# Input data (can be text, image URL, or PIL Image)\n",
"input_data = \"https://img.freepik.com/free-photo/adorable-black-white-kitty-with-monochrome-wall-her_23-2148955182.jpg\"\n",
"question = \"What's in this image? Answer:\"\n",
"\n",
"# Run inference\n",
"processed_input = model.preprocess(image=input_data, prompt=question)\n",
"\n",
"prediction = model.predict(processed_input)\n",
"output = model.postprocess(prediction)\n",
"\n",
"print(output)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "inferx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit c3a4a62

Please sign in to comment.