From bafb4a5890bb3c8bf50035542a4615a532f06b20 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Mon, 10 Jun 2024 00:25:49 +0800 Subject: [PATCH 1/2] feat: new particle model MLP fit shape example in checks --- checks/fit_shape/README.md | 26 +++++++ checks/fit_shape/config.yml | 22 ++++++ checks/fit_shape/config_mi.yml | 27 ++++++++ checks/fit_shape/final_params_leaky_relu.json | 38 ++++++++++ checks/fit_shape/final_params_relu.json | 38 ++++++++++ checks/fit_shape/fit_shape.py | 69 +++++++++++++++++++ checks/fit_shape/init_params.json | 18 +++++ tf_pwa/amp/base.py | 48 +++++++++++++ tf_pwa/config_loader/particle_function.py | 11 +++ .../tests/test_particle_function.py | 2 + 10 files changed, 299 insertions(+) create mode 100644 checks/fit_shape/README.md create mode 100644 checks/fit_shape/config.yml create mode 100644 checks/fit_shape/config_mi.yml create mode 100644 checks/fit_shape/final_params_leaky_relu.json create mode 100644 checks/fit_shape/final_params_relu.json create mode 100644 checks/fit_shape/fit_shape.py create mode 100644 checks/fit_shape/init_params.json diff --git a/checks/fit_shape/README.md b/checks/fit_shape/README.md new file mode 100644 index 00000000..068599ea --- /dev/null +++ b/checks/fit_shape/README.md @@ -0,0 +1,26 @@ +# Script to directly fit the shape with MLP + +# input: + +- config.yml: target + +- init_params.json: parameters for target model + +- config_mi.yml: fit model + +# fit script + +- fit_shape.py + +# output: + +- fit_results.png: plot of real and image + +- final_params.json: results of fit parameters + +# example + +- final*params*\*.json: one possible fit results with activation name in the + tail. + +The best model have loss about 31.85. diff --git a/checks/fit_shape/config.yml b/checks/fit_shape/config.yml new file mode 100644 index 00000000..77d06753 --- /dev/null +++ b/checks/fit_shape/config.yml @@ -0,0 +1,22 @@ +data: + dat_order: [B, C, D] + +decay: + A: [BC, D] + BC: [B, C] + +particle: + $top: A + $finals: [B, C, D] + A: { J: 0, P: -1, mass: 3.0 } + B: { J: 0, P: -1, mass: 0.1 } + C: { J: 0, P: -1, mass: 0.1 } + D: { J: 0, P: -1, mass: 0.1 } + BC: [BC1, BC2] + BC1: { J: 0, P: +1, mass: 1.0, width: 0.5 } + BC2: { J: 0, P: +1, mass: 2.0, width: 0.5 } + +constrains: + decay: + fix_chain_idx: 0 + fix_chain_val: 1 diff --git a/checks/fit_shape/config_mi.yml b/checks/fit_shape/config_mi.yml new file mode 100644 index 00000000..8d4d3cdb --- /dev/null +++ b/checks/fit_shape/config_mi.yml @@ -0,0 +1,27 @@ +data: + dat_order: [B, C, D] + +decay: + A: [BC, D] + BC: [B, C] + +particle: + $top: A + $finals: [B, C, D] + A: { J: 0, P: -1, mass: 3.0 } + B: { J: 0, P: -1, mass: 0.1 } + C: { J: 0, P: -1, mass: 0.1 } + D: { J: 0, P: -1, mass: 0.1 } + BC: [MI] + MI: + J: 0 + P: +1 + mass: 1.0 + interp_N: 10 + model: MLP + activation: leaky_relu + +constrains: + decay: + fix_chain_idx: 0 + fix_chain_val: 1 diff --git a/checks/fit_shape/final_params_leaky_relu.json b/checks/fit_shape/final_params_leaky_relu.json new file mode 100644 index 00000000..8d84df62 --- /dev/null +++ b/checks/fit_shape/final_params_leaky_relu.json @@ -0,0 +1,38 @@ +{ + "MI_b_0": -0.8796771549166532, + "MI_b_1": -0.5555867175456078, + "MI_b_2": 0.25869784992909006, + "MI_b_3": -1.0627308851937134, + "MI_b_4": 3.6615951678704897, + "MI_b_5": -0.07954771940190528, + "MI_b_6": -0.256124862341781, + "MI_b_7": 14.23127083846775, + "MI_b_8": -1.286108632752361, + "MI_b_9": 0.07076624632726991, + "MI_w_0r": -5.104381116132341, + "MI_w_0i": 7.871238527765366, + "MI_w_1r": 3.23365301131465, + "MI_w_1i": 2.5985262831448064, + "MI_w_2r": -5.405269642657437, + "MI_w_2i": 5.314416098804127, + "MI_w_3r": 6.633822670781981, + "MI_w_3i": 6.899102757998181, + "MI_w_4r": -2.241794106880599, + "MI_w_4i": 4.787308576161889, + "MI_w_5r": 11.011051058950653, + "MI_w_5i": -0.7205598197812919, + "MI_w_6r": 7.060488845223941, + "MI_w_6i": 7.225803210685805, + "MI_w_7r": 0.4388703291081927, + "MI_w_7i": -1.2006659389428078, + "MI_w_8r": 3.8879823702434892, + "MI_w_8i": 2.574115900343803, + "MI_w_9r": 10.533495767231539, + "MI_w_9i": -2.4450624359388207, + "A->MI.DMI->B.C_total_0r": 1.0, + "A->MI.DMI->B.C_total_0i": 0.0, + "A->MI.D_g_ls_0r": 1.0, + "A->MI.D_g_ls_0i": 0.0, + "MI->B.C_g_ls_0r": 1.0, + "MI->B.C_g_ls_0i": 0.0 +} diff --git a/checks/fit_shape/final_params_relu.json b/checks/fit_shape/final_params_relu.json new file mode 100644 index 00000000..1af0655f --- /dev/null +++ b/checks/fit_shape/final_params_relu.json @@ -0,0 +1,38 @@ +{ + "MI_b_0": -1.0622250950337662, + "MI_b_1": -0.8789283255768763, + "MI_b_2": 3.0599096017088203, + "MI_b_3": -1.285576319474175, + "MI_b_4": -0.255100410428076, + "MI_b_5": -0.5535368004801267, + "MI_b_6": 7.558983658698586, + "MI_b_7": 0.07110322552772569, + "MI_b_8": 0.2590192741152075, + "MI_b_9": -0.07886432386304608, + "MI_w_0r": 5.3129601453575965, + "MI_w_0i": 6.89330400075596, + "MI_w_1r": 4.076189414869148, + "MI_w_1i": 4.72169072366533, + "MI_w_2r": 2.2345575394656794, + "MI_w_2i": 1.5674140496474243, + "MI_w_3r": 3.1181489763446524, + "MI_w_3i": 2.5713070284633615, + "MI_w_4r": 5.664131964810804, + "MI_w_4i": 7.217921257894127, + "MI_w_5r": -2.5927805257104555, + "MI_w_5i": 5.731281787333332, + "MI_w_6r": 0.6947219631515521, + "MI_w_6i": -1.3812766994545103, + "MI_w_7r": -8.411867876040107, + "MI_w_7i": 6.974832219369253, + "MI_w_8r": 4.313161954304356, + "MI_w_8i": 2.171071429023437, + "MI_w_9r": 8.808426082864889, + "MI_w_9i": 5.556405458272451, + "A->MI.DMI->B.C_total_0r": 1.0, + "A->MI.DMI->B.C_total_0i": 0.0, + "A->MI.D_g_ls_0r": 1.0, + "A->MI.D_g_ls_0i": 0.0, + "MI->B.C_g_ls_0r": 1.0, + "MI->B.C_g_ls_0i": 0.0 +} diff --git a/checks/fit_shape/fit_shape.py b/checks/fit_shape/fit_shape.py new file mode 100644 index 00000000..f5d2a872 --- /dev/null +++ b/checks/fit_shape/fit_shape.py @@ -0,0 +1,69 @@ +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf + +from tf_pwa.config_loader import ConfigLoader +from tf_pwa.utils import time_print + +config = ConfigLoader("config.yml") + +config.set_params("init_params.json") + +config_mi = ConfigLoader("config_mi.yml") +config_mi.get_params() + +# config_mi.set_params("final_params.json") + +config_mi.vm.set_fix("MI_w_0r", unfix=True) +config_mi.vm.set_fix("MI_w_0i", unfix=True) + + +f1 = config.get_particle_function("BC1") +f2 = config.get_particle_function("BC2") + +f = config_mi.get_particle_function("MI") + +m = f.mass_linspace(10000) + + +fast_f = f.cached_call(m) +target_f = f1(m) + f2(m) + + +plot_count = 1 + + +def f_loss(): + ret = tf.reduce_sum(tf.abs(fast_f() - target_f) ** 2) + global plot_count + if plot_count % 10 == 1: + print(ret) + plot_count += 1 + return ret + + +best_params = {} +best_loss = np.inf +best_fit_result = None +for i in range(1): + fit_result = time_print(config_mi.vm.minimize)(f_loss) + if fit_result.fun < best_loss: + best_loss = fit_result.fun + best_params = config_mi.get_params() + best_fit_result = fit_result + # reset random parameters + config_mi2 = ConfigLoader("config_mi.yml") + config_mi.set_params(config_mi2.get_params()) + +config_mi.set_params(best_params) +config_mi.save_params("final_params.json") + +print(best_fit_result) + +plt.plot(m, tf.math.imag(f(m)).numpy(), label="image fit") +plt.plot(m, tf.math.imag(f1(m) + f2(m)).numpy(), label="imag target") +plt.plot(m, tf.math.real(f(m)).numpy(), label="real fit") +plt.plot(m, tf.math.real(f1(m) + f2(m)).numpy(), label="real target") + +plt.legend() +plt.savefig("fit_results.png") diff --git a/checks/fit_shape/init_params.json b/checks/fit_shape/init_params.json new file mode 100644 index 00000000..529bdf34 --- /dev/null +++ b/checks/fit_shape/init_params.json @@ -0,0 +1,18 @@ +{ + "BC1_mass": 1.0, + "BC1_width": 0.5, + "BC2_mass": 2.0, + "BC2_width": 0.5, + "A->BC1.DBC1->B.C_total_0r": 1.0, + "A->BC1.DBC1->B.C_total_0i": 0.0, + "A->BC1.D_g_ls_0r": 1.0, + "A->BC1.D_g_ls_0i": 0.0, + "BC1->B.C_g_ls_0r": 1.0, + "BC1->B.C_g_ls_0i": 0.0, + "A->BC2.DBC2->B.C_total_0r": 1.0, + "A->BC2.DBC2->B.C_total_0i": 1.5, + "A->BC2.D_g_ls_0r": 1.0, + "A->BC2.D_g_ls_0i": 0.0, + "BC2->B.C_g_ls_0r": 1.0, + "BC2->B.C_g_ls_0i": 0.0 +} diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index 03e5846c..450d2751 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -495,6 +495,54 @@ def get_amp(self, data, _data_c=None, **kwargs): return tf.math.polyval(pi, mass) +@regist_particle("MLP") +class ParticleMLP(Particle): + """ + Multilayer Perceptron like model. + + .. math:: + R(m) = \\sum_{k} w_k activation(m-m_0+b_k) + + lineshape when `interp_N: 11`, `activation: relu`, :math:`b_k=(k-5)/10`, :math:`w_k = exp(k i\\pi/2)` + + .. plot:: + + >>> import matplotlib.pyplot as plt + >>> import numpy as np + >>> plt.clf() + >>> from tf_pwa.utils import plot_particle_model + >>> plot_params = {f"R_BC_b_{i}": (i-5)/10 for i in range(11)} + >>> plot_params.update({f"R_BC_w_{i}r": 1 for i in range(11)}) + >>> plot_params.update({f"R_BC_w_{i}i": i * np.pi/2 for i in range(11)}) + >>> axis = plot_particle_model("MLP", params={"interp_N": 11, "activation": "relu"}, plot_params=plot_params) + + """ + + activation_function = { + "relu2": lambda x: tf.nn.relu(x) ** 2, + "relu3": lambda x: tf.nn.relu(x) ** 3, + } + + def init_params(self): + self.interp_N = getattr(self, "interp_N", 3) + self.activation = getattr(self, "activation", "leaky_relu") + self.activation_f = ParticleMLP.activation_function.get( + self.activation, getattr(tf.nn, self.activation) + ) + self.bi = self.add_var("b", shape=(self.interp_N,)) + self.wi = self.add_var("w", shape=(self.interp_N,), is_complex=True) + self.wi.set_fix_idx(fix_idx=0, fix_vals=(1.0, 0.0)) + + def get_amp(self, data, _data_c=None, **kwargs): + mass = data["m"] - self.get_mass() + bi = tf.stack(self.bi()) + wi = tf.stack(self.wi()) + x = tf.expand_dims(mass, axis=-1) + bi + x = self.activation_f(x) + ret = tf.reduce_sum(wi * tf.complex(x, tf.zeros_like(x)), axis=-1) + return ret + + @regist_decay("particle-decay") class ParticleDecay(HelicityDecay): def get_ls_amp(self, data, data_p, **kwargs): diff --git a/tf_pwa/config_loader/particle_function.py b/tf_pwa/config_loader/particle_function.py index 8c405f6f..7b634e1d 100644 --- a/tf_pwa/config_loader/particle_function.py +++ b/tf_pwa/config_loader/particle_function.py @@ -28,6 +28,17 @@ def __call__(self, m, random=False): ret = a[self.idx] return self.norm_factor * ret + def cached_call(self, m, **kwargs): + p = self.ha.generate_p_mass(self.name, m, **kwargs) + data = self.config.data.cal_angle(p) + + def f(): + a = build_amp.build_params_vector(self.decay_group, data) + ret = a[self.idx] + return self.norm_factor * ret + + return f + def mass_range(self): return self.ha.get_mass_range(self.name) diff --git a/tf_pwa/config_loader/tests/test_particle_function.py b/tf_pwa/config_loader/tests/test_particle_function.py index 0bc24b20..4c20f16e 100644 --- a/tf_pwa/config_loader/tests/test_particle_function.py +++ b/tf_pwa/config_loader/tests/test_particle_function.py @@ -11,6 +11,8 @@ def test_particle(toy_config): assert f(m).shape == (5, 6) g = toy_config.get_particle_function("R_BC", d_norm=True) g.phsp_fractor(m) + cached_g = g.cached_fun(m) + assert np.allclose(g(m).numpy(), cached_g().numpy()) g.density(m) assert g(m).shape == (5, 6) m_min, m_max = g.mass_range() From db9397260af67e682bc48b55b3b2e4402a3e6f67 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Mon, 10 Jun 2024 22:17:07 +0800 Subject: [PATCH 2/2] ci: fixed ci error --- tf_pwa/config_loader/tests/test_particle_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_pwa/config_loader/tests/test_particle_function.py b/tf_pwa/config_loader/tests/test_particle_function.py index 4c20f16e..6efa5eb1 100644 --- a/tf_pwa/config_loader/tests/test_particle_function.py +++ b/tf_pwa/config_loader/tests/test_particle_function.py @@ -11,7 +11,7 @@ def test_particle(toy_config): assert f(m).shape == (5, 6) g = toy_config.get_particle_function("R_BC", d_norm=True) g.phsp_fractor(m) - cached_g = g.cached_fun(m) + cached_g = g.cached_call(m) assert np.allclose(g(m).numpy(), cached_g().numpy()) g.density(m) assert g(m).shape == (5, 6)