diff --git a/pyproject.toml b/pyproject.toml index 7f58c70..32c9b60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "fsspec>=2023.6.0", "numpy>=1.24.3; python_version < '3.12'", "numpy>=1.26.0; python_version >= '3.12'", + "packaging>=21.0", "psutil>=5", "soundfile>=0.12.0", "tensorboard>=2.17.0", diff --git a/trainer/trainer.py b/trainer/trainer.py index 2300baa..8407a95 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -14,6 +14,7 @@ import torch import torch.distributed as dist from coqpit import Coqpit +from packaging.version import Version from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader @@ -883,7 +884,10 @@ def _model_train_step( def _get_autocast_args(self, mixed_precision: bool, precision: str): device = "cpu" - dtype = torch.get_autocast_cpu_dtype() + if Version(torch.__version__) >= Version("2.4"): + dtype = torch.get_autocast_dtype("cpu") + else: + dtype = torch.get_autocast_cpu_dtype() if self.use_cuda: device = "cuda" dtype = torch.float32