diff --git a/trainer/trainer.py b/trainer/trainer.py index bc6cd58..2300baa 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -944,7 +944,7 @@ def _set_grad_clip_per_optimizer(config: Coqpit, optimizer_idx: int): def _compute_grad_norm(self, optimizer: torch.optim.Optimizer): return torch.norm(torch.cat([param.grad.view(-1) for param in self.master_params(optimizer)], dim=0), p=2) - def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.amp.GradScaler): + def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler): """Perform gradient clipping""" if grad_clip is not None and grad_clip > 0: if scaler: @@ -960,7 +960,7 @@ def optimize( batch: dict, model: nn.Module, optimizer: torch.optim.Optimizer, - scaler: torch.amp.GradScaler, + scaler: torch.cuda.amp.GradScaler, criterion: nn.Module, scheduler: Union[torch.optim.lr_scheduler._LRScheduler, list, dict], # pylint: disable=protected-access config: Coqpit,