From 5cb10a27028ed1e70488ab2cc6489a16ba3537a3 Mon Sep 17 00:00:00 2001 From: Louis Lac Date: Mon, 2 Sep 2019 17:43:02 +0200 Subject: [PATCH] Made code more generic --- training/train_pose.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/training/train_pose.py b/training/train_pose.py index 0ef6ab7..63dd809 100644 --- a/training/train_pose.py +++ b/training/train_pose.py @@ -120,7 +120,7 @@ def get_lr_multipliers(model): return lr_mult -def get_loss_funcs(): +def get_loss_funcs(nb_stages=6): """ Euclidean loss as implemented in caffe https://github.com/BVLC/caffe/blob/master/src/caffe/layers/euclidean_loss_layer.cpp @@ -129,19 +129,9 @@ def get_loss_funcs(): def _eucl_loss(x, y): return K.sum(K.square(x - y)) / batch_size / 2 - losses = {} - losses["weight_stage1_L1"] = _eucl_loss - losses["weight_stage1_L2"] = _eucl_loss - losses["weight_stage2_L1"] = _eucl_loss - losses["weight_stage2_L2"] = _eucl_loss - losses["weight_stage3_L1"] = _eucl_loss - losses["weight_stage3_L2"] = _eucl_loss - losses["weight_stage4_L1"] = _eucl_loss - losses["weight_stage4_L2"] = _eucl_loss - losses["weight_stage5_L1"] = _eucl_loss - losses["weight_stage5_L2"] = _eucl_loss - losses["weight_stage6_L1"] = _eucl_loss - losses["weight_stage6_L2"] = _eucl_loss + keys = ["weight_stage{}_L{}".format(stage+1, L+1) for stage in range(nb_stages) for L in range(2)] + + losses = dict.fromkeys(keys, _eucl_loss) return losses