Skip to content

Commit

Permalink
Made code more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
laclouis5 committed Sep 2, 2019
1 parent c361951 commit 5cb10a2
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions training/train_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_lr_multipliers(model):
return lr_mult


def get_loss_funcs():
def get_loss_funcs(nb_stages=6):
"""
Euclidean loss as implemented in caffe
https://github.com/BVLC/caffe/blob/master/src/caffe/layers/euclidean_loss_layer.cpp
Expand All @@ -129,19 +129,9 @@ def get_loss_funcs():
def _eucl_loss(x, y):
return K.sum(K.square(x - y)) / batch_size / 2

losses = {}
losses["weight_stage1_L1"] = _eucl_loss
losses["weight_stage1_L2"] = _eucl_loss
losses["weight_stage2_L1"] = _eucl_loss
losses["weight_stage2_L2"] = _eucl_loss
losses["weight_stage3_L1"] = _eucl_loss
losses["weight_stage3_L2"] = _eucl_loss
losses["weight_stage4_L1"] = _eucl_loss
losses["weight_stage4_L2"] = _eucl_loss
losses["weight_stage5_L1"] = _eucl_loss
losses["weight_stage5_L2"] = _eucl_loss
losses["weight_stage6_L1"] = _eucl_loss
losses["weight_stage6_L2"] = _eucl_loss
keys = ["weight_stage{}_L{}".format(stage+1, L+1) for stage in range(nb_stages) for L in range(2)]

losses = dict.fromkeys(keys, _eucl_loss)

return losses

Expand Down

0 comments on commit 5cb10a2

Please sign in to comment.