-
Notifications
You must be signed in to change notification settings - Fork 46
/
main.py
143 lines (123 loc) · 4.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
import torch.optim as optim
import tensorboardX
import os
import random
import numpy as np
from train import train_epoch
from torch.utils.data import DataLoader
from validation import val_epoch
from opts import parse_opts
from model import generate_model
from torch.optim import lr_scheduler
from dataset import get_training_set, get_validation_set
from mean import get_mean, get_std
from spatial_transforms import (
Compose, Normalize, Scale, CenterCrop, CornerCrop, MultiScaleCornerCrop,
MultiScaleRandomCrop, RandomHorizontalFlip, ToTensor)
from temporal_transforms import LoopPadding, TemporalRandomCrop
from target_transforms import ClassLabel, VideoID
from target_transforms import Compose as TargetCompose
def resume_model(opt, model, optimizer):
""" Resume model
"""
checkpoint = torch.load(opt.resume_path)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print("Model Restored from Epoch {}".format(checkpoint['epoch']))
start_epoch = checkpoint['epoch'] + 1
return start_epoch
def get_loaders(opt):
""" Make dataloaders for train and validation sets
"""
# train loader
opt.mean = get_mean(opt.norm_value, dataset=opt.mean_dataset)
if opt.no_mean_norm and not opt.std_norm:
norm_method = Normalize([0, 0, 0], [1, 1, 1])
elif not opt.std_norm:
norm_method = Normalize(opt.mean, [1, 1, 1])
else:
norm_method = Normalize(opt.mean, opt.std)
spatial_transform = Compose([
# crop_method,
Scale((opt.sample_size, opt.sample_size)),
# RandomHorizontalFlip(),
ToTensor(opt.norm_value), norm_method
])
temporal_transform = TemporalRandomCrop(16)
target_transform = ClassLabel()
training_data = get_training_set(opt, spatial_transform,
temporal_transform, target_transform)
train_loader = torch.utils.data.DataLoader(
training_data,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.num_workers,
pin_memory=True)
# validation loader
spatial_transform = Compose([
Scale((opt.sample_size, opt.sample_size)),
# CenterCrop(opt.sample_size),
ToTensor(opt.norm_value), norm_method
])
target_transform = ClassLabel()
temporal_transform = LoopPadding(16)
validation_data = get_validation_set(
opt, spatial_transform, temporal_transform, target_transform)
val_loader = torch.utils.data.DataLoader(
validation_data,
batch_size=opt.batch_size,
shuffle=False,
num_workers=opt.num_workers,
pin_memory=True)
return train_loader, val_loader
def main_worker():
opt = parse_opts()
print(opt)
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# CUDA for PyTorch
device = torch.device(f"cuda:{opt.gpu}" if opt.use_cuda else "cpu")
# tensorboard
summary_writer = tensorboardX.SummaryWriter(log_dir='tf_logs')
# defining model
model = generate_model(opt, device)
# get data loaders
train_loader, val_loader = get_loaders(opt)
# optimizer
crnn_params = list(model.parameters())
optimizer = torch.optim.Adam(crnn_params, lr=opt.lr_rate, weight_decay=opt.weight_decay)
# scheduler = lr_scheduler.ReduceLROnPlateau(
# optimizer, 'min', patience=opt.lr_patience)
criterion = nn.CrossEntropyLoss()
# resume model
if opt.resume_path:
start_epoch = resume_model(opt, model, optimizer)
else:
start_epoch = 1
# start training
for epoch in range(start_epoch, opt.n_epochs + 1):
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, epoch, opt.log_interval, device)
val_loss, val_acc = val_epoch(
model, val_loader, criterion, device)
# saving weights to checkpoint
if (epoch) % opt.save_interval == 0:
# scheduler.step(val_loss)
# write summary
summary_writer.add_scalar(
'losses/train_loss', train_loss, global_step=epoch)
summary_writer.add_scalar(
'losses/val_loss', val_loss, global_step=epoch)
summary_writer.add_scalar(
'acc/train_acc', train_acc * 100, global_step=epoch)
summary_writer.add_scalar(
'acc/val_acc', val_acc * 100, global_step=epoch)
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}
torch.save(state, os.path.join('snapshots', f'{opt.model}-Epoch-{epoch}-Loss-{val_loss}.pth'))
print("Epoch {} model saved!\n".format(epoch))
if __name__ == "__main__":
main_worker()