From 3b28a2cd133490da120763621ff00676e27dcb06 Mon Sep 17 00:00:00 2001 From: dnth Date: Fri, 8 Nov 2024 16:19:09 +0800 Subject: [PATCH] autocast --- xinfer/transformers/llama32.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xinfer/transformers/llama32.py b/xinfer/transformers/llama32.py index 34ab1e6..54c4398 100644 --- a/xinfer/transformers/llama32.py +++ b/xinfer/transformers/llama32.py @@ -126,7 +126,9 @@ def infer(self, image: str, text: str, **generate_kwargs) -> Result: self.model.device ) - with torch.inference_mode(): + with torch.inference_mode(), torch.amp.autocast( + device_type=self.device, dtype=self.dtype + ): output = self.model.generate(**inputs, **generate_kwargs) return Result(text=self.processor.decode(output[0], skip_special_tokens=True)) @@ -143,7 +145,9 @@ def infer_batch( self.model.device ) - with torch.inference_mode(): + with torch.inference_mode(), torch.amp.autocast( + device_type=self.device, dtype=self.dtype + ): outputs = self.model.generate(**inputs, **generate_kwargs) return [