-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
98 lines (77 loc) · 2.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import os
from os.path import join
import argparse
import keras
import keras.backend as K
from dotmap import DotMap
from keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
# import sys
# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from HoverNet import *
from loss import *
from datasets import Kumar
def scheduler(epoch, lr):
if epoch%60==0 and epoch>0:
return lr*0.1
else:
return lr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--exp_id')
parser.add_argument(
'--gid',
help='gpu id')
parser.add_argument(
'--ds',
help='dataset name')
parser.add_argument(
'--lr',
help='learning rate')
parser.add_argument(
'--eps',
help='epochs')
parser.add_argument(
'--bs',
help='batch size')
args = parser.parse_args()
with open('configs.json', 'r') as f:
config = json.load(f)
config = DotMap(config)
# set gpu environ
os.environ['CUDA_VISIBLE_DEVICES'] = args.gid
ckps_dir = join('logs/%s/%s' % (args.ds, args.exp_id)) # logs/kumar/v1.0
callbacks = []
callbacks.append(
ModelCheckpoint(filepath=join(ckps_dir, "model_{epoch:02d}_{val_loss:.4f}.hdf5"),
monitor='val_loss',
save_best_only=False,
save_weights_only=True))
callbacks.append(
TensorBoard(log_dir=ckps_dir,
write_graph=False
))
# callbacks.append(
# LearningRateScheduler(scheduler)
# )
input_size = config.model.input_size
input_chnl = config.model.input_chnl
lr = float(args.lr)
eps = int(args.eps)
net = hvnet((input_size, input_size), input_chnl)
net.compile(loss={'np':cce, 'hv':gmse(5)},
optimizer=Adam(lr),
) # metrics={"np":[cce, soft_dice], "hv":[mse, gmse]}
# join the parameters
config.data_dir = join(config.data_dir, args.ds)
config.train.batch_size = int(args.bs)
# create the generator
train_gen = Kumar(config, 'train')
valid_gen = Kumar(config, 'valid')
net.fit_generator(train_gen,
steps_per_epoch=train_gen.__len__(),
epochs=eps,
validation_data=valid_gen,
validation_steps=valid_gen.__len__(),
callbacks=callbacks)