-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
295 lines (258 loc) · 10.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import torch
import torch.nn.functional as F
import numpy as np
import logging
from PIL import Image
from losses import *
"""
Define task metrics, loss functions and model trainer here.
"""
def randrot(img):
mode = np.random.randint(0,4)
return rot(img,mode)
def randfilp(img):
mode = np.random.randint(0,3)
return flip(img,mode)
def rot(img, rot_mode):
if rot_mode == 0:
img = img.transpose(-2, -1)
img = img.flip(-2)
elif rot_mode == 1:
img = img.flip(-2)
img = img.flip(-1)
elif rot_mode == 2:
img = img.flip(-2)
img = img.transpose(-2, -1)
return img
def flip(img, flip_mode):
if flip_mode == 0:
img = img.flip(-2)
elif flip_mode == 1:
img = img.flip(-1)
return img
def RGB2YCrCb(rgb_image):
"""
将RGB格式转换为YCrCb格式
用于中间结果的色彩空间转换中,因为此时rgb_image默认size是[B, C, H, W]
:param rgb_image: RGB格式的图像数据
:return: Y, Cr, Cb
"""
R = rgb_image[:, 0:1]
G = rgb_image[:, 1:2]
B = rgb_image[:, 2:3]
Y = 0.299 * R + 0.587 * G + 0.114 * B
Cr = (R - Y) * 0.713 + 0.5
Cb = (B - Y) * 0.564 + 0.5
Y = Y.clamp(0.0,1.0)
Cr = Cr.clamp(0.0,1.0).detach()
Cb = Cb.clamp(0.0,1.0).detach()
return Y, Cb, Cr
def YCbCr2RGB(Y, Cb, Cr):
"""
将YcrCb格式转换为RGB格式
:param Y:
:param Cb:
:param Cr:
:return:
"""
ycrcb = torch.cat([Y, Cr, Cb], dim=1)
B, C, W, H = ycrcb.shape
im_flat = ycrcb.transpose(1, 3).transpose(1, 2).reshape(-1, 3)
mat = torch.tensor([[1.0, 1.0, 1.0], [1.403, -0.714, 0.0], [0.0, -0.344, 1.773]]
).to(Y.device)
bias = torch.tensor([0.0 / 255, -0.5, -0.5]).to(Y.device)
temp = (im_flat + bias).mm(mat)
out = temp.reshape(B, W, H, C).transpose(1, 3).transpose(2, 3)
out = out.clamp(0,1.0)
return out
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def logger_config(log_path,logging_name):
'''
配置log
:param log_path: 输出log路径
:param logging_name: 记录中name,可随意
:return:
'''
'''
logger是日志对象, handler是流处理器, console是控制台输出(没有console也可以, 将不会在控制台输出,会在日志文件中输出)
'''
# 获取logger对象,取名
logger = logging.getLogger(logging_name)
# 输出DEBUG及以上级别的信息,针对所有输出的第一层过滤
logger.setLevel(level=logging.DEBUG)
# 获取文件日志句柄并设置日志级别,第二层过滤
handler = logging.FileHandler(log_path, encoding='UTF-8')
handler.setLevel(logging.INFO)
# 生成并设置文件日志格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
# console相当于控制台输出,handler文件输出。获取流句柄并设置日志级别,第二层过滤
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
# 为logger对象添加句柄
logger.addHandler(handler)
logger.addHandler(console)
return logger
def model_fit(x_pred, x_output, task_type):
device = x_pred.device
# binary mark to mask out undefined pixel space
binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).to(device)
if task_type == 'semantic':
# semantic loss: depth-wise cross entropy
loss = F.nll_loss(x_pred, x_output, ignore_index=-1)
if task_type == 'depth':
# depth loss: l1 norm
loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)
if task_type == 'normal':
# normal loss: dot product
loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)
return loss
# New mIoU and Acc. formula: accumulate every pixel and average across all pixels in all images
class ConfMatrix(object):
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, pred, target):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
with torch.no_grad():
k = (target >= 0) & (target < n)
inds = n * target[k].to(torch.int64) + pred[k]
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
def get_metrics(self):
h = self.mat.float()
acc = torch.diag(h).sum() / h.sum()
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return torch.mean(iu).item(), acc.item()
def get_palette():
unlabelled = [0, 0, 0]
car = [64, 0, 128]
person = [64, 64, 0]
bike = [0, 128, 192]
curve = [0, 0, 192]
car_stop = [128, 128, 0]
guardrail = [64, 64, 128]
color_cone = [192, 128, 128]
bump = [192, 64, 0]
palette = np.array(
[
unlabelled,
car,
person,
bike,
curve,
car_stop,
guardrail,
color_cone,
bump,
]
)
return palette
def seg_visualize(predictions, save_name):
palette = get_palette()
pred = predictions[0].data.cpu().numpy()
img = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
for cid in range(1, int(predictions.max())):
img[pred == cid] = palette[cid]
img = Image.fromarray(np.uint8(img))
img.save(save_name)
"""
confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反
P\L P N
P TP FP
N FN TN
"""
class SegmentationMetric(object):
def __init__(self, numClass, device):
self.numClass = numClass
self.confusionMatrix = torch.zeros((self.numClass,) * 2).to(device) # 混淆矩阵(空)
def pixelAccuracy(self):
# return all class overall pixel accuracy 正确的像素占总像素的比例
# PA = acc = (TP + TN) / (TP + TN + FP + TN)
acc = torch.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
# acc = acc.item()
return acc
def classPixelAccuracy(self):
# return each category pixel accuracy(A more accurate way to call it precision)
# acc = (TP) / TP + FP
classAcc = torch.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
# classAcc = classAcc.item()
return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率
def meanPixelAccuracy(self):
"""
Mean Pixel Accuracy(MPA,均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。
:return:
"""
classAcc = self.classPixelAccuracy()
meanAcc = classAcc[classAcc < float('inf')].mean() # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0
# meanAcc = meanAcc.item()
return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89
def IntersectionOverUnion(self):
# Intersection = TP Union = TP + FP + FN
# IoU = TP / (TP + FP + FN)
intersection = torch.diag(self.confusionMatrix) # 取对角元素的值,返回列表
union = torch.sum(self.confusionMatrix, axis=1) + torch.sum(self.confusionMatrix, axis=0) - torch.diag(
self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表
IoU = intersection / union # 返回列表,其值为各个类别的IoU
# IoU = [a.item() for a in IoU]
return IoU
def meanIntersectionOverUnion(self):
IoU = self.IntersectionOverUnion()
mIoU = IoU[IoU<float('inf')].mean()# 求各类别IoU的平均
# mIoU = mIoU.item()
return mIoU
def genConfusionMatrix(self, imgPredict, imgLabel, ignore_labels): #
"""
同FCN中score.py的fast_hist()函数,计算混淆矩阵
:param imgPredict:
:param imgLabel:
:return: 混淆矩阵
"""
# remove classes from unlabeled pixels in gt image and predict
mask = (imgLabel >= 0) & (imgLabel < self.numClass)
for IgLabel in ignore_labels:
mask &= (imgLabel != IgLabel)
label = self.numClass * imgLabel[mask] + imgPredict[mask]
count = torch.bincount(label, minlength=self.numClass ** 2)
confusionMatrix = count.view(self.numClass, self.numClass)
# print(confusionMatrix)
return confusionMatrix
def Frequency_Weighted_Intersection_over_Union(self):
"""
FWIoU, 频权交并比:为MIoU的一种提升, 这种方法根据每个类出现的频率为其设置权重。
FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
"""
freq = torch.sum(self.confusion_matrix, axis=1) / torch.sum(self.confusion_matrix)
iu = np.diag(self.confusion_matrix) / (
torch.sum(self.confusion_matrix, axis=1) + torch.sum(self.confusion_matrix, axis=0) -
torch.diag(self.confusion_matrix))
FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
# FWIoU = FWIoU.item()
return FWIoU
def addBatch(self, imgPredict, imgLabel, ignore_labels):
assert imgPredict.shape == imgLabel.shape
with torch.no_grad():
self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel, ignore_labels) # 得到混淆矩阵
return self.confusionMatrix
def reset(self):
self.confusionMatrix = torch.zeros((self.numClass, self.numClass))
# 测试内容
if __name__ == '__main__':
imgPredict = torch.tensor([[0,1,2],[2,1,1]]).long() # 可直接换成预测图片
imgLabel = torch.tensor([[0,1,255],[1,1,2]]).long() # 可直接换成标注图片
ignore_labels = [255]
metric = SegmentationMetric(3) # 3表示有3个分类,有几个分类就填几, 0也是1个分类
hist = metric.addBatch(imgPredict, imgLabel,ignore_labels)
pa = metric.pixelAccuracy()
cpa = metric.classPixelAccuracy()
mpa = metric.meanPixelAccuracy()
IoU = metric.IntersectionOverUnion()
mIoU = metric.meanIntersectionOverUnion()
print('hist is :\n', hist)
print('PA is : %f' % pa)
print('cPA is :', cpa) # 列表
print('mPA is : %f' % mpa)
print('IoU is : ', IoU)
print('mIoU is : ', mIoU)