-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
169 lines (151 loc) · 6.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
import torchvision
from datetime import datetime
from dcgan import NetG, NetD
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from config import Config, GetLastParams
from torch.utils.tensorboard import SummaryWriter
from torch.nn import BCELoss
def train(config):
# 加载参数配置
device = config.device
# 数据加载配置
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)
)
])
# 加载数据集
dataset = datasets.ImageFolder(config.data_path, transform=transform)
dataloader = DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_works,
drop_last=True
)
# 实例化网络
gen = NetG(config).to(device)
disc = NetD(config).to(device)
# 定义优化器
opt_gen = Adam(gen.parameters(), config.gen_lr)
opt_disc = Adam(disc.parameters(), config.disc_lr)
criterion = BCELoss().to(device)
# 加载预训练参数(如果存在)
if os.path.exists(config.pre_gen_path):
gen.load_state_dict(torch.load(config.pre_gen_path))
print("生成器预训练参数加载成功!")
if os.path.exists(config.pre_disc_path):
disc.load_state_dict(torch.load(config.pre_disc_path))
print("判别器预训练参数加载成功!")
# 自动加载最新参数
if cfg.autoLoad:
if os.path.exists(GetLastParams()):
gen.load_state_dict(torch.load(GetLastParams() + "/last_gen_params.pt"))
disc.load_state_dict(torch.load(GetLastParams() + "/last_disc_params.pt"))
print(f"自动加载参数:{GetLastParams()}!")
else:
print("找不到参数! 将从零开始训练")
# 创建运行数据文件夹
start_time = datetime.now().strftime("%Y-%m%d-%H%M")
if not os.path.exists(f"checkpoints/{start_time}"):
os.mkdir(f"checkpoints/{start_time}")
# 运行日志输出(生成器生成的数据)
fixed_noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device)
writer_fake = SummaryWriter(f"runs/{start_time}/imgs/fake")
writer_real = SummaryWriter(f"runs/{start_time}/imgs/real")
writer_disc = SummaryWriter(f"runs/{start_time}/graphs/disc")
writer_gen = SummaryWriter(f"runs/{start_time}/graphs/gen")
step = 0
# 模型训练
for epoch in range(config.epochs):
# 模型指标
lossD_all = 0
lossG_all = 0
batch_num = 0
# 批次迭代
for batch_idx, (real_img, _) in enumerate(dataloader):
batch_num += 1
real = real_img.to(device)
batch_size = real.shape[0]
"""
训练判别器
step 1 : 生成器生成一张假图
step 2 : 从数据集中调取一张真图
step 3 : 计算真图, 假图损失, 判别器损失
"""
# 生成器基于随机噪声"生成器"生成一组假图
noise = torch.randn(batch_size, config.noise_dim, 1, 1).to(device)
fake = gen.forward(noise)
# 记录生成器网络结构
if epoch == 0 and batch_idx == 0:
writer_gen.add_graph(gen, noise)
writer_gen.close()
# 计算损失
disc_real_labs = disc.forward(real).view(-1)
disc_fake_labs = disc.forward(fake).view(-1)
# 记录判别器网络结构
if epoch == 0 and batch_idx == 0:
writer_disc.add_graph(disc,real)
writer_disc.close()
lossD_real = criterion(disc_real_labs, torch.ones_like(disc_real_labs))
lossD_fake = criterion(disc_fake_labs, torch.zeros_like(disc_fake_labs))
lossD = (lossD_real + lossD_fake) / 2
lossD_all += lossD.item()
# 计算梯度并传递
disc.zero_grad() # 消除上一次的梯度影响
lossD.backward(retain_graph=True)
opt_disc.step()
"""
训练生成器
step 1 : 对生成器生成的假图判断
step 2 : 计算判断器对假图的损失
step 3 : 更新参数
"""
output = disc.forward(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
lossG_all += lossG.item()
gen.zero_grad()
lossG.backward()
opt_gen.step()
"""
tensorboard 可视化训练过程
"""
if batch_idx == 0:
with torch.no_grad():
fake = gen(fixed_noise).reshape(-1, 3, 96, 96)
data = real.reshape(-1, 3, 96, 96)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image(
"Fake Images", img_grid_fake, global_step=step
)
writer_real.add_image(
"Real Images", img_grid_real, global_step=step
)
step += 1
# 更新一次迭代的平均损失
lossD_all = lossD_all / batch_num
lossG_all = lossG_all / batch_num
"""
保存模型
"""
if config.modelSave:
torch.save(gen, f"models/{start_time}/gen_model.pt")
torch.save(disc, f"models/{start_time}/disc_model.pt")
if (epoch + 1) % 50 == 0:
torch.save(gen.state_dict(), f"checkpoints/{start_time}/epoch{epoch + 1}_gen_params.pt")
torch.save(disc.state_dict(), f"checkpoints/{start_time}/epoch{epoch + 1}_disc_params.pt")
# 默认保存最新参数
torch.save(gen.state_dict(), f"checkpoints/{start_time}/last_gen_params.pt")
torch.save(disc.state_dict(), f"checkpoints/{start_time}/last_disc_params.pt")
print(f"Epoch:{epoch + 1}, LossD:{lossD_all:.4f}, LossG:{lossG_all:.4f}")
if __name__ == '__main__':
cfg = Config()
train(config=cfg)