Skip to content

Commit

Permalink
fix(trainer): use torch.cuda.amp.GradScaler
Browse files Browse the repository at this point in the history
torch.amp.GradScaler is only available from Pytorch 2.3
  • Loading branch information
eginhard committed Jun 28, 2024
1 parent f8d5748 commit 572e698
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def _set_grad_clip_per_optimizer(config: Coqpit, optimizer_idx: int):
def _compute_grad_norm(self, optimizer: torch.optim.Optimizer):
return torch.norm(torch.cat([param.grad.view(-1) for param in self.master_params(optimizer)], dim=0), p=2)

def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.amp.GradScaler):
def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler):
"""Perform gradient clipping"""
if grad_clip is not None and grad_clip > 0:
if scaler:
Expand All @@ -960,7 +960,7 @@ def optimize(
batch: dict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: torch.amp.GradScaler,
scaler: torch.cuda.amp.GradScaler,
criterion: nn.Module,
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, list, dict], # pylint: disable=protected-access
config: Coqpit,
Expand Down

0 comments on commit 572e698

Please sign in to comment.