Skip to content

Commit

Permalink
Merge pull request #8 from shavit/torch-load-weights-only
Browse files Browse the repository at this point in the history
Load weights only in torch.load
  • Loading branch information
eginhard authored Sep 12, 2024
2 parents 8025277 + f796733 commit 2f13956
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
- name: Upload coverage data
uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: coverage-data-${{ matrix.python-version }}-${{ matrix.uv-resolution }}
path: .coverage.*
if-no-files-found: ignore
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ include = ["trainer*"]

[project]
name = "coqui-tts-trainer"
version = "0.1.4"
version = "0.1.5"
description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui."
readme = "README.md"
requires-python = ">=3.9, <3.13"
Expand Down
4 changes: 2 additions & 2 deletions trainer/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def load_fsspec(
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
mode="rb",
) as f:
return torch.load(f, map_location=map_location, **kwargs)
return torch.load(f, map_location=map_location, weights_only=True, **kwargs)
else:
with fsspec.open(str(path), "rb") as f:
return torch.load(f, map_location=map_location, **kwargs)
return torch.load(f, map_location=map_location, weights_only=True, **kwargs)


def load_checkpoint(
Expand Down

0 comments on commit 2f13956

Please sign in to comment.