Skip to content

Commit

Permalink
Merge pull request #98 from jiangyi15/sym_hel
Browse files Browse the repository at this point in the history
Symbolic formula for LS to Helicity
  • Loading branch information
jiangyi15 authored Aug 21, 2023
2 parents f5db1ef + e63e323 commit ec64215
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 15 deletions.
14 changes: 13 additions & 1 deletion tf_pwa/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,19 @@ class P4DirectlyAmplitudeModel(BaseAmplitudeModel):
def cal_angle(self, p4):
from tf_pwa.cal_angle import cal_angle_from_momentum

ret = cal_angle_from_momentum(p4, self.decay_group)
extra_kwargs = self.extra_kwargs["all_config"]
kwargs = {}
for k in [
"center_mass",
"r_boost",
"random_z",
"align_ref",
"only_left_angle",
]:
if k in extra_kwargs:
kwargs[k] = extra_kwargs[k]

ret = cal_angle_from_momentum(p4, self.decay_group, **kwargs)
return ret

def pdf(self, data):
Expand Down
56 changes: 49 additions & 7 deletions tf_pwa/amp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def get_factor_variable(self):
return [(self.g_ls,)]

def _get_particle_mass(self, p, data, from_data=False):
if from_data:
if from_data and p in data:
return data[p]["m"]
if p.mass is None:
p.mass = tf.reduce_mean(data[p]["m"])
Expand Down Expand Up @@ -743,12 +743,12 @@ def get_relative_momentum2(self, data, from_data=False):
ret = get_relative_p2(m0, m1, m2)
return ret

def get_cg_matrix(self):
def get_cg_matrix(self, out_sym=False):
ls = self.get_ls_list()
return self._get_cg_matrix(ls)
return self._get_cg_matrix(ls, out_sym=out_sym)

@functools.lru_cache()
def _get_cg_matrix(self, ls): # CG factor inside H
def _get_cg_matrix(self, ls, out_sym=False): # CG factor inside H
"""
[(l,s),(lambda_b,lambda_c)]
Expand All @@ -765,6 +765,18 @@ def _get_cg_matrix(self, ls): # CG factor inside H
self.outs[1].spins
) # _spin_int(2 * jb + 1), _spin_int(2 * jc + 1)
ret = np.zeros(shape=(m, *n))
sqrt = np.sqrt
my_cg_coef = cg_coef
if out_sym:
from sympy.physics.quantum.cg import CG

sint = lambda x: sym.simplify(_spin_int(x * 2)) / 2
sqrt = lambda x: sym.sqrt(sint(x))

def my_cg_coef(a, b, c, d, e, f):
return CG(*list(map(sint, [a, c, b, d, e, f]))).doit()

ret = ret.tolist()
for i, ls_i in enumerate(ls):
l, s = ls_i
for i1, lambda_b in enumerate(
Expand All @@ -774,11 +786,12 @@ def _get_cg_matrix(self, ls): # CG factor inside H
self.outs[1].spins
): # _spin_range(-jc, jc)):
ret[i][i1][i2] = (
np.sqrt((2 * l + 1) / (2 * ja + 1))
* cg_coef(
sqrt(2 * l + 1)
/ sqrt(2 * ja + 1)
* my_cg_coef(
jb, jc, lambda_b, -lambda_c, s, lambda_b - lambda_c
)
* cg_coef(
* my_cg_coef(
l,
s,
0,
Expand All @@ -789,6 +802,35 @@ def _get_cg_matrix(self, ls): # CG factor inside H
)
return ret

def build_ls2hel_eq(self):
cg_matrix = self.get_cg_matrix(out_sym=True)
gls = []
for l, s in self.get_ls_list():
gls.append(sym.Symbol("g_{}_{}".format(l, s)))
hel = []
eqs = []
for ib, lb in enumerate(self.outs[0].spins):
for ic, lc in enumerate(self.outs[1].spins):
tmp = sym.Symbol("H_{}_{}".format(lb, lc))
rhs = 0
for idx, gi in enumerate(gls):
rhs = rhs + cg_matrix[idx][ib][ic] * gi
if rhs == 0:
continue
eq = sym.Eq(tmp, sym.simplify(rhs))
hel.append(tmp)
eqs.append(eq)
return [gls, hel, eqs]

def build_simple_data(self):
data_p = {self.core: {"m": self.core.get_mass()}}
data = {}
zero = np.array(0.0)
for i in self.outs:
data_p[i] = {"m": i.get_mass()}
data[i] = {"ang": {"alpha": zero, "beta": zero, "gamma": zero}}
return {"data": data, "data_p": data_p}

def get_helicity_amp(self, data, data_p, **kwargs):
m_dep = self.get_ls_amp(data, data_p, **kwargs)
cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype)
Expand Down
1 change: 1 addition & 0 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def get_amplitude(self, vm=None, name=""):
jit_compile=jit_compile,
model=amp_model,
cached_shape_idx=cached_shape_idx,
all_config=amp_config,
)
self.add_constraints(amp)
self.amps[vm] = amp
Expand Down
4 changes: 2 additions & 2 deletions tf_pwa/tests/config_toy2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ particle:
B: { J: 1, P: -1, mass: 2.00698 }
C: { J: 1, P: -1, mass: 2.01028 }
D: { J: 0, P: -1, mass: 0.13957 }
R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1, model: BWR2 }
R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3, model: BW }
R_BC: { J: 1, Par: 1, m0: 4.16, g0: 0.1 }
R_BD: { J: 1, Par: 1, m0: 2.43, g0: 0.3 }
R_CD: { J: 1, Par: 1, m0: 2.42, g0: 0.03 }

constrains:
Expand Down
2 changes: 2 additions & 0 deletions tf_pwa/tests/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def test_decay_ls_amp():
data_p = {a: {"m": ma}, c: {"m": mb}, d: {"m": mb}}
dec1.get_ls_amp(data, data_p)
dec1.get_ls_amp({**data, "|q|2": mb}, data_p)
dec1.get_ls_amp(**dec1.build_simple_data())
dec1.build_ls2hel_eq()


def test_polarization():
Expand Down
15 changes: 10 additions & 5 deletions tf_pwa/tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def toy_config2(gen_toy, fit_result):
config = MultiConfig(
[f"{this_dir}/config_toy.yml", f"{this_dir}/config_toy2.yml"]
)
config.set_params(f"{this_dir}/exp2_params.json")
config.set_params(f"{this_dir}/exp_params.json")
return config


Expand All @@ -104,6 +104,7 @@ def toy_config3(gen_toy):
@pytest.fixture
def fit_result(toy_config):
ret = toy_config.fit()
assert np.allclose(ret.min_nll, -204.9468493307786)
return ret


Expand Down Expand Up @@ -204,7 +205,8 @@ def test_fit_lazy_call(gen_toy):
config_dic["data"]["lazy_call"] = True
config = ConfigLoader(config_dic)
config.set_params(f"{this_dir}/exp_params.json")
config.fit(print_init_nll=False)
results = config.fit(print_init_nll=False)
assert np.allclose(results.min_nll, -204.9468493307786)
fcn = config.get_fcn()
fcn.nll_grad()
config.plot_partial_wave(prefix="toy_data/figure/s2")
Expand Down Expand Up @@ -324,7 +326,8 @@ def test_bacth_sum(toy_config, fit_result):


def test_lazycall(toy_config_lazy):
toy_config_lazy.fit(batch=100000)
results = toy_config_lazy.fit(batch=100000)
assert np.allclose(results.min_nll, -204.9468493307786)
toy_config_lazy.plot_partial_wave(
prefix="toy_data/figure_lazy", batch=100000
)
Expand All @@ -339,13 +342,15 @@ def test_cal_signal_yields(toy_config, fit_result):


def test_fit_combine(toy_config2):
toy_config2.fit()
results = toy_config2.fit()
print(results)
assert np.allclose(results.min_nll, -204.9468493307786 * 2)
toy_config2.get_params_error()
print(toy_config2.get_params())


def test_mix_likelihood(toy_config3):
toy_config3.fit(maxiter=1)
results = toy_config3.fit(maxiter=1)


def test_cp_particles():
Expand Down

0 comments on commit ec64215

Please sign in to comment.