Skip to content

Commit

Permalink
update florence test
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Nov 8, 2024
1 parent cc062a2 commit 72daf0e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
25 changes: 17 additions & 8 deletions tests/smoke/test_florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import xinfer

TEST_DATA_DIR = Path(__file__).parent.parent / "test_data"


@pytest.fixture
def model():
Expand All @@ -14,8 +16,11 @@ def model():


@pytest.fixture
def test_image():
return str(Path(__file__).parent.parent / "test_data" / "test_image_1.jpg")
def test_images():
return [
str(TEST_DATA_DIR / "test_image_1.jpg"),
str(TEST_DATA_DIR / "test_image_2.jpg"),
]


def test_florence2_initialization(model):
Expand All @@ -24,17 +29,21 @@ def test_florence2_initialization(model):
assert model.dtype == torch.float32


def test_florence2_inference(model, test_image):
def test_florence2_inference(model, test_images):
prompt = "<CAPTION>"
result = model.infer(test_image, prompt)
result = model.infer(test_images[0], prompt)

assert isinstance(result, str)
assert len(result) > 0
assert isinstance(result.text, str)
assert len(result.text) > 0


def test_florence2_batch_inference(model, test_image):
def test_florence2_batch_inference(model, test_images):
prompt = "<CAPTION>"
result = model.infer_batch([test_image, test_image], [prompt, prompt])
result = model.infer_batch(test_images, [prompt, prompt])

assert isinstance(result, list)
assert len(result) == 2
assert isinstance(result[0].text, str)
assert isinstance(result[1].text, str)
assert len(result[0].text) > 0
assert len(result[1].text) > 0
6 changes: 2 additions & 4 deletions xinfer/transformers/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def load_model(self):
@track_inference
def infer(self, image: str, text: str, **generate_kwargs) -> Result:
output = self.infer_batch([image], [text], **generate_kwargs)
return Result(text=output[0])
return output[0]

@track_inference
def infer_batch(
Expand Down Expand Up @@ -81,6 +81,4 @@ def infer_batch(
for text, prompt, img in zip(generated_text, texts, images)
]

results = [Result(text=text) for text in parsed_answers]

return results
return [Result(text=text) for text in parsed_answers]

0 comments on commit 72daf0e

Please sign in to comment.