Skip to content

Commit

Permalink
Merge pull request #153 from jonas-eschle/patch-1
Browse files Browse the repository at this point in the history
dtype aligns in nll
  • Loading branch information
jiangyi15 authored Oct 16, 2024
2 parents 07d25ce + da127f2 commit dc0326f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tf_pwa/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def nll(self, data, mcdata):
amp_s2 = self.signal(data) * weight
amp_s2 = self.sum_resolution(amp_s2)
weight = tf.reduce_sum(rw, axis=-1)
dom_weight = tf.where(weight == 0, 1.0, weight)
dom_weight = tf.where(
weight == 0, tf.constant(1.0, dtype=weight.dtype), weight
)
ln_data = clip_log(amp_s2 / dom_weight)
mc_weight = mcdata.get("weight", tf.ones((data_shape(mcdata),)))
int_mc = tf.reduce_sum(
Expand Down

0 comments on commit dc0326f

Please sign in to comment.