-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkpoint.py
171 lines (136 loc) · 4.79 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import json
from pathlib import Path
import torch
from typing import Optional, Dict, Generator, Tuple
from collections import OrderedDict
class CheckpointManager:
def __init__(self, path: str) -> None:
self.path = Path(path)
# Initial version of the checkpoints
self.version = self.load_version()
self.old_version = 0
# Prepare checkpoints paths
self.prepare()
# Initialize training stats
self.stats = {}
# Initialize number of models to save
self.n_models = 1
def prepare(self) -> None:
self.path.mkdir(parents=True, exist_ok=True)
def get_version_file(self, path: Optional[Path] = None) -> Path:
if path is None:
path = self.path
return path / "version.txt"
def get_config_file(self, path: Optional[Path] = None) -> Path:
if path is None:
path = self.path
return path / "config.json"
def get_model_file(self, version: int, path: Optional[Path] = None) -> Path or Tuple:
if path is None:
path = self.path
paths = []
for n in range(self.n_models):
paths.append(path / f"model_m{n}_v{version}.pt")
return paths
# return path / f"model_v{version}.pt"
def get_stats_file(self, path: Optional[Path] = None) -> Path:
if path is None:
path = self.path
return path / "training_stats.json"
def save_config(self, config: Dict) -> None:
config_json = json.dumps(config, indent=4)
with self.get_config_file().open("wt") as tf:
tf.write(config_json)
def load_config(self) -> str:
try:
with self.get_config_file().open("rt") as tf:
return tf.read()
except FileNotFoundError as err:
raise err
def save_model(
self,
state_dicts: Dict[str, torch.Tensor] or list,
) -> None:
"""
Load the model state dicts
:param state_dicts: State dict of the model
:return:
"""
paths = self.get_model_file(self.version)
if self.n_models == 1:
torch.save(state_dicts, paths[0])
return
for path, dic in zip(paths, state_dicts):
try:
torch.save(dic, path)
except FileNotFoundError as err:
raise err
def load_model(self) -> Dict[str, torch.Tensor] or list:
"""
Load the model state dicts
:param state_dicts: State dict of the model
:return:
"""
paths = self.get_model_file(self.version)
if self.n_models == 1:
return torch.load(paths[0])
state_dicts = []
for path in paths:
try:
state_dicts.append(torch.load(path))
except FileNotFoundError as err:
raise err
return state_dicts
def save_version(self, version: int) -> None:
with self.get_version_file().open("wt") as tf:
tf.write(f"{version}\n")
tf.flush()
os.fsync(tf.fileno())
def load_version(self) -> int:
try:
with self.get_version_file().open("rt") as tf:
version_string = tf.read().strip()
except FileNotFoundError:
return 0
else:
if len(version_string) == 0:
return 0
else:
return int(version_string)
def append_stats(self, stats: Dict) -> None:
stats_json = json.dumps(stats)
with self.get_stats_file().open("at") as tf:
tf.write(f"{stats_json}\n")
def load_stats(self) -> Generator[str, None, None]:
try:
with self.get_stats_file().open("rt") as tf:
for line in tf:
yield line
except FileNotFoundError as err:
raise err
def write_new_version(self, config: OrderedDict, state_dicts: list or OrderedDict, epoch_stats: Dict) -> None:
if self.version == 0:
self.save_config(config)
# Update to new version
self.old_version = self.version
self.version = epoch_stats["Epoch"]
self.save_version(self.version)
# Save training stats here
# Format epoch stat
for s, v in epoch_stats.items():
if type(v) != int:
epoch_stats[s] = round(v, 5)
# Save training stats here
self.append_stats(epoch_stats)
# Save model state dicts
self.save_model(state_dicts)
def remove_old_version(self) -> None:
old_version = self.old_version
# Remove older model
paths = self.get_model_file(old_version)
for path in paths:
try:
path.unlink()
except FileNotFoundError:
pass