Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Oct 10, 2024
1 parent f707477 commit 49bc949
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
4 changes: 2 additions & 2 deletions InferX/transformers/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def preprocess(self, image: str | Image.Image, prompt: str = None):
self.device
)

def predict(self, processed_data, **generate_kwargs):
def predict(self, preprocessed_input, **generate_kwargs):
with torch.inference_mode(), torch.amp.autocast(
device_type=self.device, dtype=torch.bfloat16
):
return self.model.generate(**processed_data, **generate_kwargs)
return self.model.generate(**preprocessed_input, **generate_kwargs)

def postprocess(self, prediction):
return self.processor.batch_decode(prediction, skip_special_tokens=True)[0]
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,18 @@ output = model.postprocess(prediction)
print(output)

>>> A cat on a yellow background


image = "https://img.freepik.com/free-photo/adorable-black-white-kitty-with-monochrome-wall-her_23-2148955182.jpg"
prompt = "Describe this image in concise detail. Answer:"


processed_input = model.preprocess(image, prompt)

prediction = model.predict(processed_input, max_new_tokens=200)
output = model.postprocess(prediction)

print(output)
>>> a black and white cat sitting on a table looking up at the camera

```

0 comments on commit 49bc949

Please sign in to comment.