From a2f744e274c66d36491ad738bb3182d32b4f9550 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Thu, 7 Sep 2023 22:02:21 +0800 Subject: [PATCH 1/5] misc: add constrains on helicity_full --- tf_pwa/amp/base.py | 24 +++++++++++--- tf_pwa/variable.py | 82 ++++++++++++++++++++++------------------------ 2 files changed, 59 insertions(+), 47 deletions(-) diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index 0c2c223a..7207f0ee 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -508,18 +508,31 @@ def init_params(self): a = self.outs[0].spins b = self.outs[1].spins self.H = self.add_var("H", is_complex=True, shape=(len(a), len(b))) + self.fix_unused_h() + + def fix_unused_h(self): + a = self.outs[0].spins + b = self.outs[1].spins + min_idx = None + fix_index = [] + for idx_i, i in enumerate(a): + for idx_j, j in enumerate(b): + if abs(i - j) > self.core.J: + fix_index.append((idx_i, idx_j)) + elif min_idx is None: + min_idx = idx_i, idx_j + self.H.set_fix_idx(fix_index, 0.0) + self.H.set_fix_idx([min_idx], 1.0) def get_helicity_amp(self, data, data_p, **kwargs): return tf.stack(self.H()) @regist_decay("helicity_full-bf") -class HelicityDecayNPbf(HelicityDecay): +class HelicityDecayNPbf(HelicityDecayNP): def init_params(self): self.d = 3.0 - a = self.outs[0].spins - b = self.outs[1].spins - self.H = self.add_var("H", is_complex=True, shape=(len(a), len(b))) + super().init_params() def get_helicity_amp(self, data, data_p, **kwargs): q0 = self.get_relative_momentum(data_p, False) @@ -541,7 +554,7 @@ def get_parity_term(j1, p1, j2, p2, j3, p3): @regist_decay("helicity_parity") -class HelicityDecayP(HelicityDecay): +class HelicityDecayP(HelicityDecayNP): """ .. math:: @@ -566,6 +579,7 @@ def init_params(self): "H", is_complex=True, shape=(n_b, (n_c + 1) // 2) ) self.part_H = 1 + self.fix_unused_h() def get_helicity_amp(self, data, data_p, **kwargs): n_b = len(self.outs[0].spins) diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 30e5c1c3..498afa7a 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -1494,31 +1494,16 @@ def freed(self): else: raise Exception("Only shape==() var supports 'freed' method.") - def set_fix_idx(self, fix_idx=None, fix_vals=None, free_idx=None): - """ - :param fix_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is fixed. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[1,2]``, then Variable()[i][j][1] and Variable()[i][j][2] will be the fixed value. - :param fix_vals: Float or length-2 float list for complex variable. The fixed value. - :param free_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is set free. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[0]``, then Variable()[i][j][0] will be set free. - """ - if not self.shape: - raise Exception( - "Only shape!=() var supports 'set_fix_idx' method to fix or free variables." - ) - if free_idx is None: - free_idx = [] - else: - free_idx = free_idx % self.shape[-1] + def _set_fix_idx(self, fix_idx=None, fix_vals=None, unfix=False): if fix_idx is None: fix_idx = [] else: - fix_idx = fix_idx % self.shape[-1] - - if not hasattr(fix_idx, "__len__"): - fix_idx = [fix_idx] - fix_idx_str = ["_" + str(i) for i in fix_idx] - if not hasattr(free_idx, "__len__"): - free_idx = [free_idx] - free_idx_str = ["_" + str(i) for i in free_idx] + fix_idx = np.array(fix_idx) + if len(fix_idx.shape) <= 1: + fix_idx = np.reshape(fix_idx, (-1, 1)) + shape_mod = [self.shape[-i - 1] for i in range(fix_idx.shape[-1])] + fix_idx = fix_idx % shape_mod + fix_idx_str = ["_" + "_".join(str(j) for j in i) for i in fix_idx] if self.cp_effect: print( @@ -1541,19 +1526,22 @@ def func(name, idx): for ss in fix_idx_str: if name.endswith(ss): # print("set_fix_idx set name+r ", name) - self.vm.set_fix(name + "r", value=fix_vals[0]) - self.vm.set_fix(name + "i", value=fix_vals[1]) + self.vm.set_fix( + name + "r", value=fix_vals[0], unfix=unfix + ) + self.vm.set_fix( + name + "i", value=fix_vals[1], unfix=unfix + ) if len(fix_vals) > 2: - self.vm.set_fix(name + "deltar", value=fix_vals[2]) - self.vm.set_fix(name + "deltai", value=fix_vals[3]) - for ss in free_idx_str: - if name.endswith(ss): - self.vm.set_fix(name + "r", unfix=True) - self.vm.set_fix(name + "i", unfix=True) - self.vm.set_fix(name + "deltar", unfix=True) - self.vm.set_fix(name + "deltai", unfix=True) + self.vm.set_fix( + name + "deltar", value=fix_vals[2], unfix=unfix + ) + self.vm.set_fix( + name + "deltai", value=fix_vals[3], unfix=unfix + ) _shape_func(func, self.shape, self.name) + elif self.cplx: if fix_vals is None: fix_vals = [None, None] @@ -1563,12 +1551,12 @@ def func(name, idx): def func(name, idx): for ss in fix_idx_str: if name.endswith(ss): - self.vm.set_fix(name + "r", value=fix_vals[0]) - self.vm.set_fix(name + "i", value=fix_vals[1]) - for ss in free_idx_str: - if name.endswith(ss): - self.vm.set_fix(name + "r", unfix=True) - self.vm.set_fix(name + "i", unfix=True) + self.vm.set_fix( + name + "r", value=fix_vals[0], unfix=unfix + ) + self.vm.set_fix( + name + "i", value=fix_vals[1], unfix=unfix + ) _shape_func(func, self.shape, self.name) else: @@ -1576,13 +1564,23 @@ def func(name, idx): def func(name, idx): for ss in fix_idx_str: if name.endswith(ss): - self.vm.set_fix(name, value=fix_vals) - for ss in free_idx_str: - if name.endswith(ss): - self.vm.set_fix(name, unfix=True) + self.vm.set_fix(name, value=fix_vals, unfix=unfix) _shape_func(func, self.shape, self.name) + def set_fix_idx(self, fix_idx=None, fix_vals=None, free_idx=None): + """ + :param fix_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is fixed. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[1,2]``, then Variable()[i][j][1] and Variable()[i][j][2] will be the fixed value. + :param fix_vals: Float or length-2 float list for complex variable. The fixed value. + :param free_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is set free. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[0]``, then Variable()[i][j][0] will be set free. + """ + if not self.shape: + raise Exception( + "Only shape!=() var supports 'set_fix_idx' method to fix or free variables." + ) + self._set_fix_idx(fix_idx, fix_vals) + self._set_fix_idx(free_idx, unfix=True) + def set_bound(self, bound, func=None, overwrite=False): """ Set boundary for this Variable. Note only non-shape real Variable supports this method. From 38a44c1c25c786bb72f879bafe6646dfdeb984df Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Fri, 8 Sep 2023 23:48:36 +0800 Subject: [PATCH 2/5] feat: get_fractor_angle_amp --- tf_pwa/amp/base.py | 34 ++++++++--- tf_pwa/amp/core.py | 146 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 7 deletions(-) diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index 7207f0ee..643e4672 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -510,21 +510,41 @@ def init_params(self): self.H = self.add_var("H", is_complex=True, shape=(len(a), len(b))) self.fix_unused_h() - def fix_unused_h(self): + def get_zero_index(self): a = self.outs[0].spins b = self.outs[1].spins - min_idx = None fix_index = [] - for idx_i, i in enumerate(a): - for idx_j, j in enumerate(b): + free_index = [] + for idx_i, i in zip(range(self.H.shape[-2]), a): + for idx_j, j in zip(range(self.H.shape[-1]), b): if abs(i - j) > self.core.J: fix_index.append((idx_i, idx_j)) - elif min_idx is None: - min_idx = idx_i, idx_j + else: + free_index.append((idx_i, idx_j)) + return fix_index, free_index + + def fix_unused_h(self): + fix_index, free_idx = self.get_zero_index() self.H.set_fix_idx(fix_index, 0.0) - self.H.set_fix_idx([min_idx], 1.0) + self.H.set_fix_idx([free_idx[0]], 1.0) + + def get_H_zero_mask(self): + fix_index, free_idx = self.get_zero_index() + + def get_factor(self): + _, free_index = self.get_zero_index() + H = self.H() + return tf.gather_nd(H, free_index) def get_helicity_amp(self, data, data_p, **kwargs): + if self.mask_factor: + H = tf.stack(self.H()) + free_idx = self.get_zero_index() + return tf.scatter_nd( + indices=free_idx, + updates=tf.ones(len(free_idx), dtype=H.dtype), + shape=H.shape, + ) return tf.stack(self.H()) diff --git a/tf_pwa/amp/core.py b/tf_pwa/amp/core.py index 7607228b..1c002059 100644 --- a/tf_pwa/amp/core.py +++ b/tf_pwa/amp/core.py @@ -422,6 +422,9 @@ def get_width(self): return self.width() return self.width + def get_factor(self): + return None + def get_sympy_var(self): return sym.var("m m0 g0 m1 m2") @@ -699,6 +702,9 @@ def init_params(self): def get_factor_variable(self): return [(self.g_ls,)] + def get_factor(self): + return self.get_g_ls() + def _get_particle_mass(self, p, data, from_data=False): if from_data and p in data: return data[p]["m"] @@ -875,6 +881,36 @@ def get_angle_helicity_amp(self, data, data_p, **kwargs): ) return ret + def get_factor_angle_helicity_amp(self, data, data_p, **kwargs): + m_dep = self.get_angle_ls_amp(data, data_p, **kwargs) # (n,l) + cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype) + n_ls = len(self.get_ls_list()) + m_dep = tf.reshape(m_dep, (-1, n_ls, 1, 1)) + cg_trans = tf.reshape( + cg_trans, (n_ls, len(self.outs[0].spins), len(self.outs[1].spins)) + ) + # H = tf.reduce_sum(m_dep * cg_trans, axis=1) + H = m_dep * cg_trans # (n, n_ls, h1, h2) + # print(n_ls, cg_trans, self, m_dep.shape) # )data_p) + if self.allow_cc: + all_data = kwargs.get("all_data", {}) + charge = all_data.get("charge_conjugation", None) + if charge is not None: + H = tf.where( + charge[..., None, None] > 0, H, H[..., ::-1, ::-1] + ) + ret = tf.reshape( + H, + ( + -1, + H.shape[-3], + 1, + len(self.outs[0].spins), + len(self.outs[1].spins), + ), + ) + return ret + def get_g_ls(self): gls = self.g_ls() if self.ls_index is None: @@ -1042,6 +1078,22 @@ def get_angle_amp(self, data, data_p, **kwargs): ret = tf.reduce_sum(ret, axis=j + 2) return ret + def get_factor_angle_amp(self, data, data_p, **kwargs): + a = self.core + b = self.outs[0] + c = self.outs[1] + ang = data[b]["ang"] + D_conj = get_D_matrix_lambda(ang, a.J, a.spins, b.spins, c.spins) + H = self.get_factor_angle_helicity_amp(data, data_p, **kwargs) + H = tf.cast(H, dtype=D_conj.dtype) + D_conj = tf.reshape(D_conj, (-1, 1, *D_conj.shape[1:])) + ret = H * tf.stop_gradient(D_conj) + # print(self, H, D_conj) + # exit() + if self.aligned: + raise NotImplemented + return ret + def get_m_dep(self, data, data_p, **kwargs): return self.get_ls_amp(data, data_p, **kwargs) @@ -1120,6 +1172,18 @@ def get_factor_variable(self): a.append(tmp) return [tuple([self.total] + a)] + def get_factor(self): + decay_factor = [i.get_factor() for i in self] + particle_factor = [i.get_factor() for i in self.inner] + all_factor = particle_factor + decay_factor + all_factor = [i for i in all_factor if i is not None] + all_factor = all_factor + ret = self.get_amp_total() + for i in all_factor: + ret = tf.expand_dims(ret, axis=-1) * tf.cast(i, ret.dtype) + ret = tf.reshape(ret, (-1,)) + return ret + def get_amp_total(self, charge=1): if self.mask_factor: return tf.ones_like(tf.stack(self.total(charge))) @@ -1230,6 +1294,51 @@ def get_angle_amp(self, data_c, data_p, all_data=None, base_map=None): # ret = einsum(idx_s, *amp_d) return ret + def get_factor_angle_amp( + self, data_c, data_p, all_data=None, base_map=None + ): + base_map = self.get_base_map(base_map) + iter_idx = ["..."] + amp_d = [] + indices = [] + next_map = "zyxwvutsr" + used_idx = "" + final_indices = self.amp_index(base_map) + for i in self: + tmp_idx = i.amp_index(base_map) + tmp_idx = [next_map[0], *tmp_idx] + indices.append(tmp_idx) + used_idx += next_map[0] + amp_d.append( + i.get_factor_angle_amp(data_c[i], data_p, all_data=all_data) + ) + next_map = next_map[1:] + final_indices = "".join(iter_idx + list(used_idx) + final_indices) + + if self.aligned: + for i in self: + for j in i.outs: + if j.J != 0 and "aligned_angle" in data_c[i][j]: + ang = data_c[i][j]["aligned_angle"] + dt = get_D_matrix_lambda(ang, j.J, j.spins, j.spins) + amp_d.append(tf.stop_gradient(dt)) + idx = [base_map[j], base_map[j].upper()] + indices.append(idx) + final_indices = final_indices.replace(*idx) + idxs = [] + for i in indices: + tmp = "".join(iter_idx + i) + idxs.append(tmp) + idx = ",".join(idxs) + idx_s = "{}->{}".format(idx, final_indices) + # ret = amp * tf.reshape(rs, [-1] + [1] * len(self.amp_shape())) + print(idx_s) # , amp_d) + ret = tf.einsum(idx_s, *amp_d) + # print(self, ret[0]) + # exit() + # ret = einsum(idx_s, *amp_d) + return ret + def get_m_dep(self, data_c, data_p, all_data=None, base_map=None): base_map = self.get_base_map(base_map) iter_idx = ["..."] @@ -1367,6 +1476,12 @@ def get_factor_variable(self): ret += i.get_factor_variable() return ret + def get_factor(self): + ret = [] + for i in self: + ret += i.get_factor() + return ret + def get_amp(self, data): """ calculate the amplitude as complex number @@ -1466,6 +1581,37 @@ def get_angle_amp(self, data): # ret = tf.reduce_sum(ret, axis=0) return amp + def get_factor_angle_amp(self, data): + data_particle = data["particle"] + data_decay = data["decay"] + + used_chains = tuple([self.chains[i] for i in self.chains_idx]) + chain_maps = self.get_chains_map(used_chains) + base_map = self.get_base_map() + ret = [] + for decay_chain in used_chains: + for chains in chain_maps: + if str(decay_chain) in [str(i) for i in chains]: + maps = chains[decay_chain] + break + chain_topo = decay_chain.standard_topology() + found = False + for i in data_decay.keys(): + if i == chain_topo: + data_decay_i = data_decay[i] + found = True + break + if not found: + raise KeyError("not found {}".format(chain_topo)) + data_c = rename_data_dict(data_decay_i, maps) + data_p = rename_data_dict(data_particle, maps) + amp = decay_chain.get_factor_angle_amp( + data_c, data_p, base_map=base_map, all_data=data + ) + ret.append(amp) + # ret = tf.reduce_sum(ret, axis=0) + return amp + @functools.lru_cache() def get_swap_factor(self, key): factor = 1.0 From 372848851529ebd599401602d53e74cfbb1875ec Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 9 Sep 2023 13:22:10 +0800 Subject: [PATCH 3/5] feat: amp_model factor --- tf_pwa/amp/amp.py | 27 +++++++++++++++++++++++++++ tf_pwa/amp/base.py | 16 ++++++++++++++-- tf_pwa/amp/core.py | 21 +++++++++++++-------- tf_pwa/tests/config_hel.yml | 26 ++++++++++++++++++++++++++ tf_pwa/tests/test_full.py | 7 +++++++ 5 files changed, 87 insertions(+), 10 deletions(-) create mode 100644 tf_pwa/tests/config_hel.yml diff --git a/tf_pwa/amp/amp.py b/tf_pwa/amp/amp.py index 77e95ebf..a3a9f52b 100644 --- a/tf_pwa/amp/amp.py +++ b/tf_pwa/amp/amp.py @@ -1,6 +1,7 @@ import contextlib import warnings +import numpy as np import tensorflow as tf from tf_pwa.amp.core import Variable, variable_scope @@ -280,6 +281,32 @@ def pdf(self, data): return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape)))) +@register_amp_model("base_factor") +class FactorAmplitudeModel(BaseAmplitudeModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def pdf(self, data): + m_dep = self.decay_group.get_m_dep(data) + angle_amp = self.decay_group.get_factor_angle_amp(data) + ret = [] + for a, b in zip(m_dep, angle_amp): + tmp = b + for i in a: + total_size = np.prod(tmp.shape[1:]) + if len(i.shape) == 1: + i = tf.expand_dims(i, axis=-1) + tmp = tf.reshape( + tmp, (-1, i.shape[-1], total_size // i.shape[-1]) + ) + tmp = tmp * tf.expand_dims(i, axis=-1) + tmp = tf.reduce_sum(tmp, axis=-2) + ret.append(tmp) + amp = tf.reduce_sum(ret, axis=0) + amp2s = tf.math.real(amp * tf.math.conj(amp)) + return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape)))) + + @register_amp_model("p4_directly") class P4DirectlyAmplitudeModel(BaseAmplitudeModel): def cal_angle(self, p4): diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index 643e4672..ee777313 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -536,10 +536,10 @@ def get_factor(self): H = self.H() return tf.gather_nd(H, free_index) - def get_helicity_amp(self, data, data_p, **kwargs): + def get_helicity_amp(self, data=None, data_p=None, **kwargs): if self.mask_factor: H = tf.stack(self.H()) - free_idx = self.get_zero_index() + _, free_idx = self.get_zero_index() return tf.scatter_nd( indices=free_idx, updates=tf.ones(len(free_idx), dtype=H.dtype), @@ -547,6 +547,18 @@ def get_helicity_amp(self, data, data_p, **kwargs): ) return tf.stack(self.H()) + def get_ls_amp(self, data, data_p, **kwargs): + return tf.reshape(self.get_factor(), (1, -1)) + + def get_factor_H(self, data=None, data_p=None, **kwargs): + _, free_idx = self.get_zero_index() + H = self.get_helicity_amp() + value = tf.gather_nd(H, free_idx) + new_idx = [(i, *j) for i, j in enumerate(free_idx)] + return tf.scatter_nd( + indices=new_idx, updates=value, shape=(len(free_idx), *H.shape) + ) + @regist_decay("helicity_full-bf") class HelicityDecayNPbf(HelicityDecayNP): diff --git a/tf_pwa/amp/core.py b/tf_pwa/amp/core.py index 1c002059..aba71879 100644 --- a/tf_pwa/amp/core.py +++ b/tf_pwa/amp/core.py @@ -881,7 +881,7 @@ def get_angle_helicity_amp(self, data, data_p, **kwargs): ) return ret - def get_factor_angle_helicity_amp(self, data, data_p, **kwargs): + def get_factor_H(self, data, data_p, **kwargs): # -> (n, n_ls, h1, h2) m_dep = self.get_angle_ls_amp(data, data_p, **kwargs) # (n,l) cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype) n_ls = len(self.get_ls_list()) @@ -891,7 +891,10 @@ def get_factor_angle_helicity_amp(self, data, data_p, **kwargs): ) # H = tf.reduce_sum(m_dep * cg_trans, axis=1) H = m_dep * cg_trans # (n, n_ls, h1, h2) - # print(n_ls, cg_trans, self, m_dep.shape) # )data_p) + return H + + def get_factor_angle_helicity_amp(self, data, data_p, **kwargs): + H = self.get_factor_H(data, data_p, **kwargs) if self.allow_cc: all_data = kwargs.get("all_data", {}) charge = all_data.get("charge_conjugation", None) @@ -1097,6 +1100,9 @@ def get_factor_angle_amp(self, data, data_p, **kwargs): def get_m_dep(self, data, data_p, **kwargs): return self.get_ls_amp(data, data_p, **kwargs) + def get_factor_m_dep(self, data, data_p, **kwargs): + return self.get_ls_amp(data, data_p, **kwargs) + def get_ls_list(self): """get possible ls for decay, with l_list filter possible l""" ls_list = super(HelicityDecay, self).get_ls_list() @@ -1172,16 +1178,15 @@ def get_factor_variable(self): a.append(tmp) return [tuple([self.total] + a)] - def get_factor(self): + def get_factor(self): # (total, decay1, decay2, ...) decay_factor = [i.get_factor() for i in self] particle_factor = [i.get_factor() for i in self.inner] - all_factor = particle_factor + decay_factor + all_factor = decay_factor + particle_factor all_factor = [i for i in all_factor if i is not None] all_factor = all_factor ret = self.get_amp_total() for i in all_factor: ret = tf.expand_dims(ret, axis=-1) * tf.cast(i, ret.dtype) - ret = tf.reshape(ret, (-1,)) return ret def get_amp_total(self, charge=1): @@ -1332,7 +1337,7 @@ def get_factor_angle_amp( idx = ",".join(idxs) idx_s = "{}->{}".format(idx, final_indices) # ret = amp * tf.reshape(rs, [-1] + [1] * len(self.amp_shape())) - print(idx_s) # , amp_d) + # print(idx_s) # , amp_d) ret = tf.einsum(idx_s, *amp_d) # print(self, ret[0]) # exit() @@ -1479,7 +1484,7 @@ def get_factor_variable(self): def get_factor(self): ret = [] for i in self: - ret += i.get_factor() + ret.append(i.get_factor()) return ret def get_amp(self, data): @@ -1610,7 +1615,7 @@ def get_factor_angle_amp(self, data): ) ret.append(amp) # ret = tf.reduce_sum(ret, axis=0) - return amp + return ret @functools.lru_cache() def get_swap_factor(self, key): diff --git a/tf_pwa/tests/config_hel.yml b/tf_pwa/tests/config_hel.yml new file mode 100644 index 00000000..19d20424 --- /dev/null +++ b/tf_pwa/tests/config_hel.yml @@ -0,0 +1,26 @@ +data: + dat_order: [B, C, D] + amp_model: base_factor + +decay: + A: [BC, D, model: helicity_full] + BC: [B, C] + +particle: + $top: + A: { J: 1, P: -1, mass: 5.3 } + $finals: + B: { J: 0, P: -1, mass: 0.1 } + C: { J: 0, P: -1, mass: 0.1 } + D: { J: 0, P: -1, mass: 0.1 } + BC: [BC1, BC2] + BC1: + J: 1 + P: -1 + mass: 1.0 + width: 0.2 + BC2: + J: 1 + P: -1 + mass: 2.0 + width: 0.2 diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index 09623aa8..94184f77 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -393,3 +393,10 @@ def test_plot_2dpull(toy_config): def test_lazy_file(toy_config_npy): fcn = toy_config_npy.get_fcn() fcn.nll_grad() + + +def test_factor_hel(): + config = ConfigLoader(f"{this_dir}/config_hel.yml") + phsp = config.generate_phsp(10) + amp = config.get_amplitude() + amp(phsp) From 189be05686d204ef3cb25d1817cdb4f948b02f5c Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 9 Sep 2023 23:24:02 +0800 Subject: [PATCH 4/5] feat: preprocessor cached_angle --- tf_pwa/amp/amp.py | 5 ++++- tf_pwa/amp/base.py | 20 ++++++++++++++++---- tf_pwa/amp/preprocess.py | 37 +++++++++++++++++++++++++++++++++++++ tf_pwa/tests/config_hel.yml | 17 ++++++++++++++--- tf_pwa/tests/test_full.py | 1 + 5 files changed, 72 insertions(+), 8 deletions(-) diff --git a/tf_pwa/amp/amp.py b/tf_pwa/amp/amp.py index a3a9f52b..c174f553 100644 --- a/tf_pwa/amp/amp.py +++ b/tf_pwa/amp/amp.py @@ -288,7 +288,10 @@ def __init__(self, *args, **kwargs): def pdf(self, data): m_dep = self.decay_group.get_m_dep(data) - angle_amp = self.decay_group.get_factor_angle_amp(data) + if "cached_angle" in data: + angle_amp = data["cached_angle"] + else: + angle_amp = self.decay_group.get_factor_angle_amp(data) ret = [] for a, b in zip(m_dep, angle_amp): tmp = b diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index ee777313..d3ef82d1 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -536,7 +536,7 @@ def get_factor(self): H = self.H() return tf.gather_nd(H, free_index) - def get_helicity_amp(self, data=None, data_p=None, **kwargs): + def get_H(self): if self.mask_factor: H = tf.stack(self.H()) _, free_idx = self.get_zero_index() @@ -547,6 +547,9 @@ def get_helicity_amp(self, data=None, data_p=None, **kwargs): ) return tf.stack(self.H()) + def get_helicity_amp(self, data=None, data_p=None, **kwargs): + return self.get_H() + def get_ls_amp(self, data, data_p, **kwargs): return tf.reshape(self.get_factor(), (1, -1)) @@ -566,7 +569,7 @@ def init_params(self): self.d = 3.0 super().init_params() - def get_helicity_amp(self, data, data_p, **kwargs): + def get_H_barrier_factor(self, data, data_p, **kwargs): q0 = self.get_relative_momentum(data_p, False) data["|q0|"] = q0 if "|q|" in data: @@ -575,10 +578,19 @@ def get_helicity_amp(self, data, data_p, **kwargs): q = self.get_relative_momentum(data_p, True) data["|q|"] = q bf = barrier_factor([min(self.get_l_list())], q, q0, self.d) - H = tf.stack(self.H()) + return bf + + def get_helicity_amp(self, data, data_p, **kwargs): + H = self.get_H() + bf = self.get_H_barrier_factor(data, data_p, **kwargs) bf = tf.cast(tf.reshape(bf, (-1, 1, 1)), H.dtype) return H * bf + def get_ls_amp(self, data, data_p, **kwargs): + bf = self.get_H_barrier_factor(data, data_p, **kwargs) + f = tf.reshape(self.get_factor(), (1, -1)) + return f * tf.expand_dims(tf.cast(bf, f.dtype), axis=-1) + def get_parity_term(j1, p1, j2, p2, j3, p3): p = p1 * p2 * p3 * (-1) ** (j1 - j2 - j3) @@ -616,7 +628,7 @@ def init_params(self): def get_helicity_amp(self, data, data_p, **kwargs): n_b = len(self.outs[0].spins) n_c = len(self.outs[1].spins) - H_part = tf.stack(self.H()) + H_part = self.get_H() if self.part_H == 0: H = tf.concat( [ diff --git a/tf_pwa/amp/preprocess.py b/tf_pwa/amp/preprocess.py index dcb95529..59e36d91 100644 --- a/tf_pwa/amp/preprocess.py +++ b/tf_pwa/amp/preprocess.py @@ -146,6 +146,43 @@ def build_cached(self, x): return x +@register_preprocessor("cached_angle") +class CachedAnglePreProcessor(BasePreProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.amp = self.root_config.get_amplitude() + self.decay_group = self.amp.decay_group + self.no_angle = self.kwargs.get("no_angle", False) + self.no_p4 = self.kwargs.get("no_p4", False) + + def build_cached(self, x): + x2 = super().__call__(x) + for k, v in x["extra"].items(): + x2[k] = v # {**x2, **x["extra"]} + c_amp = self.decay_group.get_factor_angle_amp(x2) + x2["cached_angle"] = list_to_tuple(c_amp) + # print(x) + return x2 + + def strip_data(self, x): + strip_var = [] + if self.no_angle: + strip_var += ["ang", "aligned_angle"] + if self.no_p4: + strip_var += ["p"] + if strip_var: + x = data_strip(x, strip_var) + return x + + def __call__(self, x): + extra = x["extra"] + x = self.build_cached(x) + x = self.strip_data(x) + for k in extra: + del x[k] + return x + + @register_preprocessor("p4_directly") class CachedAmpPreProcessor(BasePreProcessor): def __init__(self, *args, **kwargs): diff --git a/tf_pwa/tests/config_hel.yml b/tf_pwa/tests/config_hel.yml index 19d20424..74fc47df 100644 --- a/tf_pwa/tests/config_hel.yml +++ b/tf_pwa/tests/config_hel.yml @@ -1,10 +1,16 @@ data: dat_order: [B, C, D] + preprocessor: cached_angle amp_model: base_factor decay: - A: [BC, D, model: helicity_full] + A: + - [BC, D, model: helicity_full] + - [BD, C, model: helicity_full] + - [CD, B, model: helicity_full] BC: [B, C] + BD: [B, D] + CD: [C, D] particle: $top: @@ -13,13 +19,18 @@ particle: B: { J: 0, P: -1, mass: 0.1 } C: { J: 0, P: -1, mass: 0.1 } D: { J: 0, P: -1, mass: 0.1 } - BC: [BC1, BC2] + BC: [BC1] BC1: J: 1 P: -1 mass: 1.0 width: 0.2 - BC2: + BD: + J: 1 + P: -1 + mass: 2.0 + width: 0.2 + CD: J: 1 P: -1 mass: 2.0 diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index 94184f77..7a7a96c9 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -400,3 +400,4 @@ def test_factor_hel(): phsp = config.generate_phsp(10) amp = config.get_amplitude() amp(phsp) + amp.decay_group.get_factor() From 987c0551fe58a2e54de38143973d1f9f44f41f06 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Mon, 11 Sep 2023 00:24:13 +0800 Subject: [PATCH 5/5] misc: remove unused code --- tf_pwa/amp/amp.py | 6 +++++- tf_pwa/amp/preprocess.py | 18 ------------------ 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/tf_pwa/amp/amp.py b/tf_pwa/amp/amp.py index c174f553..ac289015 100644 --- a/tf_pwa/amp/amp.py +++ b/tf_pwa/amp/amp.py @@ -286,7 +286,7 @@ class FactorAmplitudeModel(BaseAmplitudeModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def pdf(self, data): + def get_amp_list(self, data): m_dep = self.decay_group.get_m_dep(data) if "cached_angle" in data: angle_amp = data["cached_angle"] @@ -305,6 +305,10 @@ def pdf(self, data): tmp = tmp * tf.expand_dims(i, axis=-1) tmp = tf.reduce_sum(tmp, axis=-2) ret.append(tmp) + return ret + + def pdf(self, data): + ret = self.get_amp_list(data) amp = tf.reduce_sum(ret, axis=0) amp2s = tf.math.real(amp * tf.math.conj(amp)) return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape)))) diff --git a/tf_pwa/amp/preprocess.py b/tf_pwa/amp/preprocess.py index 59e36d91..9b34af60 100644 --- a/tf_pwa/amp/preprocess.py +++ b/tf_pwa/amp/preprocess.py @@ -164,24 +164,6 @@ def build_cached(self, x): # print(x) return x2 - def strip_data(self, x): - strip_var = [] - if self.no_angle: - strip_var += ["ang", "aligned_angle"] - if self.no_p4: - strip_var += ["p"] - if strip_var: - x = data_strip(x, strip_var) - return x - - def __call__(self, x): - extra = x["extra"] - x = self.build_cached(x) - x = self.strip_data(x) - for k in extra: - del x[k] - return x - @register_preprocessor("p4_directly") class CachedAmpPreProcessor(BasePreProcessor):