From 189be05686d204ef3cb25d1817cdb4f948b02f5c Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 9 Sep 2023 23:24:02 +0800 Subject: [PATCH] feat: preprocessor cached_angle --- tf_pwa/amp/amp.py | 5 ++++- tf_pwa/amp/base.py | 20 ++++++++++++++++---- tf_pwa/amp/preprocess.py | 37 +++++++++++++++++++++++++++++++++++++ tf_pwa/tests/config_hel.yml | 17 ++++++++++++++--- tf_pwa/tests/test_full.py | 1 + 5 files changed, 72 insertions(+), 8 deletions(-) diff --git a/tf_pwa/amp/amp.py b/tf_pwa/amp/amp.py index a3a9f52b..c174f553 100644 --- a/tf_pwa/amp/amp.py +++ b/tf_pwa/amp/amp.py @@ -288,7 +288,10 @@ def __init__(self, *args, **kwargs): def pdf(self, data): m_dep = self.decay_group.get_m_dep(data) - angle_amp = self.decay_group.get_factor_angle_amp(data) + if "cached_angle" in data: + angle_amp = data["cached_angle"] + else: + angle_amp = self.decay_group.get_factor_angle_amp(data) ret = [] for a, b in zip(m_dep, angle_amp): tmp = b diff --git a/tf_pwa/amp/base.py b/tf_pwa/amp/base.py index ee777313..d3ef82d1 100644 --- a/tf_pwa/amp/base.py +++ b/tf_pwa/amp/base.py @@ -536,7 +536,7 @@ def get_factor(self): H = self.H() return tf.gather_nd(H, free_index) - def get_helicity_amp(self, data=None, data_p=None, **kwargs): + def get_H(self): if self.mask_factor: H = tf.stack(self.H()) _, free_idx = self.get_zero_index() @@ -547,6 +547,9 @@ def get_helicity_amp(self, data=None, data_p=None, **kwargs): ) return tf.stack(self.H()) + def get_helicity_amp(self, data=None, data_p=None, **kwargs): + return self.get_H() + def get_ls_amp(self, data, data_p, **kwargs): return tf.reshape(self.get_factor(), (1, -1)) @@ -566,7 +569,7 @@ def init_params(self): self.d = 3.0 super().init_params() - def get_helicity_amp(self, data, data_p, **kwargs): + def get_H_barrier_factor(self, data, data_p, **kwargs): q0 = self.get_relative_momentum(data_p, False) data["|q0|"] = q0 if "|q|" in data: @@ -575,10 +578,19 @@ def get_helicity_amp(self, data, data_p, **kwargs): q = self.get_relative_momentum(data_p, True) data["|q|"] = q bf = barrier_factor([min(self.get_l_list())], q, q0, self.d) - H = tf.stack(self.H()) + return bf + + def get_helicity_amp(self, data, data_p, **kwargs): + H = self.get_H() + bf = self.get_H_barrier_factor(data, data_p, **kwargs) bf = tf.cast(tf.reshape(bf, (-1, 1, 1)), H.dtype) return H * bf + def get_ls_amp(self, data, data_p, **kwargs): + bf = self.get_H_barrier_factor(data, data_p, **kwargs) + f = tf.reshape(self.get_factor(), (1, -1)) + return f * tf.expand_dims(tf.cast(bf, f.dtype), axis=-1) + def get_parity_term(j1, p1, j2, p2, j3, p3): p = p1 * p2 * p3 * (-1) ** (j1 - j2 - j3) @@ -616,7 +628,7 @@ def init_params(self): def get_helicity_amp(self, data, data_p, **kwargs): n_b = len(self.outs[0].spins) n_c = len(self.outs[1].spins) - H_part = tf.stack(self.H()) + H_part = self.get_H() if self.part_H == 0: H = tf.concat( [ diff --git a/tf_pwa/amp/preprocess.py b/tf_pwa/amp/preprocess.py index dcb95529..59e36d91 100644 --- a/tf_pwa/amp/preprocess.py +++ b/tf_pwa/amp/preprocess.py @@ -146,6 +146,43 @@ def build_cached(self, x): return x +@register_preprocessor("cached_angle") +class CachedAnglePreProcessor(BasePreProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.amp = self.root_config.get_amplitude() + self.decay_group = self.amp.decay_group + self.no_angle = self.kwargs.get("no_angle", False) + self.no_p4 = self.kwargs.get("no_p4", False) + + def build_cached(self, x): + x2 = super().__call__(x) + for k, v in x["extra"].items(): + x2[k] = v # {**x2, **x["extra"]} + c_amp = self.decay_group.get_factor_angle_amp(x2) + x2["cached_angle"] = list_to_tuple(c_amp) + # print(x) + return x2 + + def strip_data(self, x): + strip_var = [] + if self.no_angle: + strip_var += ["ang", "aligned_angle"] + if self.no_p4: + strip_var += ["p"] + if strip_var: + x = data_strip(x, strip_var) + return x + + def __call__(self, x): + extra = x["extra"] + x = self.build_cached(x) + x = self.strip_data(x) + for k in extra: + del x[k] + return x + + @register_preprocessor("p4_directly") class CachedAmpPreProcessor(BasePreProcessor): def __init__(self, *args, **kwargs): diff --git a/tf_pwa/tests/config_hel.yml b/tf_pwa/tests/config_hel.yml index 19d20424..74fc47df 100644 --- a/tf_pwa/tests/config_hel.yml +++ b/tf_pwa/tests/config_hel.yml @@ -1,10 +1,16 @@ data: dat_order: [B, C, D] + preprocessor: cached_angle amp_model: base_factor decay: - A: [BC, D, model: helicity_full] + A: + - [BC, D, model: helicity_full] + - [BD, C, model: helicity_full] + - [CD, B, model: helicity_full] BC: [B, C] + BD: [B, D] + CD: [C, D] particle: $top: @@ -13,13 +19,18 @@ particle: B: { J: 0, P: -1, mass: 0.1 } C: { J: 0, P: -1, mass: 0.1 } D: { J: 0, P: -1, mass: 0.1 } - BC: [BC1, BC2] + BC: [BC1] BC1: J: 1 P: -1 mass: 1.0 width: 0.2 - BC2: + BD: + J: 1 + P: -1 + mass: 2.0 + width: 0.2 + CD: J: 1 P: -1 mass: 2.0 diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index 94184f77..7a7a96c9 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -400,3 +400,4 @@ def test_factor_hel(): phsp = config.generate_phsp(10) amp = config.get_amplitude() amp(phsp) + amp.decay_group.get_factor()