Skip to content

Commit

Permalink
fixed: hessian with lazycall
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Aug 1, 2023
1 parent d88a8f5 commit 4724379
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tf_pwa/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..data import (
EvalLazy,
data_merge,
data_replace,
data_shape,
data_split,
split_generator,
Expand Down Expand Up @@ -740,8 +741,8 @@ def nll_grad_hessian(
mc_weight = tf.convert_to_tensor(
[mc_weight] * data_shape(mcdata), dtype="float64"
)
data_i = {**data, "weight": weight}
mcdata_i = {**mcdata, "weight": mc_weight}
data_i = data_replace(data, "weight", weight)
mcdata_i = data_replace(mcdata, "weight", mc_weight)
return self.model.nll_grad_hessian(data_i, mcdata_i, batch=batch)

def set_params(self, var):
Expand Down

0 comments on commit 4724379

Please sign in to comment.