diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index 994ee03..03e5846 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -466,6 +466,35 @@ def get_amp(self, data, _data_c=None, **kwargs): return tf.exp(r) +@regist_particle("poly") +class ParticlePoly(Particle): + """ + .. math:: + R(m) = \\sum c_i (m-m_0)^{n-i} + + lineshape when :math:`c_0=1, c_1=c_2=0` + + .. plot:: + + >>> import matplotlib.pyplot as plt + >>> plt.clf() + >>> from tf_pwa.utils import plot_particle_model + >>> axis = plot_particle_model("poly", params={"n_order": 2}, plot_params={"R_BC_c_1r": 0., "R_BC_c_2r": 0., "R_BC_c_1i": 0., "R_BC_c_2i": 0.}) + + """ + + def init_params(self): + self.n_order = getattr(self, "n_order", 3) + self.pi = self.add_var("c", shape=(self.n_order + 1,), is_complex=True) + self.pi.set_fix_idx(fix_idx=0, fix_vals=(1.0, 0.0)) + + def get_amp(self, data, _data_c=None, **kwargs): + mass = data["m"] - self.get_mass() + pi = list(self.pi()) + mass = tf.complex(mass, tf.zeros_like(mass)) + return tf.math.polyval(pi, mass) + + @regist_decay("particle-decay") class ParticleDecay(HelicityDecay): def get_ls_amp(self, data, data_p, **kwargs):