From ae9964b43e4ae4c66697fdf3a94fd2b7c0f2fc04 Mon Sep 17 00:00:00 2001 From: Dickson Neoh Date: Tue, 29 Oct 2024 21:20:10 +0800 Subject: [PATCH] Add Florence 2 model series by Microsoft (#42) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial implementation * simplify * add other florence 2 * add florence 2 test * Bump version: 0.1.1 → 0.1.2 * update readme * update quickstart --- README.md | 6 +- nbs/florence-2.ipynb | 372 +++++++++++++++++++++++++++++++ nbs/quickstart.ipynb | 6 +- pyproject.toml | 4 +- tests/smoke/test_florence2.py | 40 ++++ xinfer/__init__.py | 2 +- xinfer/transformers/__init__.py | 1 + xinfer/transformers/florence2.py | 83 +++++++ 8 files changed, 507 insertions(+), 7 deletions(-) create mode 100644 nbs/florence-2.ipynb create mode 100644 tests/smoke/test_florence2.py create mode 100644 xinfer/transformers/florence2.py diff --git a/README.md b/README.md index ef15353..53892f6 100644 --- a/README.md +++ b/README.md @@ -256,9 +256,13 @@ pip install -e .
xinfer.create_model("fancyfeast/llama-joycaption-alpha-two-hf-llava")
- Llama-3.2 Vision + Llama-3.2 Vision Series
xinfer.create_model("meta-llama/Llama-3.2-11B-Vision-Instruct")
+ + Florence-2 Series +
xinfer.create_model("microsoft/Florence-2-base-ft")
+ diff --git a/nbs/florence-2.ipynb b/nbs/florence-2.ipynb new file mode 100644 index 0000000..4ea21b8 --- /dev/null +++ b/nbs/florence-2.ipynb @@ -0,0 +1,372 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
                            Available Models                            \n",
+       "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Implementation  Model ID                       Input --> Output    ┃\n",
+       "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ transformers    microsoft/Florence-2-base-ft   image-text --> text │\n",
+       "│ transformers    microsoft/Florence-2-large-ft  image-text --> text │\n",
+       "│ transformers    microsoft/Florence-2-base      image-text --> text │\n",
+       "│ transformers    microsoft/Florence-2-large     image-text --> text │\n",
+       "└────────────────┴───────────────────────────────┴─────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Available Models \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mImplementation\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mModel ID \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mInput --> Output \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36mtransformers \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mmicrosoft/Florence-2-base-ft \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtransformers \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mmicrosoft/Florence-2-large-ft\u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtransformers \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mmicrosoft/Florence-2-base \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtransformers \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mmicrosoft/Florence-2-large \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", + "└────────────────┴───────────────────────────────┴─────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import xinfer\n", + "\n", + "xinfer.list_models(\"florence-2\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-10-29 21:05:27.411\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m63\u001b[0m - \u001b[1mModel: microsoft/Florence-2-large-ft\u001b[0m\n", + "\u001b[32m2024-10-29 21:05:27.412\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m64\u001b[0m - \u001b[1mDevice: cuda\u001b[0m\n", + "\u001b[32m2024-10-29 21:05:27.412\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mxinfer.models\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m65\u001b[0m - \u001b[1mDtype: float16\u001b[0m\n", + "/home/dnth/mambaforge-pypy3/envs/xinfer/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n", + "Florence2LanguageForConditionalGeneration has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n", + " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n", + " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n", + " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n" + ] + }, + { + "data": { + "text/plain": [ + "'A woman with glasses is gesturing with both hands.'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = xinfer.create_model(\"microsoft/Florence-2-large-ft\", device=\"cuda\", dtype=\"float16\")\n", + "\n", + "image = \"../assets/demo/0a6ee446579d2885.jpg\"\n", + "prompt = \"\"\n", + "model.infer(image, prompt)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.',\n", + " 'A woman with glasses is gesturing with both hands.']" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_size = 40\n", + "model.infer_batch([image] * batch_size, [prompt] * batch_size)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'In this image I can see a woman wearing green color dress and spectacles. Background is in black color.'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = \"../assets/demo/0a6ee446579d2885.jpg\"\n", + "prompt = \"\"\n", + "model.infer(image, prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'A woman is standing in front of a white wall. She is wearing a green shirt and has bracelets on her wrists. The woman has glasses on and her hair is long and brown. '" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = \"../assets/demo/0a6ee446579d2885.jpg\"\n", + "prompt = \"\"\n", + "model.infer(image, prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'bboxes': [], 'labels': []}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = \"../assets/demo/0a6ee446579d2885.jpg\"\n", + "prompt = \"\"\n", + "model.infer(image, prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'bboxes': [[325.1200256347656,\n", + " 221.33999633789062,\n", + " 468.4800109863281,\n", + " 268.94000244140625],\n", + " [337.40802001953125,\n", + " 179.1800079345703,\n", + " 466.4320373535156,\n", + " 340.3399963378906],\n", + " [237.05601501464844, 124.0999984741211, 632.3200073242188, 678.97998046875]],\n", + " 'labels': ['glasses', 'human face', 'woman']}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = \"../assets/demo/0a6ee446579d2885.jpg\"\n", + "prompt = \"\"\n", + "model.infer(image, prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'bboxes': [[221.69601440429688,\n", + " 124.0999984741211,\n", + " 643.5840454101562,\n", + " 678.97998046875],\n", + " [337.40802001953125,\n", + " 179.86000061035156,\n", + " 466.4320373535156,\n", + " 340.3399963378906],\n", + " [325.1200256347656,\n", + " 221.33999633789062,\n", + " 468.4800109863281,\n", + " 268.94000244140625]],\n", + " 'labels': ['woman in green shirt with glasses on stage',\n", + " 'human face',\n", + " 'glasses']}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = \"../assets/demo/0a6ee446579d2885.jpg\"\n", + "prompt = \"\"\n", + "model.infer(image, prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'bboxes': [[237.05601501464844,\n", + " 124.0999984741211,\n", + " 632.3200073242188,\n", + " 678.97998046875],\n", + " [299.52001953125, 126.81999969482422, 501.2480163574219, 343.05999755859375],\n", + " [338.4320068359375, 179.1800079345703, 467.4560241699219, 340.3399963378906],\n", + " [505.3440246582031, 542.97998046875, 612.864013671875, 621.1799926757812],\n", + " [345.6000061035156, 557.9400024414062, 461.31201171875, 627.97998046875],\n", + " [325.1200256347656,\n", + " 220.66000366210938,\n", + " 468.4800109863281,\n", + " 268.94000244140625],\n", + " [390.656005859375, 241.05999755859375, 429.5680236816406, 281.8600158691406],\n", + " [390.656005859375, 289.3399963378906, 440.83203125, 308.3800048828125]],\n", + " 'labels': ['', '', '', '', '', '', '', '']}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image = \"../assets/demo/0a6ee446579d2885.jpg\"\n", + "prompt = \"\"\n", + "model.infer(image, prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7860\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.launch_gradio()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xinfer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/quickstart.ipynb b/nbs/quickstart.ipynb index dd7a086..9f55348 100644 --- a/nbs/quickstart.ipynb +++ b/nbs/quickstart.ipynb @@ -110,7 +110,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It's recommended to restart the kernel once all the dependencies are installed. Uncomment the following line to restart the kernel." + "It's recommended to restart the kernel once all the dependencies are installed." ] }, { @@ -119,8 +119,8 @@ "metadata": {}, "outputs": [], "source": [ - "# from IPython import get_ipython\n", - "# get_ipython().kernel.do_shutdown(restart=True)" + "from IPython import get_ipython\n", + "get_ipython().kernel.do_shutdown(restart=True)" ] }, { diff --git a/pyproject.toml b/pyproject.toml index b81d2f7..ed0f48e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "xinfer" -version = "0.1.1" +version = "0.1.2" dynamic = [ "dependencies", ] @@ -48,7 +48,7 @@ universal = true [tool.bumpversion] -current_version = "0.1.1" +current_version = "0.1.2" commit = true tag = true diff --git a/tests/smoke/test_florence2.py b/tests/smoke/test_florence2.py new file mode 100644 index 0000000..6471b09 --- /dev/null +++ b/tests/smoke/test_florence2.py @@ -0,0 +1,40 @@ +from pathlib import Path + +import pytest +import torch + +import xinfer + + +@pytest.fixture +def model(): + return xinfer.create_model( + "microsoft/Florence-2-base-ft", device="cpu", dtype="float32" + ) + + +@pytest.fixture +def test_image(): + return str(Path(__file__).parent.parent / "test_data" / "test_image_1.jpg") + + +def test_florence2_initialization(model): + assert model.model_id == "microsoft/Florence-2-base-ft" + assert model.device == "cpu" + assert model.dtype == torch.float32 + + +def test_florence2_inference(model, test_image): + prompt = "" + result = model.infer(test_image, prompt) + + assert isinstance(result, str) + assert len(result) > 0 + + +def test_florence2_batch_inference(model, test_image): + prompt = "" + result = model.infer_batch([test_image, test_image], [prompt, prompt]) + + assert isinstance(result, list) + assert len(result) == 2 diff --git a/xinfer/__init__.py b/xinfer/__init__.py index 21ce123..330a5e0 100644 --- a/xinfer/__init__.py +++ b/xinfer/__init__.py @@ -2,7 +2,7 @@ __author__ = """Dickson Neoh""" __email__ = "dickson.neoh@gmail.com" -__version__ = "0.1.1" +__version__ = "0.1.2" from .core import create_model, list_models from .model_registry import ModelInputOutput, register_model diff --git a/xinfer/transformers/__init__.py b/xinfer/transformers/__init__.py index 1953b3a..075cb6f 100644 --- a/xinfer/transformers/__init__.py +++ b/xinfer/transformers/__init__.py @@ -1,4 +1,5 @@ from .blip2 import BLIP2 +from .florence2 import Florence2 from .joycaption import JoyCaption from .llama32 import Llama32Vision, Llama32VisionInstruct from .moondream import Moondream diff --git a/xinfer/transformers/florence2.py b/xinfer/transformers/florence2.py new file mode 100644 index 0000000..228dee3 --- /dev/null +++ b/xinfer/transformers/florence2.py @@ -0,0 +1,83 @@ +import torch +from transformers import AutoModelForCausalLM, AutoProcessor + +from ..model_registry import ModelInputOutput, register_model +from ..models import BaseModel, track_inference + + +@register_model( + "microsoft/Florence-2-large", "transformers", ModelInputOutput.IMAGE_TEXT_TO_TEXT +) +@register_model( + "microsoft/Florence-2-base", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) +@register_model( + "microsoft/Florence-2-large-ft", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) +@register_model( + "microsoft/Florence-2-base-ft", + "transformers", + ModelInputOutput.IMAGE_TEXT_TO_TEXT, +) +class Florence2(BaseModel): + def __init__( + self, + model_id: str, + device: str = "cpu", + dtype: str = "float32", + ): + super().__init__(model_id, device, dtype) + self.load_model() + + def load_model(self): + self.model = AutoModelForCausalLM.from_pretrained( + self.model_id, trust_remote_code=True + ).to(self.device, self.dtype) + self.model.eval() + self.model = torch.compile(self.model, mode="max-autotune") + self.processor = AutoProcessor.from_pretrained( + self.model_id, trust_remote_code=True + ) + + @track_inference + def infer(self, image: str, prompt: str = None, **generate_kwargs) -> str: + output = self.infer_batch([image], [prompt], **generate_kwargs) + return output[0] + + @track_inference + def infer_batch( + self, images: list[str], prompts: list[str] = None, **generate_kwargs + ) -> list[str]: + images = self.parse_images(images) + inputs = self.processor(text=prompts, images=images, return_tensors="pt").to( + self.device, self.dtype + ) + + if "max_new_tokens" not in generate_kwargs: + generate_kwargs["max_new_tokens"] = 1024 + if "num_beams" not in generate_kwargs: + generate_kwargs["num_beams"] = 3 + + with torch.inference_mode(): + generated_ids = self.model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + **generate_kwargs, + ) + + generated_text = self.processor.batch_decode( + generated_ids, skip_special_tokens=False + ) + + parsed_answers = [ + self.processor.post_process_generation( + text, task=prompt, image_size=(img.width, img.height) + ).get(prompt) + for text, prompt, img in zip(generated_text, prompts, images) + ] + + return parsed_answers