Skip to content

Commit

Permalink
chore(trainer): handle pytorch autocast deprecation
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Oct 15, 2024
1 parent 47fe7cc commit e92aad7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"fsspec>=2023.6.0",
"numpy>=1.24.3; python_version < '3.12'",
"numpy>=1.26.0; python_version >= '3.12'",
"packaging>=21.0",
"psutil>=5",
"soundfile>=0.12.0",
"tensorboard>=2.17.0",
Expand Down
6 changes: 5 additions & 1 deletion trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.distributed as dist
from coqpit import Coqpit
from packaging.version import Version
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -883,7 +884,10 @@ def _model_train_step(

def _get_autocast_args(self, mixed_precision: bool, precision: str):
device = "cpu"
dtype = torch.get_autocast_cpu_dtype()
if Version(torch.__version__) >= Version("2.4"):
dtype = torch.get_autocast_dtype("cpu")
else:
dtype = torch.get_autocast_cpu_dtype()
if self.use_cuda:
device = "cuda"
dtype = torch.float32
Expand Down

0 comments on commit e92aad7

Please sign in to comment.