Skip to content

Commit

Permalink
Add Molmo models from vLLM (#26)
Browse files Browse the repository at this point in the history
* ad molmo

* update molmo supported by vllm

* strip whitespace

* update moondream to use parse method from baseclass

* remove hardcoded local model

* Bump version: 0.0.8 → 0.0.9

* update readme

* update actions
  • Loading branch information
dnth authored Oct 23, 2024
1 parent ba3f431 commit a62c1f0
Show file tree
Hide file tree
Showing 13 changed files with 823 additions and 122 deletions.
9 changes: 0 additions & 9 deletions .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install GDAL
run: |
python -m pip install --upgrade pip
pip install --find-links=https://girder.github.io/large_image_wheels --no-cache GDAL pyproj
- name: Test GDAL installation
run: |
python -c "from osgeo import gdal"
gdalinfo --version
- name: Install dependencies
run: |
pip install --no-cache-dir Cython
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
pip install --user --no-cache-dir Cython
pip install --user -r requirements.txt -r requirements_dev.txt
pip install .
- name: Discover typos with codespell
run: |
codespell --skip="*.csv,*.geojson,*.json,*.js,*.html,*cff,./.git,*.py,*.ipynb" --ignore-words-list="aci,hist"
# - name: Discover typos with codespell
# run: |
# codespell --skip="*.csv,*.geojson,*.json,*.js,*.html,*cff,./.git,*.py,*.ipynb" --ignore-words-list="aci,hist"
- name: PKG-TEST
run: |
python -m unittest discover tests/
Expand Down
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,31 @@ model = UltralyticsModel("yolov5n6u")
model = xinfer.create_model(model)
```

vLLM:

<table>
<thead>
<tr>
<th>Model</th>
<th>Usage</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://huggingface.co/allenai/Molmo-72B-0924">Molmo-72B</a></td>
<td><code>xinfer.create_model("allenai/Molmo-72B-0924")</code></td>
</tr>
<tr>
<td><a href="https://huggingface.co/allenai/Molmo-7B-D-0924">Molmo-7B-D</a></td>
<td><code>xinfer.create_model("allenai/Molmo-7B-D-0924")</code></td>
</tr>
<tr>
<td><a href="https://huggingface.co/allenai/Molmo-7B-O-0924">Molmo-7B-O</a></td>
<td><code>xinfer.create_model("allenai/Molmo-7B-O-0924")</code></td>
</tr>
</tbody>
</table>


### 🔧 Adding New Models

Expand Down
125 changes: 50 additions & 75 deletions nbs/transformers.ipynb

Large diffs are not rendered by default.

611 changes: 611 additions & 0 deletions nbs/vllm.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "xinfer"
version = "0.0.8"
version = "0.0.9"
dynamic = [
"dependencies",
]
Expand Down Expand Up @@ -46,7 +46,7 @@ universal = true


[tool.bumpversion]
current_version = "0.0.8"
current_version = "0.0.9"
commit = true
tag = true

Expand Down
11 changes: 9 additions & 2 deletions xinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@

__author__ = """Dickson Neoh"""
__email__ = "dickson.neoh@gmail.com"
__version__ = "0.0.8"
__version__ = "0.0.9"

from .core import create_model, list_models
from .model_registry import ModelInputOutput, register_model
from .models import BaseModel
from .utils import timm_available, transformers_available, ultralytics_available
from .utils import (
timm_available,
transformers_available,
ultralytics_available,
vllm_available,
)

if timm_available:
from .timm import *
if transformers_available:
from .transformers import *
if ultralytics_available:
from .ultralytics import *
if vllm_available:
from .vllm import *

__all__ = [
"create_model",
Expand Down
16 changes: 16 additions & 0 deletions xinfer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@


def create_model(model: str | TimmModel | Vision2SeqModel | UltralyticsModel, **kwargs):
"""
Create a model instance.
Parameters
----------
model : str | TimmModel | Vision2SeqModel | UltralyticsModel
The model to create.
TIMM, Vision2Seq, and Ultralytics models type here is to support user passing in the models directly.
This is useful for models not registered in the model registry.
Eg:
```python
model = UltralyticsModel("yolov5n6u")
model = xinfer.create_model(model)
```
"""
if isinstance(model, (TimmModel, Vision2SeqModel, UltralyticsModel)):
return model
return model_registry.get_model(model, **kwargs)
Expand Down
43 changes: 43 additions & 0 deletions xinfer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager

import requests
import torch
from loguru import logger
from PIL import Image
from rich import box
from rich.console import Console
from rich.table import Table
Expand Down Expand Up @@ -88,3 +90,44 @@ def __repr__(self) -> str:
f"device='{self.device}', "
f"dtype='{self.dtype}', "
)

def parse_images(
self,
images: str | list[str],
) -> list[Image.Image]:
"""
Preprocess one or more images from file paths or URLs.
Loads and converts images to RGB format from either local file paths or URLs.
Can handle both single image input or multiple images as a list.
Args:
images (Union[str, List[str]]): Either a single image path/URL as a string,
or a list of image paths/URLs. Accepts both local file paths and HTTP(S) URLs.
Returns:
List[PIL.Image.Image]: List of processed PIL Image objects in RGB format.
"""

if not isinstance(images, list):
images = [images]

parsed_images = []
for image_path in images:
if not isinstance(image_path, str):
raise ValueError("Input must be a string (local path or URL)")

if image_path.startswith(("http://", "https://")):
image = Image.open(requests.get(image_path, stream=True).raw).convert(
"RGB"
)
else:
# Assume it's a local path
try:
image = Image.open(image_path).convert("RGB")
except FileNotFoundError:
raise ValueError(f"Local file not found: {image_path}")

parsed_images.append(image)

return parsed_images
33 changes: 2 additions & 31 deletions xinfer/transformers/moondream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer

from ..model_registry import ModelInputOutput, register_model
Expand All @@ -22,33 +20,6 @@ def __init__(
self.revision = revision
self.load_model()

def preprocess(
self,
images: str | list[str],
):
if not isinstance(images, list):
images = [images]

processed_images = []
for image_path in images:
if not isinstance(image_path, str):
raise ValueError("Input must be a string (local path or URL)")

if image_path.startswith(("http://", "https://")):
image = Image.open(requests.get(image_path, stream=True).raw).convert(
"RGB"
)
else:
# Assume it's a local path
try:
image = Image.open(image_path).convert("RGB")
except FileNotFoundError:
raise ValueError(f"Local file not found: {image_path}")

processed_images.append(image)

return processed_images

def load_model(self):
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, trust_remote_code=True, revision=self.revision
Expand All @@ -60,7 +31,7 @@ def load_model(self):

def infer(self, image: str, prompt: str = None, **generate_kwargs):
with self.track_inference_time():
image = self.preprocess(image)
image = self.parse_images(image)
encoded_image = self.model.encode_image(image)
output = self.model.answer_question(
question=prompt,
Expand All @@ -74,7 +45,7 @@ def infer(self, image: str, prompt: str = None, **generate_kwargs):

def infer_batch(self, images: list[str], prompts: list[str], **generate_kwargs):
with self.track_inference_time():
images = self.preprocess(images)
images = self.parse_images(images)
prompts = [prompt for prompt in prompts]

outputs = self.model.batch_answer(
Expand Down
1 change: 1 addition & 0 deletions xinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def soft_import(name: str):
timm_available = soft_import("timm")
transformers_available = soft_import("transformers")
ultralytics_available = soft_import("ultralytics")
vllm_available = soft_import("vllm")


# Create placeholder classes
Expand Down
1 change: 1 addition & 0 deletions xinfer/vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .molmo import Molmo
60 changes: 60 additions & 0 deletions xinfer/vllm/molmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from vllm import LLM, SamplingParams

from ..model_registry import ModelInputOutput, register_model
from ..models import BaseModel


@register_model("allenai/Molmo-72B-0924", "vllm", ModelInputOutput.IMAGE_TEXT_TO_TEXT)
@register_model("allenai/Molmo-7B-O-0924", "vllm", ModelInputOutput.IMAGE_TEXT_TO_TEXT)
@register_model("allenai/Molmo-7B-D-0924", "vllm", ModelInputOutput.IMAGE_TEXT_TO_TEXT)
class Molmo(BaseModel):
def __init__(
self,
model_id: str,
device: str = "cuda",
dtype: str = "float32",
**kwargs,
):
super().__init__(model_id, device, dtype)
self.load_model(**kwargs)

def load_model(self, **kwargs):
self.model = LLM(
model=self.model_id,
trust_remote_code=True,
dtype=self.dtype,
**kwargs,
)

def infer_batch(self, images: list[str], prompts: list[str], **sampling_kwargs):
images = self.parse_images(images)

sampling_params = SamplingParams(**sampling_kwargs)
with self.track_inference_time():
batch_inputs = [
{
"prompt": f"USER: <image>\n{prompt}\nASSISTANT:",
"multi_modal_data": {"image": image},
}
for image, prompt in zip(images, prompts)
]

results = self.model.generate(batch_inputs, sampling_params)

self.update_inference_count(len(images))
return [output.outputs[0].text.strip() for output in results]

def infer(self, image: str, prompt: str, **sampling_kwargs):
with self.track_inference_time():
image = self.parse_images(image)

inputs = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}

sampling_params = SamplingParams(**sampling_kwargs)
outputs = self.model.generate(inputs, sampling_params)
generated_text = outputs[0].outputs[0].text.strip()
self.update_inference_count(1)
return generated_text

0 comments on commit a62c1f0

Please sign in to comment.