From 6c10d11c3b74720ef6896c8fe42a5192b140e6ed Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sun, 20 Aug 2023 22:23:48 +0800 Subject: [PATCH 1/4] feat: cal angle option for p4 directly --- tf_pwa/amp/amp.py | 14 +++++++++++++- tf_pwa/config_loader/config_loader.py | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tf_pwa/amp/amp.py b/tf_pwa/amp/amp.py index fb188b24..9af4fd73 100644 --- a/tf_pwa/amp/amp.py +++ b/tf_pwa/amp/amp.py @@ -285,7 +285,19 @@ class P4DirectlyAmplitudeModel(BaseAmplitudeModel): def cal_angle(self, p4): from tf_pwa.cal_angle import cal_angle_from_momentum - ret = cal_angle_from_momentum(p4, self.decay_group) + extra_kwargs = self.extra_kwargs["all_config"] + kwargs = {} + for k in [ + "center_mass", + "r_boost", + "random_z", + "align_ref", + "only_left_angle", + ]: + if k in extra_kwargs: + kwargs[k] = extra_kwargs[k] + + ret = cal_angle_from_momentum(p4, self.decay_group, **kwargs) return ret def pdf(self, data): diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 6889cd84..391dcd9b 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -241,6 +241,7 @@ def get_amplitude(self, vm=None, name=""): jit_compile=jit_compile, model=amp_model, cached_shape_idx=cached_shape_idx, + all_config=amp_config, ) self.add_constraints(amp) self.amps[vm] = amp From 652fdcc3f7fe30d13dd7c79144aeef4b54d86ca7 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sun, 20 Aug 2023 22:41:21 +0800 Subject: [PATCH 2/4] feat: build_simple_input for decay --- tf_pwa/amp/core.py | 11 ++++++++++- tf_pwa/tests/test_amp.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tf_pwa/amp/core.py b/tf_pwa/amp/core.py index 397c56f1..a65270ed 100644 --- a/tf_pwa/amp/core.py +++ b/tf_pwa/amp/core.py @@ -700,7 +700,7 @@ def get_factor_variable(self): return [(self.g_ls,)] def _get_particle_mass(self, p, data, from_data=False): - if from_data: + if from_data and p in data: return data[p]["m"] if p.mass is None: p.mass = tf.reduce_mean(data[p]["m"]) @@ -789,6 +789,15 @@ def _get_cg_matrix(self, ls): # CG factor inside H ) return ret + def build_simple_input(self): + data_p = {self.core: {"m": self.core.get_mass()}} + data = {} + zero = np.array(0.0) + for i in self.outs: + data_p[i] = {"m": i.get_mass()} + data[i] = {"ang": {"alpha": zero, "beta": zero, "gamma": zero}} + return {"data": data, "data_p": data_p} + def get_helicity_amp(self, data, data_p, **kwargs): m_dep = self.get_ls_amp(data, data_p, **kwargs) cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype) diff --git a/tf_pwa/tests/test_amp.py b/tf_pwa/tests/test_amp.py index 1d8d3b41..47468114 100644 --- a/tf_pwa/tests/test_amp.py +++ b/tf_pwa/tests/test_amp.py @@ -334,6 +334,7 @@ def test_decay_ls_amp(): data_p = {a: {"m": ma}, c: {"m": mb}, d: {"m": mb}} dec1.get_ls_amp(data, data_p) dec1.get_ls_amp({**data, "|q|2": mb}, data_p) + dec1.get_ls_amp(**dec1.build_simple_input()) def test_polarization(): From 5491cf86ba5def96b6770116638ac68ccad78347 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Mon, 21 Aug 2023 11:41:46 +0800 Subject: [PATCH 3/4] feat: build_ls2hel_eq --- tf_pwa/amp/core.py | 47 +++++++++++++++++++++++++++++++++------ tf_pwa/tests/test_amp.py | 3 ++- tf_pwa/tests/test_full.py | 27 ++++++++++++++++++++++ 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/tf_pwa/amp/core.py b/tf_pwa/amp/core.py index a65270ed..a0d65ba7 100644 --- a/tf_pwa/amp/core.py +++ b/tf_pwa/amp/core.py @@ -743,12 +743,12 @@ def get_relative_momentum2(self, data, from_data=False): ret = get_relative_p2(m0, m1, m2) return ret - def get_cg_matrix(self): + def get_cg_matrix(self, out_sym=False): ls = self.get_ls_list() - return self._get_cg_matrix(ls) + return self._get_cg_matrix(ls, out_sym=out_sym) @functools.lru_cache() - def _get_cg_matrix(self, ls): # CG factor inside H + def _get_cg_matrix(self, ls, out_sym=False): # CG factor inside H """ [(l,s),(lambda_b,lambda_c)] @@ -765,6 +765,18 @@ def _get_cg_matrix(self, ls): # CG factor inside H self.outs[1].spins ) # _spin_int(2 * jb + 1), _spin_int(2 * jc + 1) ret = np.zeros(shape=(m, *n)) + sqrt = np.sqrt + my_cg_coef = cg_coef + if out_sym: + from sympy.physics.quantum.cg import CG + + sint = lambda x: sym.simplify(_spin_int(x * 2)) / 2 + sqrt = lambda x: sym.sqrt(sint(x)) + + def my_cg_coef(a, b, c, d, e, f): + return CG(*list(map(sint, [a, c, b, d, e, f]))).doit() + + ret = ret.tolist() for i, ls_i in enumerate(ls): l, s = ls_i for i1, lambda_b in enumerate( @@ -774,11 +786,12 @@ def _get_cg_matrix(self, ls): # CG factor inside H self.outs[1].spins ): # _spin_range(-jc, jc)): ret[i][i1][i2] = ( - np.sqrt((2 * l + 1) / (2 * ja + 1)) - * cg_coef( + sqrt(2 * l + 1) + / sqrt(2 * ja + 1) + * my_cg_coef( jb, jc, lambda_b, -lambda_c, s, lambda_b - lambda_c ) - * cg_coef( + * my_cg_coef( l, s, 0, @@ -789,7 +802,27 @@ def _get_cg_matrix(self, ls): # CG factor inside H ) return ret - def build_simple_input(self): + def build_ls2hel_eq(self): + cg_matrix = self.get_cg_matrix(out_sym=True) + gls = [] + for l, s in self.get_ls_list(): + gls.append(sym.Symbol("g_{}_{}".format(l, s))) + hel = [] + eqs = [] + for ib, lb in enumerate(self.outs[0].spins): + for ic, lc in enumerate(self.outs[1].spins): + tmp = sym.Symbol("H_{}_{}".format(lb, lc)) + rhs = 0 + for idx, gi in enumerate(gls): + rhs = rhs + cg_matrix[idx][ib][ic] * gi + if rhs == 0: + continue + eq = sym.Eq(tmp, sym.simplify(rhs)) + hel.append(tmp) + eqs.append(eq) + return [gls, hel, eqs] + + def build_simple_data(self): data_p = {self.core: {"m": self.core.get_mass()}} data = {} zero = np.array(0.0) diff --git a/tf_pwa/tests/test_amp.py b/tf_pwa/tests/test_amp.py index 47468114..b3e977bd 100644 --- a/tf_pwa/tests/test_amp.py +++ b/tf_pwa/tests/test_amp.py @@ -334,7 +334,8 @@ def test_decay_ls_amp(): data_p = {a: {"m": ma}, c: {"m": mb}, d: {"m": mb}} dec1.get_ls_amp(data, data_p) dec1.get_ls_amp({**data, "|q|2": mb}, data_p) - dec1.get_ls_amp(**dec1.build_simple_input()) + dec1.get_ls_amp(**dec1.build_simple_data()) + dec1.build_ls2hel_eq() def test_polarization(): diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index fdddf9e3..e29f2f40 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -18,6 +18,33 @@ this_dir = os.path.dirname(os.path.abspath(__file__)) +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + +@pytest.fixture(scope="session", autouse=True) +def timer_session_scope(): + start = time.time() + print( + "\nstart: {}".format(time.strftime(DATE_FORMAT, time.localtime(start))) + ) + + yield + + finished = time.time() + print( + "finished: {}".format( + time.strftime(DATE_FORMAT, time.localtime(finished)) + ) + ) + print("Total time cost: {:.3f}s".format(finished - start)) + + +@pytest.fixture(autouse=True) +def timer_function_scope(): + start = time.time() + yield + print(" Time cost: {:.3f}s".format(time.time() - start)) + def generate_phspMC(Nmc): """Generate PhaseSpace MC of size Nmc and save it as txt file""" From e63e3233bb5c6de91162bd0468883127498dfd44 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Mon, 21 Aug 2023 13:08:23 +0800 Subject: [PATCH 4/4] ci: add check for fit results --- tf_pwa/tests/config_toy2.yml | 4 ++-- tf_pwa/tests/test_full.py | 42 +++++++++--------------------------- 2 files changed, 12 insertions(+), 34 deletions(-) diff --git a/tf_pwa/tests/config_toy2.yml b/tf_pwa/tests/config_toy2.yml index 6e611c67..9be6ab72 100644 --- a/tf_pwa/tests/config_toy2.yml +++ b/tf_pwa/tests/config_toy2.yml @@ -27,8 +27,8 @@ particle: B: { J: 1, P: -1, mass: 2.00698 } C: { J: 1, P: -1, mass: 2.01028 } D: { J: 0, P: -1, mass: 0.13957 } - R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, model: BWR2 } - R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3, model: BW } + R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1 } + R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3 } R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 } constrains: diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index e29f2f40..df36a65d 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -18,33 +18,6 @@ this_dir = os.path.dirname(os.path.abspath(__file__)) -DATE_FORMAT = "%Y-%m-%d %H:%M:%S" - - -@pytest.fixture(scope="session", autouse=True) -def timer_session_scope(): - start = time.time() - print( - "\nstart: {}".format(time.strftime(DATE_FORMAT, time.localtime(start))) - ) - - yield - - finished = time.time() - print( - "finished: {}".format( - time.strftime(DATE_FORMAT, time.localtime(finished)) - ) - ) - print("Total time cost: {:.3f}s".format(finished - start)) - - -@pytest.fixture(autouse=True) -def timer_function_scope(): - start = time.time() - yield - print(" Time cost: {:.3f}s".format(time.time() - start)) - def generate_phspMC(Nmc): """Generate PhaseSpace MC of size Nmc and save it as txt file""" @@ -117,7 +90,7 @@ def toy_config2(gen_toy, fit_result): config = MultiConfig( [f"{this_dir}/config_toy.yml", f"{this_dir}/config_toy2.yml"] ) - config.set_params(f"{this_dir}/exp2_params.json") + config.set_params(f"{this_dir}/exp_params.json") return config @@ -131,6 +104,7 @@ def toy_config3(gen_toy): @pytest.fixture def fit_result(toy_config): ret = toy_config.fit() + assert np.allclose(ret.min_nll, -204.9468493307786) return ret @@ -231,7 +205,8 @@ def test_fit_lazy_call(gen_toy): config_dic["data"]["lazy_call"] = True config = ConfigLoader(config_dic) config.set_params(f"{this_dir}/exp_params.json") - config.fit(print_init_nll=False) + results = config.fit(print_init_nll=False) + assert np.allclose(results.min_nll, -204.9468493307786) fcn = config.get_fcn() fcn.nll_grad() config.plot_partial_wave(prefix="toy_data/figure/s2") @@ -351,7 +326,8 @@ def test_bacth_sum(toy_config, fit_result): def test_lazycall(toy_config_lazy): - toy_config_lazy.fit(batch=100000) + results = toy_config_lazy.fit(batch=100000) + assert np.allclose(results.min_nll, -204.9468493307786) toy_config_lazy.plot_partial_wave( prefix="toy_data/figure_lazy", batch=100000 ) @@ -366,13 +342,15 @@ def test_cal_signal_yields(toy_config, fit_result): def test_fit_combine(toy_config2): - toy_config2.fit() + results = toy_config2.fit() + print(results) + assert np.allclose(results.min_nll, -204.9468493307786 * 2) toy_config2.get_params_error() print(toy_config2.get_params()) def test_mix_likelihood(toy_config3): - toy_config3.fit(maxiter=1) + results = toy_config3.fit(maxiter=1) def test_cp_particles():