-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
104 lines (81 loc) · 4.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#! /usr/bin/env python
import argparse
import os
import numpy as np
from preprocessing import parse_annotation
from frontend import YOLO
import json
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3" #4 GPU
argparser = argparse.ArgumentParser(
description='Train and validate YOLO_v2 model on any dataset')
argparser.add_argument(
'-c',
'--conf',
help='path to configuration file')
def _main_(args):
config_path = args.conf
with open(config_path) as config_buffer:
config = json.loads(config_buffer.read())
os.environ["CUDA_VISIBLE_DEVICES"]=config['env']['gpu']
gpus = max(1, len(config['env']['gpu'].split(",")))
#print("{}").format(config)
###############################
# Parse the annotations
###############################
# parse annotations of the training set
train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'],
config['train']['train_image_folder'],
config['model']['labels'], config['train']['image_ext_name'], config['train']['image_prefix'])
# parse annotations of the validation set, if any, otherwise split the training set
if os.path.exists(config['valid']['valid_annot_folder']):
valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'],
config['valid']['valid_image_folder'],
config['model']['labels'], config['train']['image_ext_name'], config['train']['image_prefix'])
else:
train_valid_split = int(0.8*len(train_imgs))
np.random.shuffle(train_imgs)
valid_imgs = train_imgs[train_valid_split:]
train_imgs = train_imgs[:train_valid_split]
overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))
print ('Seen labels:{}'.format(train_labels))
print ('Given labels:{}'.format(config['model']['labels']))
print ('Overlap labels:{}'.format(overlap_labels))
if len(overlap_labels) < len(config['model']['labels']):
print ('Some labels have no images! Please revise the list of labels in the config.json file!')
return
###############################
# Construct the model
###############################
yolo = YOLO(architecture = config['model']['architecture'],
input_size = config['model']['input_size'],
labels = config['model']['labels'],
max_box_per_image = config['model']['max_box_per_image'],
anchors = config['model']['anchors']
gpus = gpus)
###############################
# Load the pretrained weights (if any)
###############################
if os.path.exists(config['train']['pretrained_weights']):
print ("Loading pre-trained weights in {}".format(config['train']['pretrained_weights']))
yolo.load_weights(config['train']['pretrained_weights'])
###############################
# Start the training process
###############################
yolo.train(train_imgs = train_imgs,
valid_imgs = valid_imgs,
train_times = config['train']['train_times'],
valid_times = config['valid']['valid_times'],
nb_epoch = config['train']['nb_epoch'],
learning_rate = config['train']['learning_rate'],
batch_size = config['train']['batch_size'],
warmup_bs = config['train']['warmup_batches'],
object_scale = config['train']['object_scale'],
no_object_scale = config['train']['no_object_scale'],
coord_scale = config['train']['coord_scale'],
class_scale = config['train']['class_scale'],
saved_weights_name = config['train']['saved_weights_name'],
debug = config['train']['debug'])
if __name__ == '__main__':
args = argparser.parse_args()
_main_(args)