Skip to content

Commit

Permalink
autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Nov 8, 2024
1 parent dafbb9d commit 3b28a2c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions xinfer/transformers/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def infer(self, image: str, text: str, **generate_kwargs) -> Result:
self.model.device
)

with torch.inference_mode():
with torch.inference_mode(), torch.amp.autocast(
device_type=self.device, dtype=self.dtype
):
output = self.model.generate(**inputs, **generate_kwargs)

return Result(text=self.processor.decode(output[0], skip_special_tokens=True))
Expand All @@ -143,7 +145,9 @@ def infer_batch(
self.model.device
)

with torch.inference_mode():
with torch.inference_mode(), torch.amp.autocast(
device_type=self.device, dtype=self.dtype
):
outputs = self.model.generate(**inputs, **generate_kwargs)

return [
Expand Down

0 comments on commit 3b28a2c

Please sign in to comment.