diff --git a/tf_pwa/model/model.py b/tf_pwa/model/model.py index b8130a92..b50fbaea 100644 --- a/tf_pwa/model/model.py +++ b/tf_pwa/model/model.py @@ -12,6 +12,7 @@ from ..data import ( EvalLazy, data_merge, + data_replace, data_shape, data_split, split_generator, @@ -740,8 +741,8 @@ def nll_grad_hessian( mc_weight = tf.convert_to_tensor( [mc_weight] * data_shape(mcdata), dtype="float64" ) - data_i = {**data, "weight": weight} - mcdata_i = {**mcdata, "weight": mc_weight} + data_i = data_replace(data, "weight", weight) + mcdata_i = data_replace(mcdata, "weight", mc_weight) return self.model.nll_grad_hessian(data_i, mcdata_i, batch=batch) def set_params(self, var):