-
Notifications
You must be signed in to change notification settings - Fork 6
/
EPECriterion.lua
41 lines (37 loc) · 1.32 KB
/
EPECriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
local EPECriterion, parent = torch.class('nn.EPECriterion', 'nn.Criterion')
function EPECriterion:__init(scales,criterion)
parent.__init(self)
self.EPE = torch.Tensor()
if criterion == 'SmoothL1' then
self.criterion = nn.SmoothL1Criterion():cuda()
elseif criterion == 'MSE' then
self.criterion = nn.MSECriterion():cuda()
else --AbsCriterions
self.criterion = nn.AbsCriterion():cuda()
end
end
function EPECriterion:updateOutput(input, target)
local diffMap = input-target
assert(input:nDimension() == 4 or input:nDimension() == 3)
if input:nDimension() == 4 then
self.EPE = diffMap:norm(2,2)
else
self.EPE = diffMap:norm(2,1)
end
self.zeroEPE = torch.zeros(self.EPE:size()):cuda():fill(0)
self.output = self.criterion:forward(self.EPE, self.zeroEPE)
return self.output
end
function EPECriterion:updateGradInput(input, target)
self.gradInput = input-target
local gradOutput = torch.cdiv(self.criterion:backward(self.EPE,self.zeroEPE),self.EPE)
assert(self.gradInput:nDimension() == 4 or gradInput:nDimension() == 3)
if self.gradInput:nDimension() == 4 then
self.gradInput[{{},1}]:cmul(gradOutput)
self.gradInput[{{},2}]:cmul(gradOutput)
else
self.gradInput[1]:cmul(gradOutput)
self.gradInput[2]:cmul(gradOutput)
end
return self.gradInput
end