Skip to content

Commit

Permalink
feat: resolution in cfit
Browse files Browse the repository at this point in the history
feat: plot fit results with resolution
  • Loading branch information
jiangyi15 committed Jul 23, 2023
1 parent c920b73 commit ecd4f40
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 29 deletions.
26 changes: 19 additions & 7 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,19 @@ def get_phsp_noeff(self):
)
return self.get_data("phsp")[0]

def get_phsp_plot(self):
if "phsp_plot" in self.config["data"]:
assert len(self.config["data"]["phsp_plot"]) == len(
self.config["data"]["phsp"]
def get_phsp_plot(self, tail=""):
if "phsp_plot" + tail in self.config["data"]:
assert len(self.config["data"]["phsp_plot" + tail]) == len(
self.config["data"]["phsp" + tail]
)
return self.get_data("phsp_plot")
return self.get_data("phsp")
return self.get_data("phsp_plot" + tail)
return self.get_data("phsp" + tail)

def get_data_rec(self, name):
ret = self.get_data(name + "_rec")
if ret is None:
ret = self.get_data(name)
return ret

def get_decay(self, full=True):
if full:
Expand Down Expand Up @@ -481,7 +487,13 @@ def _get_model(self, vm=None, name=""):
)
else:
model.append(
Model_cfit(amp, wb, bg_function, eff_function)
Model_cfit(
amp,
wb,
bg_function,
eff_function,
resolution_size=self.resolution_size,
)
)
elif "inmc" in self.config["data"]:
float_wmc = self.config["data"].get(
Expand Down
1 change: 0 additions & 1 deletion tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ def get_data(self, idx) -> list:
elif idx != "phsp_noeff":
assert self._Ngroup == len(ret), "not the same data group"
bg_value = self.dic.get(idx + "_bg_value", None)
print(len(ret), files, weights, charge)
if bg_value is not None:
if isinstance(bg_value, str):
bg_value = [bg_value]
Expand Down
52 changes: 40 additions & 12 deletions tf_pwa/config_loader/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def plot_partial_wave(
res=None,
save_root=False,
chains_id_method=None,
phsp_rec=None,
cut_function=lambda x: 1,
**kwargs
):
"""
Expand Down Expand Up @@ -265,16 +267,19 @@ def plot_partial_wave(
os.makedirs(path, exist_ok=True)

if data is None:
data = self.get_data("data")
bg = self.get_data("bg")
data = self.get_data_rec("data")
bg = self.get_data_rec("bg")
phsp = self.get_phsp_plot()
phsp_rec = self.get_phsp_plot("_rec")
phsp_rec = phsp if phsp_rec is None else phsp_rec
if bg is None:
if self.config["data"].get("model", "auto") == "cfit":
bg = _get_cfit_bg(self, data, phsp)
else:
bg = [bg] * len(data)
if self.config["data"].get("model", "auto") == "cfit":
phsp = _get_cfit_eff_phsp(self, phsp)
phsp_rec = _get_cfit_eff_phsp(self, phsp_rec)
amp = self.get_amplitude()
self._Ngroup = len(data)
ws_bkg = [
Expand Down Expand Up @@ -320,6 +325,8 @@ def plot_partial_wave(
chain_property,
save_root=save_root,
res=res,
phsp_rec=phsp_rec[0],
cut_function=cut_function,
**kwargs,
)
self._plot_partial_wave(
Expand Down Expand Up @@ -349,6 +356,8 @@ def plot_partial_wave(
plot_var_dic,
chain_property,
save_root=save_root,
phsp_rec=phsp_rec[i],
cut_function=cut_function,
**kwargs,
)
self._plot_partial_wave(
Expand Down Expand Up @@ -378,6 +387,8 @@ def plot_partial_wave(
chain_property,
save_root=save_root,
res=res,
phsp_rec=phsp_rec[i],
cut_function=cut_function,
**kwargs,
)
# self._plot_partial_wave(data_dict, phsp_dict, bg_dict, path+'d{}_'.format(i), plot_var_dic, chain_property, **kwargs)
Expand Down Expand Up @@ -447,20 +458,29 @@ def _cal_partial_wave(
res=None,
batch=65000,
ref_amp=None,
phsp_rec=None,
cut_function=lambda x: 1,
**kwargs
):
data_dict = {}
phsp_dict = {}
bg_dict = {}
phsp_rec = phsp if phsp_rec is None else phsp_rec

resolution_size_phsp = data_shape(phsp) // data_shape(phsp_rec)
sr = lambda w: np.sum(
np.reshape(data_to_numpy(w), (-1, resolution_size_phsp)), axis=-1
)

with amp.temp_params(params):
weights_i = [amp(i) for i in data_split(phsp, batch)]
weight_phsp = data_merge(*weights_i)
phsp_origin_w = phsp.get("weight", 1.0) * phsp.get("eff_value", 1.0)
total_weight = weight_phsp * phsp_origin_w
total_weight = sr(weight_phsp * phsp_origin_w)
if ref_amp is not None:
weights_i_ref = [ref_amp(i) for i in data_split(phsp, batch)]
weight_phsp_ref = data_merge(*weights_i_ref)
total_weight_ref = weight_phsp_ref * phsp_origin_w
total_weight_ref = sr(weight_phsp_ref * phsp_origin_w)
data_weight = data.get("weight", None)
if data_weight is None:
n_data = data_shape(data)
Expand Down Expand Up @@ -492,17 +512,23 @@ def _cal_partial_wave(
amp.set_used_res(used_res)

data_weights = data.get("weight", np.ones((data_shape(data),)))
data_dict["data_weights"] = data_weights
data_dict["data_weights"] = cut_function(data) * data_weights
phsp_weights = total_weight * norm_frac
phsp_dict["MC_total_fit"] = phsp_weights # MC total weight
phsp_dict["MC_total_fit"] = (
cut_function(phsp_rec) * phsp_weights
) # MC total weight
if ref_amp is not None:
phsp_dict["MC_total_fit_ref"] = total_weight_ref * norm_frac_ref
phsp_dict["MC_total_fit_ref"] = (
cut_function(phsp_rec) * total_weight_ref * norm_frac_ref
)
if bg is not None:
if isinstance(w_bkg, float):
bg_weight = [w_bkg] * data_shape(bg)
else:
bg_weight = -w_bkg
bg_dict["sideband_weights"] = bg_weight # sideband weight
bg_dict["sideband_weights"] = (
cut_function(bg) * bg_weight
) # sideband weight
for i, name_i, label, _ in chain_property:
weight_i = (
weights[i]
Expand All @@ -511,9 +537,11 @@ def _cal_partial_wave(
* phsp.get("weight", 1.0)
* phsp.get("eff_value", 1.0)
)
phsp_dict[
"MC_{0}_{1}_fit".format(i, name_i)
] = weight_i # MC partial weight
phsp_dict["MC_{0}_{1}_fit".format(i, name_i)] = cut_function(
phsp_rec
) * sr(
weight_i
) # MC partial weight
for name in plot_var_dic:
idx = plot_var_dic[name]["idx"]
trans = lambda x: np.reshape(plot_var_dic[name]["trans"](x), (-1,))
Expand All @@ -530,7 +558,7 @@ def _cal_partial_wave(
data_dict[name + "_PZ"] = p4[3]
data_dict[name] = data_i # data variable

phsp_i = trans(data_index(phsp, idx))
phsp_i = trans(data_index(phsp_rec, idx))
phsp_dict[name + "_MC"] = phsp_i # MC

if bg is not None:
Expand Down
8 changes: 6 additions & 2 deletions tf_pwa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,14 +550,18 @@ def _check_nan(dat, head):
if isinstance(dat, tuple):
return tuple(
[
data_struct(data_i, head + [i])
_check_nan(data_i, head + [i])
for i, data_i in enumerate(dat)
]
)
if np.any(tf.math.is_nan(dat)):
if no_raise:
return False
raise ValueError("nan in data[{}]".format(head))
raise ValueError(
"nan in data[{}], idx:{}".format(
head, tf.where(tf.math.is_nan(dat))
)
)
return True

return _check_nan(data, head_keys)
18 changes: 13 additions & 5 deletions tf_pwa/model/cfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def f_eff(data):


class Model_cfit(Model):
def __init__(self, amp, w_bkg=0.001, bg_f=None, eff_f=None):
super().__init__(amp, w_bkg)
def __init__(
self, amp, w_bkg=0.001, bg_f=None, eff_f=None, resolution_size=1
):
super().__init__(amp, w_bkg, resolution_size)
if bg_f is None:
bg_f = get_function("default_bg")
elif isinstance(bg_f, str):
Expand Down Expand Up @@ -61,8 +63,8 @@ def nll(
"""
data, weight = self.get_weight_data(data, weight)
sw = tf.reduce_sum(weight)
sig_data = self.sig(data)
bg_data = self.bg(data)
sig_data = self.sum_resolution(self.sig(data))
bg_data = self.sum_resolution(self.bg(data))
if mc_weight is None:
int_mc = tf.reduce_mean(self.sig(mcdata))
int_bg = tf.reduce_mean(self.bg(mcdata))
Expand Down Expand Up @@ -107,7 +109,12 @@ def prob(x):
) / v_int_sig + self.w_bkg * self.bg(x) / v_int_bg

ll, g_ll = sum_gradient(
prob, data, var + [v_int_sig, v_int_bg], weight, trans=clip_log
prob,
data,
var + [v_int_sig, v_int_bg],
weight,
trans=clip_log,
resolution_size=self.resolution_size,
)
g_ll_sig, g_ll_bg = g_ll[-2], g_ll[-1]
g = [
Expand Down Expand Up @@ -177,6 +184,7 @@ def prob(x):
var + [v_int_sig, v_int_bg],
weight=split_generator(weight, batch),
trans=clip_log,
resolution_size=self.resolution_size,
)

n_var = len(var)
Expand Down
7 changes: 5 additions & 2 deletions tf_pwa/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,17 @@ def __init__(self, signal, resolution_size=1):
self.vm = signal.vm
self.resolution_size = resolution_size

def sum_resolution(self, w):
w = tf.reshape(w, (-1, self.resolution_size))
return tf.reduce_sum(w, axis=-1)

def nll(self, data, mcdata):
"""Negative log-Likelihood"""
weight = data.get("weight", tf.ones((data_shape(data),)))
sw = tf.reduce_sum(weight)
rw = tf.reshape(weight, (-1, self.resolution_size))
amp_s2 = self.signal(data) * weight
amp_s2 = tf.reshape(amp_s2, (-1, self.resolution_size))
amp_s2 = tf.reduce_sum(amp_s2, axis=-1)
amp_s2 = self.sum_resolution(amp_s2)
weight = tf.reduce_sum(rw, axis=-1)
ln_data = clip_log(amp_s2 / weight)
mc_weight = mcdata.get("weight", tf.ones((data_shape(mcdata),)))
Expand Down

0 comments on commit ecd4f40

Please sign in to comment.