diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index cd88908a..301817e0 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -196,13 +196,19 @@ def get_phsp_noeff(self): ) return self.get_data("phsp")[0] - def get_phsp_plot(self): - if "phsp_plot" in self.config["data"]: - assert len(self.config["data"]["phsp_plot"]) == len( - self.config["data"]["phsp"] + def get_phsp_plot(self, tail=""): + if "phsp_plot" + tail in self.config["data"]: + assert len(self.config["data"]["phsp_plot" + tail]) == len( + self.config["data"]["phsp" + tail] ) - return self.get_data("phsp_plot") - return self.get_data("phsp") + return self.get_data("phsp_plot" + tail) + return self.get_data("phsp" + tail) + + def get_data_rec(self, name): + ret = self.get_data(name + "_rec") + if ret is None: + ret = self.get_data(name) + return ret def get_decay(self, full=True): if full: @@ -481,7 +487,13 @@ def _get_model(self, vm=None, name=""): ) else: model.append( - Model_cfit(amp, wb, bg_function, eff_function) + Model_cfit( + amp, + wb, + bg_function, + eff_function, + resolution_size=self.resolution_size, + ) ) elif "inmc" in self.config["data"]: float_wmc = self.config["data"].get( diff --git a/tf_pwa/config_loader/data.py b/tf_pwa/config_loader/data.py index f3b79dba..7dbb9004 100644 --- a/tf_pwa/config_loader/data.py +++ b/tf_pwa/config_loader/data.py @@ -377,7 +377,6 @@ def get_data(self, idx) -> list: elif idx != "phsp_noeff": assert self._Ngroup == len(ret), "not the same data group" bg_value = self.dic.get(idx + "_bg_value", None) - print(len(ret), files, weights, charge) if bg_value is not None: if isinstance(bg_value, str): bg_value = [bg_value] diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index 7263c65f..43e1c39b 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -227,6 +227,8 @@ def plot_partial_wave( res=None, save_root=False, chains_id_method=None, + phsp_rec=None, + cut_function=lambda x: 1, **kwargs ): """ @@ -265,9 +267,11 @@ def plot_partial_wave( os.makedirs(path, exist_ok=True) if data is None: - data = self.get_data("data") - bg = self.get_data("bg") + data = self.get_data_rec("data") + bg = self.get_data_rec("bg") phsp = self.get_phsp_plot() + phsp_rec = self.get_phsp_plot("_rec") + phsp_rec = phsp if phsp_rec is None else phsp_rec if bg is None: if self.config["data"].get("model", "auto") == "cfit": bg = _get_cfit_bg(self, data, phsp) @@ -275,6 +279,7 @@ def plot_partial_wave( bg = [bg] * len(data) if self.config["data"].get("model", "auto") == "cfit": phsp = _get_cfit_eff_phsp(self, phsp) + phsp_rec = _get_cfit_eff_phsp(self, phsp_rec) amp = self.get_amplitude() self._Ngroup = len(data) ws_bkg = [ @@ -320,6 +325,8 @@ def plot_partial_wave( chain_property, save_root=save_root, res=res, + phsp_rec=phsp_rec[0], + cut_function=cut_function, **kwargs, ) self._plot_partial_wave( @@ -349,6 +356,8 @@ def plot_partial_wave( plot_var_dic, chain_property, save_root=save_root, + phsp_rec=phsp_rec[i], + cut_function=cut_function, **kwargs, ) self._plot_partial_wave( @@ -378,6 +387,8 @@ def plot_partial_wave( chain_property, save_root=save_root, res=res, + phsp_rec=phsp_rec[i], + cut_function=cut_function, **kwargs, ) # self._plot_partial_wave(data_dict, phsp_dict, bg_dict, path+'d{}_'.format(i), plot_var_dic, chain_property, **kwargs) @@ -447,20 +458,29 @@ def _cal_partial_wave( res=None, batch=65000, ref_amp=None, + phsp_rec=None, + cut_function=lambda x: 1, **kwargs ): data_dict = {} phsp_dict = {} bg_dict = {} + phsp_rec = phsp if phsp_rec is None else phsp_rec + + resolution_size_phsp = data_shape(phsp) // data_shape(phsp_rec) + sr = lambda w: np.sum( + np.reshape(data_to_numpy(w), (-1, resolution_size_phsp)), axis=-1 + ) + with amp.temp_params(params): weights_i = [amp(i) for i in data_split(phsp, batch)] weight_phsp = data_merge(*weights_i) phsp_origin_w = phsp.get("weight", 1.0) * phsp.get("eff_value", 1.0) - total_weight = weight_phsp * phsp_origin_w + total_weight = sr(weight_phsp * phsp_origin_w) if ref_amp is not None: weights_i_ref = [ref_amp(i) for i in data_split(phsp, batch)] weight_phsp_ref = data_merge(*weights_i_ref) - total_weight_ref = weight_phsp_ref * phsp_origin_w + total_weight_ref = sr(weight_phsp_ref * phsp_origin_w) data_weight = data.get("weight", None) if data_weight is None: n_data = data_shape(data) @@ -492,17 +512,23 @@ def _cal_partial_wave( amp.set_used_res(used_res) data_weights = data.get("weight", np.ones((data_shape(data),))) - data_dict["data_weights"] = data_weights + data_dict["data_weights"] = cut_function(data) * data_weights phsp_weights = total_weight * norm_frac - phsp_dict["MC_total_fit"] = phsp_weights # MC total weight + phsp_dict["MC_total_fit"] = ( + cut_function(phsp_rec) * phsp_weights + ) # MC total weight if ref_amp is not None: - phsp_dict["MC_total_fit_ref"] = total_weight_ref * norm_frac_ref + phsp_dict["MC_total_fit_ref"] = ( + cut_function(phsp_rec) * total_weight_ref * norm_frac_ref + ) if bg is not None: if isinstance(w_bkg, float): bg_weight = [w_bkg] * data_shape(bg) else: bg_weight = -w_bkg - bg_dict["sideband_weights"] = bg_weight # sideband weight + bg_dict["sideband_weights"] = ( + cut_function(bg) * bg_weight + ) # sideband weight for i, name_i, label, _ in chain_property: weight_i = ( weights[i] @@ -511,9 +537,11 @@ def _cal_partial_wave( * phsp.get("weight", 1.0) * phsp.get("eff_value", 1.0) ) - phsp_dict[ - "MC_{0}_{1}_fit".format(i, name_i) - ] = weight_i # MC partial weight + phsp_dict["MC_{0}_{1}_fit".format(i, name_i)] = cut_function( + phsp_rec + ) * sr( + weight_i + ) # MC partial weight for name in plot_var_dic: idx = plot_var_dic[name]["idx"] trans = lambda x: np.reshape(plot_var_dic[name]["trans"](x), (-1,)) @@ -530,7 +558,7 @@ def _cal_partial_wave( data_dict[name + "_PZ"] = p4[3] data_dict[name] = data_i # data variable - phsp_i = trans(data_index(phsp, idx)) + phsp_i = trans(data_index(phsp_rec, idx)) phsp_dict[name + "_MC"] = phsp_i # MC if bg is not None: diff --git a/tf_pwa/data.py b/tf_pwa/data.py index 5fb0bd27..a36ccf2e 100644 --- a/tf_pwa/data.py +++ b/tf_pwa/data.py @@ -550,14 +550,18 @@ def _check_nan(dat, head): if isinstance(dat, tuple): return tuple( [ - data_struct(data_i, head + [i]) + _check_nan(data_i, head + [i]) for i, data_i in enumerate(dat) ] ) if np.any(tf.math.is_nan(dat)): if no_raise: return False - raise ValueError("nan in data[{}]".format(head)) + raise ValueError( + "nan in data[{}], idx:{}".format( + head, tf.where(tf.math.is_nan(dat)) + ) + ) return True return _check_nan(data, head_keys) diff --git a/tf_pwa/model/cfit.py b/tf_pwa/model/cfit.py index 7b3d5a8a..11f00b59 100644 --- a/tf_pwa/model/cfit.py +++ b/tf_pwa/model/cfit.py @@ -20,8 +20,10 @@ def f_eff(data): class Model_cfit(Model): - def __init__(self, amp, w_bkg=0.001, bg_f=None, eff_f=None): - super().__init__(amp, w_bkg) + def __init__( + self, amp, w_bkg=0.001, bg_f=None, eff_f=None, resolution_size=1 + ): + super().__init__(amp, w_bkg, resolution_size) if bg_f is None: bg_f = get_function("default_bg") elif isinstance(bg_f, str): @@ -61,8 +63,8 @@ def nll( """ data, weight = self.get_weight_data(data, weight) sw = tf.reduce_sum(weight) - sig_data = self.sig(data) - bg_data = self.bg(data) + sig_data = self.sum_resolution(self.sig(data)) + bg_data = self.sum_resolution(self.bg(data)) if mc_weight is None: int_mc = tf.reduce_mean(self.sig(mcdata)) int_bg = tf.reduce_mean(self.bg(mcdata)) @@ -107,7 +109,12 @@ def prob(x): ) / v_int_sig + self.w_bkg * self.bg(x) / v_int_bg ll, g_ll = sum_gradient( - prob, data, var + [v_int_sig, v_int_bg], weight, trans=clip_log + prob, + data, + var + [v_int_sig, v_int_bg], + weight, + trans=clip_log, + resolution_size=self.resolution_size, ) g_ll_sig, g_ll_bg = g_ll[-2], g_ll[-1] g = [ @@ -177,6 +184,7 @@ def prob(x): var + [v_int_sig, v_int_bg], weight=split_generator(weight, batch), trans=clip_log, + resolution_size=self.resolution_size, ) n_var = len(var) diff --git a/tf_pwa/model/model.py b/tf_pwa/model/model.py index d3ceadd4..96a4aa02 100644 --- a/tf_pwa/model/model.py +++ b/tf_pwa/model/model.py @@ -302,14 +302,17 @@ def __init__(self, signal, resolution_size=1): self.vm = signal.vm self.resolution_size = resolution_size + def sum_resolution(self, w): + w = tf.reshape(w, (-1, self.resolution_size)) + return tf.reduce_sum(w, axis=-1) + def nll(self, data, mcdata): """Negative log-Likelihood""" weight = data.get("weight", tf.ones((data_shape(data),))) sw = tf.reduce_sum(weight) rw = tf.reshape(weight, (-1, self.resolution_size)) amp_s2 = self.signal(data) * weight - amp_s2 = tf.reshape(amp_s2, (-1, self.resolution_size)) - amp_s2 = tf.reduce_sum(amp_s2, axis=-1) + amp_s2 = self.sum_resolution(amp_s2) weight = tf.reduce_sum(rw, axis=-1) ln_data = clip_log(amp_s2 / weight) mc_weight = mcdata.get("weight", tf.ones((data_shape(mcdata),)))