diff --git a/tf_pwa/model/model.py b/tf_pwa/model/model.py index 9c5412d..c9cf04d 100644 --- a/tf_pwa/model/model.py +++ b/tf_pwa/model/model.py @@ -327,7 +327,9 @@ def nll(self, data, mcdata): amp_s2 = self.signal(data) * weight amp_s2 = self.sum_resolution(amp_s2) weight = tf.reduce_sum(rw, axis=-1) - dom_weight = tf.where(weight == 0, 1.0, weight) + dom_weight = tf.where( + weight == 0, tf.constant(1.0, dtype=weight.dtype), weight + ) ln_data = clip_log(amp_s2 / dom_weight) mc_weight = mcdata.get("weight", tf.ones((data_shape(mcdata),))) int_mc = tf.reduce_sum(