diff --git a/tf_pwa/amp/core.py b/tf_pwa/amp/core.py index 1e83c08d..2a8bee4f 100644 --- a/tf_pwa/amp/core.py +++ b/tf_pwa/amp/core.py @@ -773,7 +773,7 @@ def _get_cg_matrix(self, ls): # CG factor inside H lambda_b - lambda_c, ) ) - return tf.convert_to_tensor(ret) + return ret def get_helicity_amp(self, data, data_p, **kwargs): m_dep = self.get_ls_amp(data, data_p, **kwargs) @@ -1622,6 +1622,8 @@ def add_used_chains(self, used_chains): self.chains_idx.append(i) def set_used_chains(self, used_chains): + if isinstance(used_chains, str): + used_chains = [used_chains] self.chains_idx = list(used_chains) if len(self.chains_idx) != len(self.chains): self.not_full = True @@ -1704,10 +1706,18 @@ def value_and_grad(f, var): class AmplitudeModel(object): def __init__( - self, decay_group, name="", polar=None, vm=None, use_tf_function=False + self, + decay_group, + name="", + polar=None, + vm=None, + use_tf_function=False, + no_id_cached=False, + jit_compile=False, ): self.decay_group = decay_group self._name = name + self.no_id_cached = no_id_cached with variable_scope(vm) as vm: if polar is not None: vm.polar = polar @@ -1720,7 +1730,9 @@ def __init__( if use_tf_function: from tf_pwa.experimental.wrap_function import WrapFun - self.cached_fun = WrapFun(self.decay_group.sum_amp) + self.cached_fun = WrapFun( + self.decay_group.sum_amp, jit_compile=jit_compile + ) else: self.cached_fun = self.decay_group.sum_amp @@ -1783,7 +1795,7 @@ def trainable_variables(self): def __call__(self, data, cached=False): if isinstance(data, LazyCall): data = data.eval() - if id(data) in self.f_data: + if id(data) in self.f_data or self.no_id_cached: if not self.decay_group.not_full: return self.cached_fun(data) else: diff --git a/tf_pwa/cal_angle.py b/tf_pwa/cal_angle.py index 6c898166..b7113912 100644 --- a/tf_pwa/cal_angle.py +++ b/tf_pwa/cal_angle.py @@ -56,6 +56,7 @@ from .angle import SU2M, EulerAngle, LorentzVector, Vector3, _epsilon from .config import get_config from .data import ( + HeavyCall, LazyCall, data_index, data_merge, @@ -261,8 +262,8 @@ def cal_single_boost(data, decay_chain: DecayChain) -> dict: def cal_helicity_angle( data: dict, decay_chain: DecayChain, - base_z=np.array([[0.0, 0.0, 1.0]]), - base_x=np.array([[1.0, 0.0, 0.0]]), + base_z=np.array([0.0, 0.0, 1.0]), + base_x=np.array([1.0, 0.0, 0.0]), ) -> dict: """ Calculate helicity angle for A -> B + C: :math:`\\theta_{B}^{A}, \\phi_{B}^{A}` from momentum. @@ -276,7 +277,6 @@ def cal_helicity_angle( # print(decay_chain, part_data) part_data = cal_chain_boost(data, decay_chain) - # print(decay_chain , part_data) # calculate angle and base x,z axis from mother particle rest frame momentum and base axis set_x = {decay_chain.top: base_x} set_z = {decay_chain.top: base_z} @@ -405,6 +405,7 @@ def cal_angle_from_particle( r_boost=True, final_rest=True, align_ref=None, # "center_mass", + only_left_angle=False, ): """ Calculate helicity angle for particle momentum, add aligned angle. @@ -422,7 +423,7 @@ def cal_angle_from_particle( # get base z axis p4 = data[decay_group.top]["p"] p3 = LorentzVector.vect(p4) - base_z = np.array([[0.0, 0.0, 1.0]]) + tf.zeros_like(p3) + base_z = np.array([0.0, 0.0, 1.0]) + tf.zeros_like(p3) if random_z: p3_norm = Vector3.norm(p3) mask = tf.expand_dims(p3_norm < 1e-5, -1) @@ -474,6 +475,10 @@ def cal_angle_from_particle( # ang = AlignmentAngle.angle_px_px(z1, x1, z2, x2) part_data[i]["aligned_angle"] = ang ret = data_strip(decay_data, ["r_matrix", "b_matrix", "x", "z"]) + if only_left_angle: + for i in ret: + for j in ret[i]: + del ret[i][j][j.outs[1]]["ang"] return ret @@ -629,6 +634,7 @@ def cal_angle_from_momentum_base( random_z=False, batch=65000, align_ref=None, + only_left_angle=False, ) -> CalAngleData: """ Transform 4-momentum data in files for the amplitude model automatically via DecayGroup. @@ -646,6 +652,7 @@ def cal_angle_from_momentum_base( r_boost, random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) ret = [] for i in split_generator(p, batch): @@ -658,6 +665,7 @@ def cal_angle_from_momentum_base( r_boost, random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) ) return data_merge(*ret) @@ -707,11 +715,20 @@ def cal_angle_from_momentum_id_swap( random_z=False, batch=65000, align_ref=None, + only_left_angle=False, ) -> CalAngleData: ret = [] id_particles = decs.identical_particles data = cal_angle_from_momentum_base( - p, decs, using_topology, center_mass, r_boost, random_z, batch + p, + decs, + using_topology, + center_mass, + r_boost, + random_z, + batch, + align_ref=align_ref, + only_left_angle=only_left_angle, ) if id_particles is None or len(id_particles) == 0: return data @@ -727,6 +744,7 @@ def cal_angle_from_momentum_id_swap( random_z, batch, align_ref=align_ref, + only_left_angle=only_left_angle, ) return data @@ -740,6 +758,7 @@ def cal_angle_from_momentum( random_z=False, batch=65000, align_ref=None, + only_left_angle=False, ) -> CalAngleData: """ Transform 4-momentum data in files for the amplitude model automatically via DecayGroup. @@ -750,13 +769,15 @@ def cal_angle_from_momentum( """ if isinstance(p, LazyCall): return LazyCall( - cal_angle_from_momentum, + HeavyCall(cal_angle_from_momentum), p, decs=decs, using_topology=using_topology, center_mass=center_mass, r_boost=r_boost, random_z=random_z, + align_ref=align_ref, + only_left_angle=only_left_angle, batch=batch, ) ret = [] @@ -771,6 +792,7 @@ def cal_angle_from_momentum( random_z, batch, align_ref=align_ref, + only_left_angle=only_left_angle, ) if cp_particles is None or len(cp_particles) == 0: return data @@ -785,6 +807,7 @@ def cal_angle_from_momentum( random_z, batch, align_ref=align_ref, + only_left_angle=only_left_angle, ) return data @@ -797,6 +820,7 @@ def cal_angle_from_momentum_single( r_boost=True, random_z=True, align_ref=None, + only_left_angle=False, ) -> CalAngleData: """ Transform 4-momentum data in files for the amplitude model automatically via DecayGroup. @@ -824,6 +848,7 @@ def cal_angle_from_momentum_single( r_boost=r_boost, random_z=random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) data = {"particle": data_p, "decay": data_d} add_relative_momentum(data) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 32d78bf3..4e8920af 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -218,9 +218,10 @@ def get_decay(self, full=True): @functools.lru_cache() def get_amplitude(self, vm=None, name=""): - use_tf_function = self.config.get("data", {}).get( - "use_tf_function", False - ) + amp_config = self.config.get("data", {}) + use_tf_function = amp_config.get("use_tf_function", False) + no_id_cached = amp_config.get("no_id_cached", False) + jit_compile = amp_config.get("jit_compile", False) decay_group = self.full_decay self.check_valid_jp(decay_group) if vm is None: @@ -228,7 +229,12 @@ def get_amplitude(self, vm=None, name=""): if vm in self.amps: return self.amps[vm] amp = AmplitudeModel( - decay_group, vm=vm, name=name, use_tf_function=use_tf_function + decay_group, + vm=vm, + name=name, + use_tf_function=use_tf_function, + no_id_cached=no_id_cached, + jit_compile=jit_compile, ) self.add_constraints(amp) self.amps[vm] = amp @@ -561,6 +567,7 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""): bg = [None] * self._Ngroup model = self._get_model(vm=vm, name=name) fcns = [] + # print(self.config["data"].get("using_mix_likelihood", False)) if self.config["data"].get("using_mix_likelihood", False): print(" Using Mix Likelihood") @@ -575,7 +582,9 @@ def get_fcn(self, all_data=None, batch=65000, vm=None, name=""): if all_data is None: self.cached_fcn[vm] = fcn return fcn - for md, dt, mc, sb, ij in zip(model, data, phsp, bg, inmc): + for idx, (md, dt, mc, sb, ij) in enumerate( + zip(model, data, phsp, bg, inmc) + ): if self.config["data"].get("model", "auto") == "cfit": fcns.append( FCN( @@ -644,6 +653,7 @@ def fit( maxiter=None, jac=True, print_init_nll=True, + callback=None, ): if data is None and phsp is None: data, phsp, bg, inmc = self.get_all_data() @@ -677,6 +687,7 @@ def fit( improve=False, maxiter=maxiter, jac=jac, + callback=callback, ) if self.fit_params.hess_inv is not None: self.inv_he = self.fit_params.hess_inv diff --git a/tf_pwa/config_loader/data.py b/tf_pwa/config_loader/data.py index 8491cd67..462f6b91 100644 --- a/tf_pwa/config_loader/data.py +++ b/tf_pwa/config_loader/data.py @@ -125,7 +125,8 @@ def get_data(self, idx) -> dict: weight_sign = self.get_weight_sign(idx) charge = self.dic.get(idx + "_charge", None) ret = self.load_data(files, weights, weight_sign, charge) - return self.process_scale(idx, ret) + ret = self.process_scale(idx, ret) + return ret def process_scale(self, idx, data): if idx in self.scale_list and self.dic.get("weight_scale", False): @@ -136,6 +137,12 @@ def process_scale(self, idx, data): ) return data + def set_lazy_call(self, data, idx): + if isinstance(data, LazyCall): + name = idx + cached_file = self.dic.get("cached_lazy_call", None) + data.set_cached_file(cached_file, name) + def get_n_data(self): data = self.get_data("data") weight = data.get("weight", np.ones((data_shape(data),))) @@ -156,6 +163,7 @@ def cal_angle(self, p4, charge=None): r_boost = self.dic.get("r_boost", True) random_z = self.dic.get("random_z", True) align_ref = self.dic.get("align_ref", None) + only_left_angle = self.dic.get("only_left_angle", False) data = cal_angle_from_momentum( p4, self.decay_struct, @@ -163,6 +171,7 @@ def cal_angle(self, p4, charge=None): r_boost=r_boost, random_z=random_z, align_ref=align_ref, + only_left_angle=only_left_angle, ) if charge is not None: data["charge_conjugation"] = charge @@ -185,18 +194,17 @@ def load_data( p4 = self.load_p4(files) charges = None if charges is None else charges[: data_shape(p4)] data = self.cal_angle(p4, charges) - if weights is not None: - if isinstance(weights, float): - data["weight"] = np.array( - [weights * weights_sign] * data_shape(data) - ) - elif isinstance(weights, str): # weight files - weight = self.load_weight_file(weights) - data["weight"] = weight[: data_shape(data)] * weights_sign - else: - raise TypeError( - "weight format error: {}".format(type(weights)) - ) + if weights is None: + data["weight"] = np.array([1.0 * weights_sign] * data_shape(data)) + elif isinstance(weights, float): + data["weight"] = np.array( + [weights * weights_sign] * data_shape(data) + ) + elif isinstance(weights, str): # weight files + weight = self.load_weight_file(weights) + data["weight"] = weight[: data_shape(data)] * weights_sign + else: + raise TypeError("weight format error: {}".format(type(weights))) if charge is None: data["charge_conjugation"] = tf.ones((data_shape(data),)) @@ -322,8 +330,11 @@ def savetxt(self, file_name, data): else: raise ValueError("not support data") p4 = data_to_numpy(p4) - p4 = np.stack(p4).transpose((1, 0, 2)).reshape((-1, 4)) - np.savetxt(file_name, p4) + p4 = np.stack(p4).transpose((1, 0, 2)) + if file_name.endswith("npy"): + np.save(file_name, p4) + else: + np.savetxt(file_name, p4.reshape((-1, 4))) @register_data_mode("multi") @@ -342,6 +353,10 @@ def process_scale(self, idx, data): ) return data + def set_lazy_call(self, data, idx): + for i, data_i in enumerate(data): + super().set_lazy_call(data_i, "s{}{}".format(i, idx)) + def get_n_data(self): data = self.get_data("data") weight = [ @@ -405,6 +420,7 @@ def get_data(self, idx) -> list: data_shape(k) ) ret = self.process_scale(idx, ret) + self.set_lazy_call(ret, idx) return ret def get_phsp_noeff(self): diff --git a/tf_pwa/config_loader/multi_config.py b/tf_pwa/config_loader/multi_config.py index 3070909e..6aa32260 100644 --- a/tf_pwa/config_loader/multi_config.py +++ b/tf_pwa/config_loader/multi_config.py @@ -101,19 +101,30 @@ def get_fcns(self, datas=None, vm=None, batch=65000): if not self.total_same: fcns = [ i[1].get_fcn( - name="s" + str(i[0]), all_data=j, vm=vm, batch=batch + name="s" + str(i[0]), + all_data=j, + vm=vm, + batch=batch, ) for i, j in zip(enumerate(self.configs), datas) ] else: fcns = [ - j.get_fcn(all_data=data, vm=vm, batch=batch) + j.get_fcn( + all_data=data, + vm=vm, + batch=batch, + ) for data, j in zip(datas, self.configs) ] else: if not self.total_same: fcns = [ - j.get_fcn(name="s" + str(i), vm=vm, batch=batch) + j.get_fcn( + name="s" + str(i), + vm=vm, + batch=batch, + ) for i, j in enumerate(self.configs) ] else: @@ -148,14 +159,27 @@ def get_args_value(self, bounds_dict): return args_name, x0, args, bnds @time_print - def fit(self, datas=None, batch=65000, method="BFGS", maxiter=None): + def fit( + self, + datas=None, + batch=65000, + method="BFGS", + maxiter=None, + print_init_nll=False, + callback=None, + ): fcn = self.get_fcn(datas=datas) # fcn.gauss_constr.update({"Zc_Xm_width": (0.177, 0.03180001857)}) print("\n########### initial parameters") print(json.dumps(fcn.get_params(), indent=2), flush=True) - print("initial NLL: ", fcn({})) + if print_init_nll: + print("initial NLL: ", fcn({})) self.fit_params = fit( - fcn=fcn, method=method, bounds_dict=self.bound_dic, maxiter=maxiter + fcn=fcn, + method=method, + bounds_dict=self.bound_dic, + maxiter=maxiter, + callback=callback, ) if self.fit_params.hess_inv is not None: self.inv_he = self.fit_params.hess_inv diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index 0513789d..06bf2fed 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -10,6 +10,7 @@ from tf_pwa.adaptive_bins import cal_chi2 as cal_chi2_o from tf_pwa.data import ( batch_call, + batch_call_numpy, data_index, data_merge, data_replace, @@ -100,7 +101,7 @@ def save(self): yaml.dump(self.linestyle_table, f) -def _get_cfit_bg(self, data, phsp): +def _get_cfit_bg(self, data, phsp, batch=65000): model = self._get_model() bg_function = [i.bg for i in model] w_bkg = [i.w_bkg for i in model] @@ -108,25 +109,21 @@ def _get_cfit_bg(self, data, phsp): for data_i, phsp_i, w, bg_f in zip(data, phsp, w_bkg, bg_function): ndata = np.sum(data_i.get_weight()) nbg = ndata * w - w_bg = bg_f(phsp_i) * phsp_i.get_weight() + w_bg = batch_call_numpy(bg_f, phsp_i, batch) * phsp_i.get_weight() phsp_weight.append(-w_bg / np.sum(w_bg) * nbg) ret = [ data_replace(phsp_i, "weight", w) for phsp_i, w in zip(phsp, phsp_weight) ] return ret - # return [ - # type(phsp_i)({**phsp_i, "weight": w}) - # for phsp_i, w in zip(phsp, phsp_weight) - # ] -def _get_cfit_eff_phsp(self, phsp): +def _get_cfit_eff_phsp(self, phsp, batch=65000): model = self._get_model() eff_function = [i.eff for i in model] phsp_weight = [] for phsp_i, eff_f in zip(phsp, eff_function): - w_eff = eff_f(phsp_i) * phsp_i.get_weight() + w_eff = batch_call_numpy(eff_f, phsp_i, batch) * phsp_i.get_weight() phsp_weight.append(w_eff) ret = [ @@ -272,14 +269,15 @@ def plot_partial_wave( phsp = self.get_phsp_plot() phsp_rec = self.get_phsp_plot("_rec") phsp_rec = phsp if phsp_rec is None else phsp_rec + batch = kwargs.get("batch", 65000) if bg is None: if self.config["data"].get("model", "auto") == "cfit": - bg = _get_cfit_bg(self, data, phsp) + bg = _get_cfit_bg(self, data, phsp, batch) else: bg = [bg] * len(data) if self.config["data"].get("model", "auto") == "cfit": - phsp = _get_cfit_eff_phsp(self, phsp) - phsp_rec = _get_cfit_eff_phsp(self, phsp_rec) + phsp = _get_cfit_eff_phsp(self, phsp, batch) + phsp_rec = _get_cfit_eff_phsp(self, phsp_rec, batch) amp = self.get_amplitude() self._Ngroup = len(data) ws_bkg = [ @@ -471,15 +469,18 @@ def _cal_partial_wave( sr = lambda w: np.sum( np.reshape(data_to_numpy(w), (-1, resolution_size_phsp)), axis=-1 ) - with amp.temp_params(params): - weights_i = [amp(i) for i in data_split(phsp, batch)] - weight_phsp = data_merge(*weights_i) + weight_phsp = batch_call_numpy( + amp, phsp, batch + ) # (i) for i in data_split(phsp, batch)] + # weight_phsp = data_merge(*weights_i) phsp_origin_w = phsp.get("weight", 1.0) * phsp.get("eff_value", 1.0) total_weight = sr(weight_phsp * phsp_origin_w) if ref_amp is not None: - weights_i_ref = [ref_amp(i) for i in data_split(phsp, batch)] - weight_phsp_ref = data_merge(*weights_i_ref) + # weights_i_ref = [ref_amp(i) for i in data_split(phsp, batch)] + weight_phsp_ref = batch_call_numpy( + ref_amp, phsp, batch + ) # data_merge(*weights_i_ref) total_weight_ref = sr(weight_phsp_ref * phsp_origin_w) data_weight = data.get("weight", None) if data_weight is None: @@ -495,33 +496,27 @@ def _cal_partial_wave( norm_frac = n_sig / np.sum(total_weight) if ref_amp is not None: norm_frac_ref = n_sig / np.sum(total_weight_ref) - if res is None: - weights = amp.partial_weight(phsp) - else: - weights = [] - used_res = amp.used_res - for i in res: - if not isinstance(i, list): - i = [i] - amp.set_used_res(i) - weights.append(amp(phsp)) - # print(weights, amp.decay_group.chains_idx) - amp.set_used_res(used_res) - - data_weights = data.get("weight", np.ones((data_shape(data),))) - data_dict["data_weights"] = cut_function(data) * data_weights + weights = batch_call_numpy( + lambda x: amp.partial_weight(x, combine=res), phsp, batch + ) + data_weights = ( + data_weight # data.get("weight", np.ones((data_shape(data),))) + ) + data_dict["data_weights"] = ( + batch_call_numpy(cut_function, data, batch) * data_weights + ) phsp_weights = total_weight * norm_frac - phsp_dict["MC_total_fit"] = ( - cut_function(phsp_rec) * phsp_weights - ) # MC total weight + cut_phsp = batch_call_numpy(cut_function, phsp_rec, batch) + phsp_dict["MC_total_fit"] = cut_phsp * phsp_weights # MC total weight + if ref_amp is not None: phsp_dict["MC_total_fit_ref"] = ( - cut_function(phsp_rec) * total_weight_ref * norm_frac_ref + cut_phsp * total_weight_ref * norm_frac_ref ) if bg is not None: bg_weight = -w_bkg bg_dict["sideband_weights"] = ( - cut_function(bg) * bg_weight + batch_call_numpy(cut_function, bg, batch) * bg_weight ) # sideband weight for i, name_i, label, _ in chain_property: weight_i = ( @@ -531,20 +526,22 @@ def _cal_partial_wave( * phsp.get("weight", 1.0) * phsp.get("eff_value", 1.0) ) - phsp_dict["MC_{0}_{1}_fit".format(i, name_i)] = cut_function( - phsp_rec - ) * sr( + phsp_dict["MC_{0}_{1}_fit".format(i, name_i)] = cut_phsp * sr( weight_i ) # MC partial weight for name in plot_var_dic: idx = plot_var_dic[name]["idx"] trans = lambda x: np.reshape(plot_var_dic[name]["trans"](x), (-1,)) - data_i = trans(data_index(data, idx)) + data_i = batch_call_numpy( + lambda x: trans(data_index(x, idx)), data, batch + ) if idx[-1] == "m": tmp_idx = list(idx) tmp_idx[-1] = "p" - p4 = data_index(data, tmp_idx) + p4 = batch_call_numpy( + lambda x: data_index(x, tmp_idx), data, batch + ) p4 = np.transpose(p4) data_dict[name + "_E"] = p4[0] data_dict[name + "_PX"] = p4[1] @@ -552,11 +549,15 @@ def _cal_partial_wave( data_dict[name + "_PZ"] = p4[3] data_dict[name] = data_i # data variable - phsp_i = trans(data_index(phsp_rec, idx)) + phsp_i = batch_call_numpy( + lambda x: trans(data_index(x, idx)), phsp_rec, batch + ) phsp_dict[name + "_MC"] = phsp_i # MC if bg is not None: - bg_i = trans(data_index(bg, idx)) + bg_i = batch_call_numpy( + lambda x: trans(data_index(x, idx)), bg, batch + ) bg_dict[name + "_sideband"] = bg_i # sideband data_dict = data_to_numpy(data_dict) phsp_dict = data_to_numpy(phsp_dict) @@ -737,7 +738,7 @@ def _plot_partial_wave( ax.set_title(display, fontsize="xx-large") else: ax.set_title( - "{}: -lnL= {:.5}".format(display, nll), fontsize="xx-large" + "{}: -lnL= {:.2f}".format(display, nll), fontsize="xx-large" ) ax.set_xlabel(display + units) ywidth = np.mean( @@ -886,6 +887,23 @@ def _2d_plot( plt.ylim(range2) plt.savefig(prefix + k + "_data") plt.clf() + print("Finish plotting 2D data " + prefix + k) # data + if "data_hist" in plot_figs: + plt.hist2d( + data_1, + data_2, + bins=100, + weights=data_dict["data_weights"], + cmin=1e-12, + ) + plt.xlabel(name1) + plt.ylabel(name2) + plt.title(display, fontsize="xx-large") + plt.legend() + plt.xlim(range1) + plt.ylim(range2) + plt.savefig(prefix + k + "_data_hist") + plt.clf() print("Finish plotting 2D data " + prefix + k) # sideband if "sideband" in plot_figs: diff --git a/tf_pwa/data.py b/tf_pwa/data.py index a36ccf2e..745f670e 100644 --- a/tf_pwa/data.py +++ b/tf_pwa/data.py @@ -69,6 +69,14 @@ from collections import Iterable +class HeavyCall: + def __init__(self, f): + self.f = f + + def __call__(self, *args, **kwargs): + return self.f(*args, **kwargs) + + class LazyCall: def __init__(self, f, x, *args, **kwargs): self.f = f @@ -76,28 +84,96 @@ def __init__(self, f, x, *args, **kwargs): self.args = args self.kwargs = kwargs self.extra = {} + self.batch_size = None + self.cached_batch = {} + self.cached_file = None + self.name = "" + + def batch(self, batch, axis=0): + return self.as_dataset(batch) + + def __iter__(self): + assert self.batch_size is not None, "" + if isinstance(self.f, HeavyCall): + for i, j in zip( + self.cached_batch[self.batch_size], + split_generator(self.extra, self.batch_size), + ): + yield {**i, **j} + elif isinstance(self.x, LazyCall): + for i, j in zip( + self.x, split_generator(self.extra, self.batch_size) + ): + yield {**i, **j} + else: + for i, j in zip( + split_generator(self.x, self.batch_size), + split_generator(self.extra, self.batch_size), + ): + yield {**i, **j} + + def as_dataset(self, batch=65000): + self.batch_size = batch + if isinstance(self.x, LazyCall): + self.x.as_dataset(batch) + + if not isinstance(self.f, HeavyCall): + return self - def batch(self, batch, axis): - for i, j in zip( - data_split(self.x, batch, axis=axis), - data_split(self.extra, batch, axis=axis), - ): - ret = LazyCall(self.f, i, *self.args, **self.kwargs) - for k, v in j.items(): - ret[k] = v - yield ret + if batch in self.cached_batch: + return self + + def f(x): + ret = self.f(x, *self.args, **self.kwargs) + return ret + + if isinstance(self.x, LazyCall): + real_x = self.x.eval() + else: + real_x = self.x + + data = tf.data.Dataset.from_tensor_slices(real_x) + # data = data.batch(batch).cache().map(f) + if self.cached_file is not None: + from tf_pwa.utils import create_dir + + cached_file = self.cached_file + self.name + + cached_file += "_" + str(batch) + create_dir(cached_file) + data = data.batch(batch).map(f) + if self.cached_file == "": + data = data.cache() + else: + data = data.cache(cached_file) + else: + data = data.batch(batch).cache().map(f) + data = data.prefetch(tf.data.AUTOTUNE) + + self.cached_batch[batch] = data + return self + + def set_cached_file(self, cached_file, name): + if isinstance(self.x, LazyCall): + self.x.set_cached_file(cached_file, name) + self.cached_file = cached_file + self.name = name def merge(self, *other, axis=0): all_x = [self.x] all_extra = [self.extra] for i in other: all_x.append(i.x) - all_extra = [i.extra] + all_extra.append(i.extra) new_extra = data_merge(*all_extra, axis=axis) ret = LazyCall( self.f, data_merge(*all_x, axis=axis), *self.args, **self.kwargs ) ret.extra = new_extra + ret.cached_file = self.cached_file + ret.name = self.name + for i in other: + ret.name += "_" + i.name return ret def __setitem__(self, index, value): @@ -119,8 +195,10 @@ def get_weight(self): return tf.ones(data_shape(self), dtype=get_config("dtype")) def copy(self): - ret = LazyCall(lambda x: x, self) + ret = LazyCall(self.f, self.x, *self.args, **self.kwargs) ret.extra = self.extra.copy() + ret.cached_file = self.cached_file + ret.name = self.name return ret def eval(self): @@ -134,6 +212,12 @@ def eval(self): ret[k] = v return ret + def __len__(self): + x = self.x + if isinstance(self.x, LazyCall): + x = x.eval() + return data_shape(x) + class EvalLazy: def __init__(self, f): @@ -483,11 +567,22 @@ def list_gen(dat): def batch_call(function, data, batch=10000): ret = [] - for i in data_split(data, batch): - ret.append(function(i)) + if isinstance(data, LazyCall): + batches = data.as_dataset(batch) + else: + batches = data_split(data, batch) + for i in batches: + tmp = function(i) + if isinstance(tmp, (int, float)): + tmp = tmp * np.ones((data_shape(i),)) + ret.append(tmp) return data_merge(*ret) +def batch_call_numpy(function, data, batch=10000): + return data_to_numpy(batch_call(function, data, batch)) + + def data_index(data, key): """Indexing data for key or a list of keys.""" if isinstance(data, LazyCall): diff --git a/tf_pwa/data_trans/helicity_angle.py b/tf_pwa/data_trans/helicity_angle.py index 1a411d1f..0866ade9 100644 --- a/tf_pwa/data_trans/helicity_angle.py +++ b/tf_pwa/data_trans/helicity_angle.py @@ -271,7 +271,7 @@ def lorentz_neg(pc): def generate_p(ms, msp, costheta, phi): """ ms(0) -> ms(1) + msp(0), costheta(0), phi(0) - ms(1) -> ms(2) + msp(1), costheta(0), phi(0) + ms(1) -> ms(2) + msp(1), costheta(1), phi(1) ... ms(n) -> ms(n+1) + msp(n), costheta(n), phi(n) diff --git a/tf_pwa/experimental/wrap_function.py b/tf_pwa/experimental/wrap_function.py index 31eabd6b..5f334f32 100644 --- a/tf_pwa/experimental/wrap_function.py +++ b/tf_pwa/experimental/wrap_function.py @@ -4,7 +4,9 @@ def _wrap_struct(dic, first_none=True): if isinstance(dic, dict): - return {k: _wrap_struct(v, first_none) for k, v in dic.items()} + return { + k: _wrap_struct(dic[k], first_none) for k in sorted(dic.keys()) + } if isinstance(dic, list): return [_wrap_struct(v, first_none) for v in dic] if isinstance(dic, tuple): @@ -19,8 +21,8 @@ def _wrap_struct(dic, first_none=True): def _flatten(dic): if isinstance(dic, dict): - for k, v in dic.items(): - yield from _flatten(v) + for k in sorted(dic.keys()): + yield from _flatten(dic[k]) if isinstance(dic, (list, tuple)): for v in dic: yield from _flatten(v) @@ -52,10 +54,11 @@ def _nest(dic, value, idx=None): class WrapFun: - def __init__(self, f): + def __init__(self, f, jit_compile=False): self.f = f self.cached_f = {} self.struct = {} + self.jit_compile = jit_compile def __call__(self, *args, **kwargs): @@ -71,7 +74,9 @@ def _g(*x): *new_args, **new_kwargs ) # *new_args, **new_kwargs) - self.cached_f[idx] = tf.function(_g).get_concrete_function( + _g2 = tf.function(_g, jit_compile=self.jit_compile) + + self.cached_f[idx] = _g2.get_concrete_function( *list(_flatten(self.struct[idx])) ) new_x = [ diff --git a/tf_pwa/fit.py b/tf_pwa/fit.py index 6a58487a..3870b474 100644 --- a/tf_pwa/fit.py +++ b/tf_pwa/fit.py @@ -194,6 +194,7 @@ def fit_scipy( improve=False, maxiter=None, jac=True, + callback=None, ): """ @@ -243,6 +244,10 @@ def v_g2(x0): for i, name in enumerate(args_name): print(args_name[i], gs[i], gs0[i]) + callback_inner = lambda x, y: None + if callback is not None: + callback_inner = callback + if method in ["BFGS", "CG", "Nelder-Mead", "test"]: def callback(x): @@ -255,6 +260,7 @@ def callback(x): # with open("fit_curve.json", "w") as f: # json.dump({"points": points, "nlls": nlls}, f, indent=2) # pass # raise Exception("Reached the largest iterations: {}".format(maxiter)) + callback_inner(x, fcn) print(fcn.cached_nll) # bd = Bounds(bnds) diff --git a/tf_pwa/model/model.py b/tf_pwa/model/model.py index 8f267e1d..b50fbaea 100644 --- a/tf_pwa/model/model.py +++ b/tf_pwa/model/model.py @@ -12,6 +12,7 @@ from ..data import ( EvalLazy, data_merge, + data_replace, data_shape, data_split, split_generator, @@ -740,8 +741,8 @@ def nll_grad_hessian( mc_weight = tf.convert_to_tensor( [mc_weight] * data_shape(mcdata), dtype="float64" ) - data_i = {**data, "weight": weight} - mcdata_i = {**mcdata, "weight": mc_weight} + data_i = data_replace(data, "weight", weight) + mcdata_i = data_replace(mcdata, "weight", mc_weight) return self.model.nll_grad_hessian(data_i, mcdata_i, batch=batch) def set_params(self, var): @@ -1067,6 +1068,17 @@ def nll_gradient(self, data, mcdata, weight=1.0, batch=None, bg=None): return nll, g +def _convert_batch(data, batch, cached_file=None, name=""): + from tf_pwa.data import LazyCall + + if isinstance(data, LazyCall): + if cached_file is not None: + return data.as_dataset(batch, cached_file + name) + else: + return data.as_dataset(batch) + return list(split_generator(data, batch)) + + class FCN(object): """ This class implements methods to calculate the NLL as well as its derivatives for a general function. @@ -1095,17 +1107,17 @@ def __init__( self.cached_nll = None if inmc is None: data, weight = self.model.get_weight_data(data, bg=bg) - print("Using Model_old") + print("Using Model") else: data, weight = self.model.get_weight_data(data, bg=bg, inmc=inmc) - print("Using Model_new") + print("Using Model with inmc") n_mcdata = data_shape(mcdata) self.alpha = tf.reduce_sum(weight) / tf.reduce_sum(weight * weight) self.weight = weight self.data = data - self.batch_data = list(split_generator(data, batch)) + self.batch_data = self._convert_batch(data, batch) self.mcdata = mcdata - self.batch_mcdata = list(split_generator(mcdata, batch)) + self.batch_mcdata = self._convert_batch(mcdata, batch) self.batch = batch if mcdata.get("weight", None) is not None: mc_weight = tf.convert_to_tensor(mcdata["weight"], dtype="float64") @@ -1114,10 +1126,13 @@ def __init__( self.mc_weight = tf.convert_to_tensor( [1 / n_mcdata] * n_mcdata, dtype="float64" ) - self.batch_mc_weight = list(data_split(self.mc_weight, self.batch)) + self.batch_mc_weight = self._convert_batch(self.mc_weight, self.batch) self.gauss_constr = GaussianConstr(self.vm, gauss_constr) self.cached_mc = {} + def _convert_batch(self, data, batch): + return _convert_batch(data, batch) + def get_params(self, trainable_only=False): return self.vm.get_all_dic(trainable_only) diff --git a/tf_pwa/particle.py b/tf_pwa/particle.py index b96fd9e7..6bd0b037 100644 --- a/tf_pwa/particle.py +++ b/tf_pwa/particle.py @@ -877,10 +877,12 @@ def topology_structure(self, identical=False, standard=True): """ ret = [] for i in self: + found = False for j in ret: if i.topology_same(j, identical): + found = True break - else: + if not found: ret.append(i) if standard: return [i.standard_topology() for i in ret] diff --git a/tf_pwa/tests/config_cfit.yml b/tf_pwa/tests/config_cfit.yml index cc390e6b..81a25f53 100644 --- a/tf_pwa/tests/config_cfit.yml +++ b/tf_pwa/tests/config_cfit.yml @@ -23,11 +23,11 @@ decay: particle: $top: - A: { J: 1, P: -1, spins: [-1, 1] } + A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 } $finals: - B: { J: 1, P: -1 } - C: { J: 1, P: -1 } - D: { J: 0, P: -1 } + B: { J: 1, P: -1, mass: 2.00698 } + C: { J: 1, P: -1, mass: 2.01028 } + D: { J: 0, P: -1, mass: 0.13957 } R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, model: BWR2 } R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3, model: BW } R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 } diff --git a/tf_pwa/tests/config_lazycall.yml b/tf_pwa/tests/config_lazycall.yml new file mode 100644 index 00000000..490e43e8 --- /dev/null +++ b/tf_pwa/tests/config_lazycall.yml @@ -0,0 +1,50 @@ +data: + dat_order: [B, C, D] + data: ["toy_data/data.dat"] + bg: ["toy_data/bg.dat"] + phsp: ["toy_data/PHSP.dat"] + random_z: False + r_boost: False + bg_weight: 0.1 + lazy_call: True + use_tf_function: True + no_id_cached: True + jit_compile: True + only_left_angle: True + cached_lazy_call: toy_data/cached/ + +decay: + A: + - [R_BC, D] + - [R_BD, C] + - [R_CD, B] + R_BC: [B, C] + R_BD: [B, D] + R_CD: [C, D] + +particle: + $top: + A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 } + $finals: + B: { J: 1, P: -1, mass: 2.00698 } + C: { J: 1, P: -1, mass: 2.01028 } + D: { J: 0, P: -1, mass: 0.13957 } + R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, params: { mass_range: [4.0, 4.2] } } + R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3 } + R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 } + +constrains: + particle: null + decay: null + +plot: + mass: + R_BC: { display: "$M_{BC}$" } + R_BD: { display: "$M_{BD}$" } + R_CD: { display: "$M_{CD}$" } + angle: + R_BC/B: + cos(beta): + display: "cos \\theta" + alpha: + display: "\\phi" diff --git a/tf_pwa/tests/config_rec.yml b/tf_pwa/tests/config_rec.yml index dee65cb3..a5af5e78 100644 --- a/tf_pwa/tests/config_rec.yml +++ b/tf_pwa/tests/config_rec.yml @@ -23,11 +23,11 @@ decay: particle: $top: - A: { J: 1, P: -1, spins: [-1, 1] } + A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 } $finals: - B: { J: 1, P: -1 } - C: { J: 1, P: -1 } - D: { J: 0, P: -1 } + B: { J: 1, P: -1, mass: 2.00698 } + C: { J: 1, P: -1, mass: 2.01028 } + D: { J: 0, P: -1, mass: 0.13957 } R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, model: BWR2 } R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3, model: BW } R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 } diff --git a/tf_pwa/tests/config_toy.yml b/tf_pwa/tests/config_toy.yml index 89adeefe..920aec13 100644 --- a/tf_pwa/tests/config_toy.yml +++ b/tf_pwa/tests/config_toy.yml @@ -18,11 +18,11 @@ decay: particle: $top: - A: { m0: 4.6, J: 1, P: -1, spins: [-1, 1] } + A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 } $finals: - B: { m0: 2.00698, J: 1, P: -1 } - C: { m0: 2.01028, J: 1, P: -1 } - D: { m0: 0.13957, J: 0, P: -1 } + B: { J: 1, P: -1, mass: 2.00698 } + C: { J: 1, P: -1, mass: 2.01028 } + D: { J: 0, P: -1, mass: 0.13957 } R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, params: { mass_range: [4.0, 4.2] } } R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3 } R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 } diff --git a/tf_pwa/tests/config_toy2.yml b/tf_pwa/tests/config_toy2.yml index 048322ac..6e611c67 100644 --- a/tf_pwa/tests/config_toy2.yml +++ b/tf_pwa/tests/config_toy2.yml @@ -22,11 +22,11 @@ decay: particle: $top: - A: { J: 1, P: -1, spins: [-1, 1] } + A: { J: 1, P: -1, spins: [-1, 1], mass: 4.6 } $finals: - B: { J: 1, P: -1 } - C: { J: 1, P: -1 } - D: { J: 0, P: -1 } + B: { J: 1, P: -1, mass: 2.00698 } + C: { J: 1, P: -1, mass: 2.01028 } + D: { J: 0, P: -1, mass: 0.13957 } R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, model: BWR2 } R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3, model: BW } R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 } diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index 3ff2d113..8cf34227 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -71,6 +71,13 @@ def toy_config(gen_toy): return config +@pytest.fixture +def toy_config_lazy(gen_toy): + config = ConfigLoader(f"{this_dir}/config_lazycall.yml") + config.set_params(f"{this_dir}/exp_params.json") + return config + + def test_build_angle_amplitude(toy_config): data = toy_config.get_data("data") dec = toy_config.get_amplitude().decay_group @@ -236,6 +243,13 @@ def test_fit(toy_config, fit_result): xy_err = pt.get_error_matrix([x, y]) +def test_lazycall(toy_config_lazy): + toy_config_lazy.fit(batch=100000) + toy_config_lazy.plot_partial_wave( + prefix="toy_data/figure_lazy", batch=100000 + ) + + def test_cal_chi2(toy_config, fit_result): toy_config.cal_chi2(bins=[[2, 2]] * 2, mass=["R_BD", "R_CD"]) diff --git a/tf_pwa/utils.py b/tf_pwa/utils.py index 174a2bb4..63bbc530 100644 --- a/tf_pwa/utils.py +++ b/tf_pwa/utils.py @@ -32,6 +32,18 @@ def _load_yaml_file(name): return yaml.load(f, Loader=yaml.FullLoader) +def create_dir(name): + import os + + dirname = os.path.dirname(name) + if dirname == "": + return True + if not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + return False + return True + + def load_config_file(name): """ Load config file such as **Resonances.yml**.