Skip to content

Commit

Permalink
autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Nov 8, 2024
1 parent 862fe43 commit dafbb9d
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions xinfer/transformers/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def infer(self, image: str, text: str, **generate_kwargs) -> Result:
return_tensors="pt",
).to(self.model.device)

with torch.inference_mode():
with torch.inference_mode(), torch.amp.autocast(
device_type=self.device, dtype=self.dtype
):
output = self.model.generate(**inputs, **generate_kwargs)

decoded = self.processor.decode(output[0], skip_special_tokens=True)
Expand Down Expand Up @@ -91,7 +93,9 @@ def infer_batch(
padding=True,
).to(self.model.device)

with torch.inference_mode():
with torch.inference_mode(), torch.amp.autocast(
device_type=self.device, dtype=self.dtype
):
outputs = self.model.generate(**inputs, **generate_kwargs)

decoded = self.processor.batch_decode(outputs, skip_special_tokens=True)
Expand Down

0 comments on commit dafbb9d

Please sign in to comment.