diff --git a/tf_pwa/amp/amp.py b/tf_pwa/amp/amp.py index 77e95ebf..ac289015 100644 --- a/tf_pwa/amp/amp.py +++ b/tf_pwa/amp/amp.py @@ -1,6 +1,7 @@ import contextlib import warnings +import numpy as np import tensorflow as tf from tf_pwa.amp.core import Variable, variable_scope @@ -280,6 +281,39 @@ def pdf(self, data): return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape)))) +@register_amp_model("base_factor") +class FactorAmplitudeModel(BaseAmplitudeModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_amp_list(self, data): + m_dep = self.decay_group.get_m_dep(data) + if "cached_angle" in data: + angle_amp = data["cached_angle"] + else: + angle_amp = self.decay_group.get_factor_angle_amp(data) + ret = [] + for a, b in zip(m_dep, angle_amp): + tmp = b + for i in a: + total_size = np.prod(tmp.shape[1:]) + if len(i.shape) == 1: + i = tf.expand_dims(i, axis=-1) + tmp = tf.reshape( + tmp, (-1, i.shape[-1], total_size // i.shape[-1]) + ) + tmp = tmp * tf.expand_dims(i, axis=-1) + tmp = tf.reduce_sum(tmp, axis=-2) + ret.append(tmp) + return ret + + def pdf(self, data): + ret = self.get_amp_list(data) + amp = tf.reduce_sum(ret, axis=0) + amp2s = tf.math.real(amp * tf.math.conj(amp)) + return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape)))) + + @register_amp_model("p4_directly") class P4DirectlyAmplitudeModel(BaseAmplitudeModel): def cal_angle(self, p4): diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index 0c2c223a..d3ef82d1 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -508,20 +508,68 @@ def init_params(self): a = self.outs[0].spins b = self.outs[1].spins self.H = self.add_var("H", is_complex=True, shape=(len(a), len(b))) + self.fix_unused_h() - def get_helicity_amp(self, data, data_p, **kwargs): + def get_zero_index(self): + a = self.outs[0].spins + b = self.outs[1].spins + fix_index = [] + free_index = [] + for idx_i, i in zip(range(self.H.shape[-2]), a): + for idx_j, j in zip(range(self.H.shape[-1]), b): + if abs(i - j) > self.core.J: + fix_index.append((idx_i, idx_j)) + else: + free_index.append((idx_i, idx_j)) + return fix_index, free_index + + def fix_unused_h(self): + fix_index, free_idx = self.get_zero_index() + self.H.set_fix_idx(fix_index, 0.0) + self.H.set_fix_idx([free_idx[0]], 1.0) + + def get_H_zero_mask(self): + fix_index, free_idx = self.get_zero_index() + + def get_factor(self): + _, free_index = self.get_zero_index() + H = self.H() + return tf.gather_nd(H, free_index) + + def get_H(self): + if self.mask_factor: + H = tf.stack(self.H()) + _, free_idx = self.get_zero_index() + return tf.scatter_nd( + indices=free_idx, + updates=tf.ones(len(free_idx), dtype=H.dtype), + shape=H.shape, + ) return tf.stack(self.H()) + def get_helicity_amp(self, data=None, data_p=None, **kwargs): + return self.get_H() + + def get_ls_amp(self, data, data_p, **kwargs): + return tf.reshape(self.get_factor(), (1, -1)) + + def get_factor_H(self, data=None, data_p=None, **kwargs): + _, free_idx = self.get_zero_index() + H = self.get_helicity_amp() + value = tf.gather_nd(H, free_idx) + new_idx = [(i, *j) for i, j in enumerate(free_idx)] + return tf.scatter_nd( + indices=new_idx, updates=value, shape=(len(free_idx), *H.shape) + ) + @regist_decay("helicity_full-bf") -class HelicityDecayNPbf(HelicityDecay): +class HelicityDecayNPbf(HelicityDecayNP): def init_params(self): self.d = 3.0 - a = self.outs[0].spins - b = self.outs[1].spins - self.H = self.add_var("H", is_complex=True, shape=(len(a), len(b))) + super().init_params() - def get_helicity_amp(self, data, data_p, **kwargs): + def get_H_barrier_factor(self, data, data_p, **kwargs): q0 = self.get_relative_momentum(data_p, False) data["|q0|"] = q0 if "|q|" in data: @@ -530,10 +578,19 @@ def get_helicity_amp(self, data, data_p, **kwargs): q = self.get_relative_momentum(data_p, True) data["|q|"] = q bf = barrier_factor([min(self.get_l_list())], q, q0, self.d) - H = tf.stack(self.H()) + return bf + + def get_helicity_amp(self, data, data_p, **kwargs): + H = self.get_H() + bf = self.get_H_barrier_factor(data, data_p, **kwargs) bf = tf.cast(tf.reshape(bf, (-1, 1, 1)), H.dtype) return H * bf + def get_ls_amp(self, data, data_p, **kwargs): + bf = self.get_H_barrier_factor(data, data_p, **kwargs) + f = tf.reshape(self.get_factor(), (1, -1)) + return f * tf.expand_dims(tf.cast(bf, f.dtype), axis=-1) + def get_parity_term(j1, p1, j2, p2, j3, p3): p = p1 * p2 * p3 * (-1) ** (j1 - j2 - j3) @@ -541,7 +598,7 @@ def get_parity_term(j1, p1, j2, p2, j3, p3): @regist_decay("helicity_parity") -class HelicityDecayP(HelicityDecay): +class HelicityDecayP(HelicityDecayNP): """ .. math:: @@ -566,11 +623,12 @@ def init_params(self): "H", is_complex=True, shape=(n_b, (n_c + 1) // 2) ) self.part_H = 1 + self.fix_unused_h() def get_helicity_amp(self, data, data_p, **kwargs): n_b = len(self.outs[0].spins) n_c = len(self.outs[1].spins) - H_part = tf.stack(self.H()) + H_part = self.get_H() if self.part_H == 0: H = tf.concat( [ diff --git a/tf_pwa/amp/core.py b/tf_pwa/amp/core.py index 7607228b..aba71879 100644 --- a/tf_pwa/amp/core.py +++ b/tf_pwa/amp/core.py @@ -422,6 +422,9 @@ def get_width(self): return self.width() return self.width + def get_factor(self): + return None + def get_sympy_var(self): return sym.var("m m0 g0 m1 m2") @@ -699,6 +702,9 @@ def init_params(self): def get_factor_variable(self): return [(self.g_ls,)] + def get_factor(self): + return self.get_g_ls() + def _get_particle_mass(self, p, data, from_data=False): if from_data and p in data: return data[p]["m"] @@ -875,6 +881,39 @@ def get_angle_helicity_amp(self, data, data_p, **kwargs): ) return ret + def get_factor_H(self, data, data_p, **kwargs): # -> (n, n_ls, h1, h2) + m_dep = self.get_angle_ls_amp(data, data_p, **kwargs) # (n,l) + cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype) + n_ls = len(self.get_ls_list()) + m_dep = tf.reshape(m_dep, (-1, n_ls, 1, 1)) + cg_trans = tf.reshape( + cg_trans, (n_ls, len(self.outs[0].spins), len(self.outs[1].spins)) + ) + # H = tf.reduce_sum(m_dep * cg_trans, axis=1) + H = m_dep * cg_trans # (n, n_ls, h1, h2) + return H + + def get_factor_angle_helicity_amp(self, data, data_p, **kwargs): + H = self.get_factor_H(data, data_p, **kwargs) + if self.allow_cc: + all_data = kwargs.get("all_data", {}) + charge = all_data.get("charge_conjugation", None) + if charge is not None: + H = tf.where( + charge[..., None, None] > 0, H, H[..., ::-1, ::-1] + ) + ret = tf.reshape( + H, + ( + -1, + H.shape[-3], + 1, + len(self.outs[0].spins), + len(self.outs[1].spins), + ), + ) + return ret + def get_g_ls(self): gls = self.g_ls() if self.ls_index is None: @@ -1042,9 +1081,28 @@ def get_angle_amp(self, data, data_p, **kwargs): ret = tf.reduce_sum(ret, axis=j + 2) return ret + def get_factor_angle_amp(self, data, data_p, **kwargs): + a = self.core + b = self.outs[0] + c = self.outs[1] + ang = data[b]["ang"] + D_conj = get_D_matrix_lambda(ang, a.J, a.spins, b.spins, c.spins) + H = self.get_factor_angle_helicity_amp(data, data_p, **kwargs) + H = tf.cast(H, dtype=D_conj.dtype) + D_conj = tf.reshape(D_conj, (-1, 1, *D_conj.shape[1:])) + ret = H * tf.stop_gradient(D_conj) + # print(self, H, D_conj) + # exit() + if self.aligned: + raise NotImplemented + return ret + def get_m_dep(self, data, data_p, **kwargs): return self.get_ls_amp(data, data_p, **kwargs) + def get_factor_m_dep(self, data, data_p, **kwargs): + return self.get_ls_amp(data, data_p, **kwargs) + def get_ls_list(self): """get possible ls for decay, with l_list filter possible l""" ls_list = super(HelicityDecay, self).get_ls_list() @@ -1120,6 +1178,17 @@ def get_factor_variable(self): a.append(tmp) return [tuple([self.total] + a)] + def get_factor(self): # (total, decay1, decay2, ...) + decay_factor = [i.get_factor() for i in self] + particle_factor = [i.get_factor() for i in self.inner] + all_factor = decay_factor + particle_factor + all_factor = [i for i in all_factor if i is not None] + all_factor = all_factor + ret = self.get_amp_total() + for i in all_factor: + ret = tf.expand_dims(ret, axis=-1) * tf.cast(i, ret.dtype) + return ret + def get_amp_total(self, charge=1): if self.mask_factor: return tf.ones_like(tf.stack(self.total(charge))) @@ -1230,6 +1299,51 @@ def get_angle_amp(self, data_c, data_p, all_data=None, base_map=None): # ret = einsum(idx_s, *amp_d) return ret + def get_factor_angle_amp( + self, data_c, data_p, all_data=None, base_map=None + ): + base_map = self.get_base_map(base_map) + iter_idx = ["..."] + amp_d = [] + indices = [] + next_map = "zyxwvutsr" + used_idx = "" + final_indices = self.amp_index(base_map) + for i in self: + tmp_idx = i.amp_index(base_map) + tmp_idx = [next_map[0], *tmp_idx] + indices.append(tmp_idx) + used_idx += next_map[0] + amp_d.append( + i.get_factor_angle_amp(data_c[i], data_p, all_data=all_data) + ) + next_map = next_map[1:] + final_indices = "".join(iter_idx + list(used_idx) + final_indices) + + if self.aligned: + for i in self: + for j in i.outs: + if j.J != 0 and "aligned_angle" in data_c[i][j]: + ang = data_c[i][j]["aligned_angle"] + dt = get_D_matrix_lambda(ang, j.J, j.spins, j.spins) + amp_d.append(tf.stop_gradient(dt)) + idx = [base_map[j], base_map[j].upper()] + indices.append(idx) + final_indices = final_indices.replace(*idx) + idxs = [] + for i in indices: + tmp = "".join(iter_idx + i) + idxs.append(tmp) + idx = ",".join(idxs) + idx_s = "{}->{}".format(idx, final_indices) + # ret = amp * tf.reshape(rs, [-1] + [1] * len(self.amp_shape())) + # print(idx_s) # , amp_d) + ret = tf.einsum(idx_s, *amp_d) + # print(self, ret[0]) + # exit() + # ret = einsum(idx_s, *amp_d) + return ret + def get_m_dep(self, data_c, data_p, all_data=None, base_map=None): base_map = self.get_base_map(base_map) iter_idx = ["..."] @@ -1367,6 +1481,12 @@ def get_factor_variable(self): ret += i.get_factor_variable() return ret + def get_factor(self): + ret = [] + for i in self: + ret.append(i.get_factor()) + return ret + def get_amp(self, data): """ calculate the amplitude as complex number @@ -1466,6 +1586,37 @@ def get_angle_amp(self, data): # ret = tf.reduce_sum(ret, axis=0) return amp + def get_factor_angle_amp(self, data): + data_particle = data["particle"] + data_decay = data["decay"] + + used_chains = tuple([self.chains[i] for i in self.chains_idx]) + chain_maps = self.get_chains_map(used_chains) + base_map = self.get_base_map() + ret = [] + for decay_chain in used_chains: + for chains in chain_maps: + if str(decay_chain) in [str(i) for i in chains]: + maps = chains[decay_chain] + break + chain_topo = decay_chain.standard_topology() + found = False + for i in data_decay.keys(): + if i == chain_topo: + data_decay_i = data_decay[i] + found = True + break + if not found: + raise KeyError("not found {}".format(chain_topo)) + data_c = rename_data_dict(data_decay_i, maps) + data_p = rename_data_dict(data_particle, maps) + amp = decay_chain.get_factor_angle_amp( + data_c, data_p, base_map=base_map, all_data=data + ) + ret.append(amp) + # ret = tf.reduce_sum(ret, axis=0) + return ret + @functools.lru_cache() def get_swap_factor(self, key): factor = 1.0 diff --git a/tf_pwa/amp/preprocess.py b/tf_pwa/amp/preprocess.py index dcb95529..9b34af60 100644 --- a/tf_pwa/amp/preprocess.py +++ b/tf_pwa/amp/preprocess.py @@ -146,6 +146,25 @@ def build_cached(self, x): return x +@register_preprocessor("cached_angle") +class CachedAnglePreProcessor(BasePreProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.amp = self.root_config.get_amplitude() + self.decay_group = self.amp.decay_group + self.no_angle = self.kwargs.get("no_angle", False) + self.no_p4 = self.kwargs.get("no_p4", False) + + def build_cached(self, x): + x2 = super().__call__(x) + for k, v in x["extra"].items(): + x2[k] = v # {**x2, **x["extra"]} + c_amp = self.decay_group.get_factor_angle_amp(x2) + x2["cached_angle"] = list_to_tuple(c_amp) + # print(x) + return x2 + + @register_preprocessor("p4_directly") class CachedAmpPreProcessor(BasePreProcessor): def __init__(self, *args, **kwargs): diff --git a/tf_pwa/tests/config_hel.yml b/tf_pwa/tests/config_hel.yml new file mode 100644 index 00000000..74fc47df --- /dev/null +++ b/tf_pwa/tests/config_hel.yml @@ -0,0 +1,37 @@ +data: + dat_order: [B, C, D] + preprocessor: cached_angle + amp_model: base_factor + +decay: + A: + - [BC, D, model: helicity_full] + - [BD, C, model: helicity_full] + - [CD, B, model: helicity_full] + BC: [B, C] + BD: [B, D] + CD: [C, D] + +particle: + $top: + A: { J: 1, P: -1, mass: 5.3 } + $finals: + B: { J: 0, P: -1, mass: 0.1 } + C: { J: 0, P: -1, mass: 0.1 } + D: { J: 0, P: -1, mass: 0.1 } + BC: [BC1] + BC1: + J: 1 + P: -1 + mass: 1.0 + width: 0.2 + BD: + J: 1 + P: -1 + mass: 2.0 + width: 0.2 + CD: + J: 1 + P: -1 + mass: 2.0 + width: 0.2 diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index 09623aa8..7a7a96c9 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -393,3 +393,11 @@ def test_plot_2dpull(toy_config): def test_lazy_file(toy_config_npy): fcn = toy_config_npy.get_fcn() fcn.nll_grad() + + +def test_factor_hel(): + config = ConfigLoader(f"{this_dir}/config_hel.yml") + phsp = config.generate_phsp(10) + amp = config.get_amplitude() + amp(phsp) + amp.decay_group.get_factor() diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 30e5c1c3..498afa7a 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -1494,31 +1494,16 @@ def freed(self): else: raise Exception("Only shape==() var supports 'freed' method.") - def set_fix_idx(self, fix_idx=None, fix_vals=None, free_idx=None): - """ - :param fix_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is fixed. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[1,2]``, then Variable()[i][j][1] and Variable()[i][j][2] will be the fixed value. - :param fix_vals: Float or length-2 float list for complex variable. The fixed value. - :param free_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is set free. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[0]``, then Variable()[i][j][0] will be set free. - """ - if not self.shape: - raise Exception( - "Only shape!=() var supports 'set_fix_idx' method to fix or free variables." - ) - if free_idx is None: - free_idx = [] - else: - free_idx = free_idx % self.shape[-1] + def _set_fix_idx(self, fix_idx=None, fix_vals=None, unfix=False): if fix_idx is None: fix_idx = [] else: - fix_idx = fix_idx % self.shape[-1] - - if not hasattr(fix_idx, "__len__"): - fix_idx = [fix_idx] - fix_idx_str = ["_" + str(i) for i in fix_idx] - if not hasattr(free_idx, "__len__"): - free_idx = [free_idx] - free_idx_str = ["_" + str(i) for i in free_idx] + fix_idx = np.array(fix_idx) + if len(fix_idx.shape) <= 1: + fix_idx = np.reshape(fix_idx, (-1, 1)) + shape_mod = [self.shape[-i - 1] for i in range(fix_idx.shape[-1])] + fix_idx = fix_idx % shape_mod + fix_idx_str = ["_" + "_".join(str(j) for j in i) for i in fix_idx] if self.cp_effect: print( @@ -1541,19 +1526,22 @@ def func(name, idx): for ss in fix_idx_str: if name.endswith(ss): # print("set_fix_idx set name+r ", name) - self.vm.set_fix(name + "r", value=fix_vals[0]) - self.vm.set_fix(name + "i", value=fix_vals[1]) + self.vm.set_fix( + name + "r", value=fix_vals[0], unfix=unfix + ) + self.vm.set_fix( + name + "i", value=fix_vals[1], unfix=unfix + ) if len(fix_vals) > 2: - self.vm.set_fix(name + "deltar", value=fix_vals[2]) - self.vm.set_fix(name + "deltai", value=fix_vals[3]) - for ss in free_idx_str: - if name.endswith(ss): - self.vm.set_fix(name + "r", unfix=True) - self.vm.set_fix(name + "i", unfix=True) - self.vm.set_fix(name + "deltar", unfix=True) - self.vm.set_fix(name + "deltai", unfix=True) + self.vm.set_fix( + name + "deltar", value=fix_vals[2], unfix=unfix + ) + self.vm.set_fix( + name + "deltai", value=fix_vals[3], unfix=unfix + ) _shape_func(func, self.shape, self.name) + elif self.cplx: if fix_vals is None: fix_vals = [None, None] @@ -1563,12 +1551,12 @@ def func(name, idx): def func(name, idx): for ss in fix_idx_str: if name.endswith(ss): - self.vm.set_fix(name + "r", value=fix_vals[0]) - self.vm.set_fix(name + "i", value=fix_vals[1]) - for ss in free_idx_str: - if name.endswith(ss): - self.vm.set_fix(name + "r", unfix=True) - self.vm.set_fix(name + "i", unfix=True) + self.vm.set_fix( + name + "r", value=fix_vals[0], unfix=unfix + ) + self.vm.set_fix( + name + "i", value=fix_vals[1], unfix=unfix + ) _shape_func(func, self.shape, self.name) else: @@ -1576,13 +1564,23 @@ def func(name, idx): def func(name, idx): for ss in fix_idx_str: if name.endswith(ss): - self.vm.set_fix(name, value=fix_vals) - for ss in free_idx_str: - if name.endswith(ss): - self.vm.set_fix(name, unfix=True) + self.vm.set_fix(name, value=fix_vals, unfix=unfix) _shape_func(func, self.shape, self.name) + def set_fix_idx(self, fix_idx=None, fix_vals=None, free_idx=None): + """ + :param fix_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is fixed. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[1,2]``, then Variable()[i][j][1] and Variable()[i][j][2] will be the fixed value. + :param fix_vals: Float or length-2 float list for complex variable. The fixed value. + :param free_idx: Interger or list of integers. Which complex component in the innermost layer of the variable is set free. E.g. If ``self.shape==[2,3,4]`` and ``fix_idx==[0]``, then Variable()[i][j][0] will be set free. + """ + if not self.shape: + raise Exception( + "Only shape!=() var supports 'set_fix_idx' method to fix or free variables." + ) + self._set_fix_idx(fix_idx, fix_vals) + self._set_fix_idx(free_idx, unfix=True) + def set_bound(self, bound, func=None, overwrite=False): """ Set boundary for this Variable. Note only non-shape real Variable supports this method.