diff --git a/xinfer/transformers/llama32.py b/xinfer/transformers/llama32.py index 3e64e60..34ab1e6 100644 --- a/xinfer/transformers/llama32.py +++ b/xinfer/transformers/llama32.py @@ -54,7 +54,9 @@ def infer(self, image: str, text: str, **generate_kwargs) -> Result: return_tensors="pt", ).to(self.model.device) - with torch.inference_mode(): + with torch.inference_mode(), torch.amp.autocast( + device_type=self.device, dtype=self.dtype + ): output = self.model.generate(**inputs, **generate_kwargs) decoded = self.processor.decode(output[0], skip_special_tokens=True) @@ -91,7 +93,9 @@ def infer_batch( padding=True, ).to(self.model.device) - with torch.inference_mode(): + with torch.inference_mode(), torch.amp.autocast( + device_type=self.device, dtype=self.dtype + ): outputs = self.model.generate(**inputs, **generate_kwargs) decoded = self.processor.batch_decode(outputs, skip_special_tokens=True)