Skip to content

Commit

Permalink
Merge pull request #110 from jiangyi15/hel_frac
Browse files Browse the repository at this point in the history
Support factor for model using helicity directly
  • Loading branch information
jiangyi15 authored Sep 11, 2023
2 parents fbb3925 + 987c055 commit c2ea2d5
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 51 deletions.
34 changes: 34 additions & 0 deletions tf_pwa/amp/amp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import warnings

import numpy as np
import tensorflow as tf

from tf_pwa.amp.core import Variable, variable_scope
Expand Down Expand Up @@ -280,6 +281,39 @@ def pdf(self, data):
return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape))))


@register_amp_model("base_factor")
class FactorAmplitudeModel(BaseAmplitudeModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_amp_list(self, data):
m_dep = self.decay_group.get_m_dep(data)
if "cached_angle" in data:
angle_amp = data["cached_angle"]
else:
angle_amp = self.decay_group.get_factor_angle_amp(data)
ret = []
for a, b in zip(m_dep, angle_amp):
tmp = b
for i in a:
total_size = np.prod(tmp.shape[1:])
if len(i.shape) == 1:
i = tf.expand_dims(i, axis=-1)
tmp = tf.reshape(
tmp, (-1, i.shape[-1], total_size // i.shape[-1])
)
tmp = tmp * tf.expand_dims(i, axis=-1)
tmp = tf.reduce_sum(tmp, axis=-2)
ret.append(tmp)
return ret

def pdf(self, data):
ret = self.get_amp_list(data)
amp = tf.reduce_sum(ret, axis=0)
amp2s = tf.math.real(amp * tf.math.conj(amp))
return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape))))


@register_amp_model("p4_directly")
class P4DirectlyAmplitudeModel(BaseAmplitudeModel):
def cal_angle(self, p4):
Expand Down
76 changes: 67 additions & 9 deletions tf_pwa/amp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,20 +508,68 @@ def init_params(self):
a = self.outs[0].spins
b = self.outs[1].spins
self.H = self.add_var("H", is_complex=True, shape=(len(a), len(b)))
self.fix_unused_h()

def get_helicity_amp(self, data, data_p, **kwargs):
def get_zero_index(self):
a = self.outs[0].spins
b = self.outs[1].spins
fix_index = []
free_index = []
for idx_i, i in zip(range(self.H.shape[-2]), a):
for idx_j, j in zip(range(self.H.shape[-1]), b):
if abs(i - j) > self.core.J:
fix_index.append((idx_i, idx_j))
else:
free_index.append((idx_i, idx_j))
return fix_index, free_index

def fix_unused_h(self):
fix_index, free_idx = self.get_zero_index()
self.H.set_fix_idx(fix_index, 0.0)
self.H.set_fix_idx([free_idx[0]], 1.0)

def get_H_zero_mask(self):
fix_index, free_idx = self.get_zero_index()

def get_factor(self):
_, free_index = self.get_zero_index()
H = self.H()
return tf.gather_nd(H, free_index)

def get_H(self):
if self.mask_factor:
H = tf.stack(self.H())
_, free_idx = self.get_zero_index()
return tf.scatter_nd(
indices=free_idx,
updates=tf.ones(len(free_idx), dtype=H.dtype),
shape=H.shape,
)
return tf.stack(self.H())

def get_helicity_amp(self, data=None, data_p=None, **kwargs):
return self.get_H()

def get_ls_amp(self, data, data_p, **kwargs):
return tf.reshape(self.get_factor(), (1, -1))

def get_factor_H(self, data=None, data_p=None, **kwargs):
_, free_idx = self.get_zero_index()
H = self.get_helicity_amp()
value = tf.gather_nd(H, free_idx)
new_idx = [(i, *j) for i, j in enumerate(free_idx)]
return tf.scatter_nd(
indices=new_idx, updates=value, shape=(len(free_idx), *H.shape)
)


@regist_decay("helicity_full-bf")
class HelicityDecayNPbf(HelicityDecay):
class HelicityDecayNPbf(HelicityDecayNP):
def init_params(self):
self.d = 3.0
a = self.outs[0].spins
b = self.outs[1].spins
self.H = self.add_var("H", is_complex=True, shape=(len(a), len(b)))
super().init_params()

def get_helicity_amp(self, data, data_p, **kwargs):
def get_H_barrier_factor(self, data, data_p, **kwargs):
q0 = self.get_relative_momentum(data_p, False)
data["|q0|"] = q0
if "|q|" in data:
Expand All @@ -530,18 +578,27 @@ def get_helicity_amp(self, data, data_p, **kwargs):
q = self.get_relative_momentum(data_p, True)
data["|q|"] = q
bf = barrier_factor([min(self.get_l_list())], q, q0, self.d)
H = tf.stack(self.H())
return bf

def get_helicity_amp(self, data, data_p, **kwargs):
H = self.get_H()
bf = self.get_H_barrier_factor(data, data_p, **kwargs)
bf = tf.cast(tf.reshape(bf, (-1, 1, 1)), H.dtype)
return H * bf

def get_ls_amp(self, data, data_p, **kwargs):
bf = self.get_H_barrier_factor(data, data_p, **kwargs)
f = tf.reshape(self.get_factor(), (1, -1))
return f * tf.expand_dims(tf.cast(bf, f.dtype), axis=-1)


def get_parity_term(j1, p1, j2, p2, j3, p3):
p = p1 * p2 * p3 * (-1) ** (j1 - j2 - j3)
return p


@regist_decay("helicity_parity")
class HelicityDecayP(HelicityDecay):
class HelicityDecayP(HelicityDecayNP):
"""
.. math::
Expand All @@ -566,11 +623,12 @@ def init_params(self):
"H", is_complex=True, shape=(n_b, (n_c + 1) // 2)
)
self.part_H = 1
self.fix_unused_h()

def get_helicity_amp(self, data, data_p, **kwargs):
n_b = len(self.outs[0].spins)
n_c = len(self.outs[1].spins)
H_part = tf.stack(self.H())
H_part = self.get_H()
if self.part_H == 0:
H = tf.concat(
[
Expand Down
151 changes: 151 additions & 0 deletions tf_pwa/amp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ def get_width(self):
return self.width()
return self.width

def get_factor(self):
return None

def get_sympy_var(self):
return sym.var("m m0 g0 m1 m2")

Expand Down Expand Up @@ -699,6 +702,9 @@ def init_params(self):
def get_factor_variable(self):
return [(self.g_ls,)]

def get_factor(self):
return self.get_g_ls()

def _get_particle_mass(self, p, data, from_data=False):
if from_data and p in data:
return data[p]["m"]
Expand Down Expand Up @@ -875,6 +881,39 @@ def get_angle_helicity_amp(self, data, data_p, **kwargs):
)
return ret

def get_factor_H(self, data, data_p, **kwargs): # -> (n, n_ls, h1, h2)
m_dep = self.get_angle_ls_amp(data, data_p, **kwargs) # (n,l)
cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype)
n_ls = len(self.get_ls_list())
m_dep = tf.reshape(m_dep, (-1, n_ls, 1, 1))
cg_trans = tf.reshape(
cg_trans, (n_ls, len(self.outs[0].spins), len(self.outs[1].spins))
)
# H = tf.reduce_sum(m_dep * cg_trans, axis=1)
H = m_dep * cg_trans # (n, n_ls, h1, h2)
return H

def get_factor_angle_helicity_amp(self, data, data_p, **kwargs):
H = self.get_factor_H(data, data_p, **kwargs)
if self.allow_cc:
all_data = kwargs.get("all_data", {})
charge = all_data.get("charge_conjugation", None)
if charge is not None:
H = tf.where(
charge[..., None, None] > 0, H, H[..., ::-1, ::-1]
)
ret = tf.reshape(
H,
(
-1,
H.shape[-3],
1,
len(self.outs[0].spins),
len(self.outs[1].spins),
),
)
return ret

def get_g_ls(self):
gls = self.g_ls()
if self.ls_index is None:
Expand Down Expand Up @@ -1042,9 +1081,28 @@ def get_angle_amp(self, data, data_p, **kwargs):
ret = tf.reduce_sum(ret, axis=j + 2)
return ret

def get_factor_angle_amp(self, data, data_p, **kwargs):
a = self.core
b = self.outs[0]
c = self.outs[1]
ang = data[b]["ang"]
D_conj = get_D_matrix_lambda(ang, a.J, a.spins, b.spins, c.spins)
H = self.get_factor_angle_helicity_amp(data, data_p, **kwargs)
H = tf.cast(H, dtype=D_conj.dtype)
D_conj = tf.reshape(D_conj, (-1, 1, *D_conj.shape[1:]))
ret = H * tf.stop_gradient(D_conj)
# print(self, H, D_conj)
# exit()
if self.aligned:
raise NotImplemented
return ret

def get_m_dep(self, data, data_p, **kwargs):
return self.get_ls_amp(data, data_p, **kwargs)

def get_factor_m_dep(self, data, data_p, **kwargs):
return self.get_ls_amp(data, data_p, **kwargs)

def get_ls_list(self):
"""get possible ls for decay, with l_list filter possible l"""
ls_list = super(HelicityDecay, self).get_ls_list()
Expand Down Expand Up @@ -1120,6 +1178,17 @@ def get_factor_variable(self):
a.append(tmp)
return [tuple([self.total] + a)]

def get_factor(self): # (total, decay1, decay2, ...)
decay_factor = [i.get_factor() for i in self]
particle_factor = [i.get_factor() for i in self.inner]
all_factor = decay_factor + particle_factor
all_factor = [i for i in all_factor if i is not None]
all_factor = all_factor
ret = self.get_amp_total()
for i in all_factor:
ret = tf.expand_dims(ret, axis=-1) * tf.cast(i, ret.dtype)
return ret

def get_amp_total(self, charge=1):
if self.mask_factor:
return tf.ones_like(tf.stack(self.total(charge)))
Expand Down Expand Up @@ -1230,6 +1299,51 @@ def get_angle_amp(self, data_c, data_p, all_data=None, base_map=None):
# ret = einsum(idx_s, *amp_d)
return ret

def get_factor_angle_amp(
self, data_c, data_p, all_data=None, base_map=None
):
base_map = self.get_base_map(base_map)
iter_idx = ["..."]
amp_d = []
indices = []
next_map = "zyxwvutsr"
used_idx = ""
final_indices = self.amp_index(base_map)
for i in self:
tmp_idx = i.amp_index(base_map)
tmp_idx = [next_map[0], *tmp_idx]
indices.append(tmp_idx)
used_idx += next_map[0]
amp_d.append(
i.get_factor_angle_amp(data_c[i], data_p, all_data=all_data)
)
next_map = next_map[1:]
final_indices = "".join(iter_idx + list(used_idx) + final_indices)

if self.aligned:
for i in self:
for j in i.outs:
if j.J != 0 and "aligned_angle" in data_c[i][j]:
ang = data_c[i][j]["aligned_angle"]
dt = get_D_matrix_lambda(ang, j.J, j.spins, j.spins)
amp_d.append(tf.stop_gradient(dt))
idx = [base_map[j], base_map[j].upper()]
indices.append(idx)
final_indices = final_indices.replace(*idx)
idxs = []
for i in indices:
tmp = "".join(iter_idx + i)
idxs.append(tmp)
idx = ",".join(idxs)
idx_s = "{}->{}".format(idx, final_indices)
# ret = amp * tf.reshape(rs, [-1] + [1] * len(self.amp_shape()))
# print(idx_s) # , amp_d)
ret = tf.einsum(idx_s, *amp_d)
# print(self, ret[0])
# exit()
# ret = einsum(idx_s, *amp_d)
return ret

def get_m_dep(self, data_c, data_p, all_data=None, base_map=None):
base_map = self.get_base_map(base_map)
iter_idx = ["..."]
Expand Down Expand Up @@ -1367,6 +1481,12 @@ def get_factor_variable(self):
ret += i.get_factor_variable()
return ret

def get_factor(self):
ret = []
for i in self:
ret.append(i.get_factor())
return ret

def get_amp(self, data):
"""
calculate the amplitude as complex number
Expand Down Expand Up @@ -1466,6 +1586,37 @@ def get_angle_amp(self, data):
# ret = tf.reduce_sum(ret, axis=0)
return amp

def get_factor_angle_amp(self, data):
data_particle = data["particle"]
data_decay = data["decay"]

used_chains = tuple([self.chains[i] for i in self.chains_idx])
chain_maps = self.get_chains_map(used_chains)
base_map = self.get_base_map()
ret = []
for decay_chain in used_chains:
for chains in chain_maps:
if str(decay_chain) in [str(i) for i in chains]:
maps = chains[decay_chain]
break
chain_topo = decay_chain.standard_topology()
found = False
for i in data_decay.keys():
if i == chain_topo:
data_decay_i = data_decay[i]
found = True
break
if not found:
raise KeyError("not found {}".format(chain_topo))
data_c = rename_data_dict(data_decay_i, maps)
data_p = rename_data_dict(data_particle, maps)
amp = decay_chain.get_factor_angle_amp(
data_c, data_p, base_map=base_map, all_data=data
)
ret.append(amp)
# ret = tf.reduce_sum(ret, axis=0)
return ret

@functools.lru_cache()
def get_swap_factor(self, key):
factor = 1.0
Expand Down
Loading

0 comments on commit c2ea2d5

Please sign in to comment.