From 4724379b5551d31e7bf0c7181dbe9308a45e4933 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Tue, 1 Aug 2023 10:24:04 +0800 Subject: [PATCH] fixed: hessian with lazycall --- tf_pwa/model/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tf_pwa/model/model.py b/tf_pwa/model/model.py index b8130a92..b50fbaea 100644 --- a/tf_pwa/model/model.py +++ b/tf_pwa/model/model.py @@ -12,6 +12,7 @@ from ..data import ( EvalLazy, data_merge, + data_replace, data_shape, data_split, split_generator, @@ -740,8 +741,8 @@ def nll_grad_hessian( mc_weight = tf.convert_to_tensor( [mc_weight] * data_shape(mcdata), dtype="float64" ) - data_i = {**data, "weight": weight} - mcdata_i = {**mcdata, "weight": mc_weight} + data_i = data_replace(data, "weight", weight) + mcdata_i = data_replace(mcdata, "weight", mc_weight) return self.model.nll_grad_hessian(data_i, mcdata_i, batch=batch) def set_params(self, var):