-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_attention.py
190 lines (165 loc) · 9.25 KB
/
train_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from ResNet_Attention import ResNetAttention
from dataloader import data_for_run
import numpy as np
from torch.optim import lr_scheduler
import argparse
import datetime
BATCH_SIZE = 256
EPOCH = 40
LEARN_STEP = 0.001 #学习率
FloodingDepth = 0.0001 #泛洪法正则化参数
DECREASING_LEARN_STEP = True #衰减学习率
DCREASING_STEP_SIZE = 4 #衰减间隔步数
DCREASING_GAMMA =0.6 #衰减率
L2_DECAY = 1e-3 #L2正则化
writer = SummaryWriter()
train_dataloader,test_dataloader = data_for_run(BATCH_SIZE)
###Invoke cuda device
if(torch.cuda.is_available()):
device = torch.device("cuda")
print("cuda available")
else:
device = torch.device("cpu")
print("not found cuda")
###实例化模型
model = ResNetAttention()
###training
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = 'test')
parser.add_argument('--o',default = "./save", type=str,help='output dir.')
args = parser.parse_args()
model = ResNetAttention(in_channels=4)
train_data = train_dataloader
loss_method = torch.nn.CrossEntropyLoss()
#loss_method = torch.nn.BCEWithLogitsLoss() ###v6更换损失函数
if torch.cuda.is_available():
model = model.cuda()
loss_method = loss_method.cuda()
learnstep = LEARN_STEP
optimizer = torch.optim.Adam(model.parameters(),lr=learnstep,weight_decay=L2_DECAY)
#optimizer = torch.optim.SGD(model.parameters(),lr=learnstep)
if(DECREASING_LEARN_STEP == True): #衰减学习率
scheduler = lr_scheduler.StepLR(optimizer, step_size=DCREASING_STEP_SIZE, gamma=DCREASING_GAMMA)
epoch = EPOCH
train_step = 0
for i in range(epoch):
print("-------epoch {}".format(i+1))
model.train()
#print('learnstep=',learnstep)
for step, [DNAs, labels] in enumerate(train_data):
#labels = torch.Tensor(labels).long() ###V3损失函数要求labels是64位.long()
labels = torch.Tensor(labels).float() ###V5修改优化器并更改标签格式后需.float()
if torch.cuda.is_available():
DNAs=DNAs.cuda()
labels=labels.cuda()
outputs = model(DNAs)
#print(outputs,labels) ###for test
loss = loss_method(outputs,labels)
#优化器部分
optimizer.zero_grad()
b = FloodingDepth
flood = (loss-b).abs()+b ###V6.4更新flooding方法
flood.backward()
optimizer.step()
train_step = len(train_dataloader)*i+step+1
if train_step % 100 == 0:
time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("train time:{}, Loss: {}".format(train_step, loss.item()),time)
writer.add_scalar("train_loss", loss.item(), train_step)
# 测试步骤
model.eval()
total_test_loss = 0
total_accuracy = 0
est_data_lengt = 0
test_data = test_dataloader
total_accuracy_ecc = 0
total_num_labels_ecc = 0
ROC_SMOOTH = 50 ###ROC曲线的光滑度
ROC_row = [0]*ROC_SMOOTH ###使用列表保存roc曲线的坐标
ROC_col = [0]*ROC_SMOOTH
ROC_total_num_labels_ecc = [0]*ROC_SMOOTH
ROC_total_num_labels_notecc = [0]*ROC_SMOOTH
with torch.no_grad():
for test_data_length, [test_DNAs, test_labels] in enumerate(test_data):
#test_labels = np.array(test_labels)
#test_labels = torch.Tensor(test_labels).long() ###V3损失函数要求labels是64位.long()
test_labels = torch.Tensor(test_labels).float() ###V5修改优化器并更改标签格式后需.float()
if torch.cuda.is_available():
test_DNAs = test_DNAs.cuda()
test_labels = test_labels.cuda()
outputs = model(test_DNAs)
loss = loss_method(outputs, test_labels)
total_test_loss = total_test_loss + loss.item()
accuracy = (outputs.argmax(1) == test_labels.argmax(1)).sum() ###v5修改比较方式
total_accuracy = total_accuracy + accuracy
#eccDNA_accuracy
cout = 0
acc_ecc = 0
num = 0
for t in test_labels.argmax(1):
cout = cout + 1
t_ecc = t.item()
if(t_ecc == 0): #标签是eccDNA,即标签为(1,0),其.argmax(1)为0
num = num+1
if(outputs.argmax(1)[cout-1].item() == 0):
acc_ecc = acc_ecc+1
total_accuracy_ecc = total_accuracy_ecc + acc_ecc
total_num_labels_ecc = total_num_labels_ecc + num
#ROC
### TP FN P
### FP TN N
outputs_norm = torch.nn.functional.softmax(outputs,dim = -1) ###在行维度进行softmax
ROC_w,ROC_z = torch.split(outputs_norm,1,dim=1) ###拆分出第一列存于ROC_w中
for ROC_I in range(ROC_SMOOTH): ###设置滑块
ROC_slide = ROC_I/ROC_SMOOTH ###注意ROC_I的第一项是0
ROC_P_num = 0
ROC_N_num = 0
ROC_cout = 0
ROC_acc_ecc = 0
ROC_fcc_ecc = 0
for ROC_t in test_labels.argmax(1):
ROC_cout = ROC_cout+1
if(ROC_t.item() == 0): ###标签是eccDNA
ROC_P_num = ROC_P_num+1 ###标签是eccDNA的数量,P值
if(ROC_w[ROC_cout-1].item() >= ROC_slide): ###大于等于该阈值则视为eccDNA
ROC_acc_ecc = ROC_acc_ecc +1 ###视为eccDNA的数量,TP值
if(ROC_t.item() == 1): ###标签是其他DNA
ROC_N_num = ROC_N_num +1
if(ROC_w[ROC_cout-1].item() >= ROC_slide): ###其他DNA被视为eccDNA,FP值
ROC_fcc_ecc = ROC_fcc_ecc +1
ROC_row[ROC_I] = ROC_row[ROC_I] + ROC_acc_ecc ###TP
ROC_total_num_labels_ecc[ROC_I] = ROC_total_num_labels_ecc[ROC_I] + ROC_P_num ###P
ROC_col[ROC_I] = ROC_col[ROC_I] + ROC_fcc_ecc ###FP
ROC_total_num_labels_notecc[ROC_I] = ROC_total_num_labels_notecc[ROC_I] + ROC_N_num ###N
#print(outputs,test_labels) ###for test
#print(outputs.argmax(1),test_labels.argmax(1)) ###for test
#print(accuracy) ###for test
#print(total_accuracy) ###for test
#print(test_data_length) ###for test
print("test set Loss: {}".format(total_test_loss))
print("test set accuracy: {}".format(total_accuracy/test_data_length/BATCH_SIZE))
print("test accuracy of eccDNA: {}".format(total_accuracy_ecc/total_num_labels_ecc))
writer.add_scalar("test_loss", total_test_loss, i)
writer.add_scalar("test_accuracy", total_accuracy/test_data_length/BATCH_SIZE, i)
writer.add_scalar("test accuracy of eccDNA", total_accuracy_ecc/total_num_labels_ecc, i)
###绘制ROC
print('draw ROC')
#for xxxx in ROC_row:print(xxxx) ###for test
file1 = open('./ROC/ROC_result_'+str(i+1)+'.txt','w')
for ROC_j in range(ROC_SMOOTH):
#print(ROC_row[ROC_j],ROC_total_num_labels_ecc[ROC_j]) ###for test
ROC_true = ROC_row[ROC_SMOOTH-ROC_j-1]/ROC_total_num_labels_ecc[ROC_SMOOTH-ROC_j-1] ###先高阈值再低阈值
#print(ROC_true) ###for test
ROC_false = ROC_col[ROC_SMOOTH-ROC_j-1]/ROC_total_num_labels_notecc[ROC_SMOOTH-ROC_j-1]
write_line=str(ROC_true)+' '+str(ROC_false)
file1.writelines(write_line+'\n')
file1.close()
if(DECREASING_LEARN_STEP == True):
scheduler.step() #衰减学习率计数
torch.save(model, "{}/module_{}.pth".format(args.o,i+1)) ###注意文件夹
torch.save(model.state_dict(),"{}/module_dict_{}.pth".format(args.o,i+1))
print("saved epoch {}".format(i+1))
writer.close()