From 572e698cc9fb1faec5e453a987e3d1bf3700a9c9 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 28 Jun 2024 15:40:56 +0200 Subject: [PATCH] fix(trainer): use torch.cuda.amp.GradScaler torch.amp.GradScaler is only available from Pytorch 2.3 --- trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trainer/trainer.py b/trainer/trainer.py index bc6cd58..2300baa 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -944,7 +944,7 @@ def _set_grad_clip_per_optimizer(config: Coqpit, optimizer_idx: int): def _compute_grad_norm(self, optimizer: torch.optim.Optimizer): return torch.norm(torch.cat([param.grad.view(-1) for param in self.master_params(optimizer)], dim=0), p=2) - def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.amp.GradScaler): + def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler): """Perform gradient clipping""" if grad_clip is not None and grad_clip > 0: if scaler: @@ -960,7 +960,7 @@ def optimize( batch: dict, model: nn.Module, optimizer: torch.optim.Optimizer, - scaler: torch.amp.GradScaler, + scaler: torch.cuda.amp.GradScaler, criterion: nn.Module, scheduler: Union[torch.optim.lr_scheduler._LRScheduler, list, dict], # pylint: disable=protected-access config: Coqpit,