-
Notifications
You must be signed in to change notification settings - Fork 14
/
CrossEntropyLoss2d.py
38 lines (31 loc) · 1.73 KB
/
CrossEntropyLoss2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Code obtained from https://discuss.pytorch.org/t/about-segmentation-loss-function/2906/6
# Comments prefixed with "ED" are mine.
import torch.nn.functional as F
import torch.nn as nn
class CrossEntropy2d(nn.Module):
def __init__(self, size_average=True, ignore_label=255):
super(CrossEntropy2d, self).__init__()
self.size_average = size_average
self.ignore_label = ignore_label
def forward(self, predict, target, weight=None):
"""
Args:
predict:(n, c, h, w)
target:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
"""
assert not target.requires_grad
# ED: Assestions about size compatibility between prediction and ground truth
assert predict.dim() == 4
assert target.dim() == 3
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
n, c, h, w = predict.size()
target_mask = (target >= 0) * (target != self.ignore_label) # ED creates a mask for all labels except ignored
target = target[target_mask] # ED: applies the mask to ground truth
predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average)
return loss