Skip to content

Commit

Permalink
feat: preprocessor cached_angle
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Sep 9, 2023
1 parent 3728488 commit 189be05
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 8 deletions.
5 changes: 4 additions & 1 deletion tf_pwa/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ def __init__(self, *args, **kwargs):

def pdf(self, data):
m_dep = self.decay_group.get_m_dep(data)
angle_amp = self.decay_group.get_factor_angle_amp(data)
if "cached_angle" in data:
angle_amp = data["cached_angle"]
else:
angle_amp = self.decay_group.get_factor_angle_amp(data)

Check warning on line 294 in tf_pwa/amp/amp.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/amp/amp.py#L294

Added line #L294 was not covered by tests
ret = []
for a, b in zip(m_dep, angle_amp):
tmp = b
Expand Down
20 changes: 16 additions & 4 deletions tf_pwa/amp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def get_factor(self):
H = self.H()
return tf.gather_nd(H, free_index)

def get_helicity_amp(self, data=None, data_p=None, **kwargs):
def get_H(self):
if self.mask_factor:
H = tf.stack(self.H())
_, free_idx = self.get_zero_index()
Expand All @@ -547,6 +547,9 @@ def get_helicity_amp(self, data=None, data_p=None, **kwargs):
)
return tf.stack(self.H())

def get_helicity_amp(self, data=None, data_p=None, **kwargs):
return self.get_H()

def get_ls_amp(self, data, data_p, **kwargs):
return tf.reshape(self.get_factor(), (1, -1))

Expand All @@ -566,7 +569,7 @@ def init_params(self):
self.d = 3.0
super().init_params()

def get_helicity_amp(self, data, data_p, **kwargs):
def get_H_barrier_factor(self, data, data_p, **kwargs):
q0 = self.get_relative_momentum(data_p, False)
data["|q0|"] = q0
if "|q|" in data:
Expand All @@ -575,10 +578,19 @@ def get_helicity_amp(self, data, data_p, **kwargs):
q = self.get_relative_momentum(data_p, True)
data["|q|"] = q
bf = barrier_factor([min(self.get_l_list())], q, q0, self.d)
H = tf.stack(self.H())
return bf

def get_helicity_amp(self, data, data_p, **kwargs):
H = self.get_H()
bf = self.get_H_barrier_factor(data, data_p, **kwargs)
bf = tf.cast(tf.reshape(bf, (-1, 1, 1)), H.dtype)
return H * bf

def get_ls_amp(self, data, data_p, **kwargs):
bf = self.get_H_barrier_factor(data, data_p, **kwargs)
f = tf.reshape(self.get_factor(), (1, -1))
return f * tf.expand_dims(tf.cast(bf, f.dtype), axis=-1)

Check warning on line 592 in tf_pwa/amp/base.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/amp/base.py#L590-L592

Added lines #L590 - L592 were not covered by tests


def get_parity_term(j1, p1, j2, p2, j3, p3):
p = p1 * p2 * p3 * (-1) ** (j1 - j2 - j3)
Expand Down Expand Up @@ -616,7 +628,7 @@ def init_params(self):
def get_helicity_amp(self, data, data_p, **kwargs):
n_b = len(self.outs[0].spins)
n_c = len(self.outs[1].spins)
H_part = tf.stack(self.H())
H_part = self.get_H()
if self.part_H == 0:
H = tf.concat(
[
Expand Down
37 changes: 37 additions & 0 deletions tf_pwa/amp/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,43 @@ def build_cached(self, x):
return x


@register_preprocessor("cached_angle")
class CachedAnglePreProcessor(BasePreProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.amp = self.root_config.get_amplitude()
self.decay_group = self.amp.decay_group
self.no_angle = self.kwargs.get("no_angle", False)
self.no_p4 = self.kwargs.get("no_p4", False)

def build_cached(self, x):
x2 = super().__call__(x)
for k, v in x["extra"].items():
x2[k] = v # {**x2, **x["extra"]}
c_amp = self.decay_group.get_factor_angle_amp(x2)
x2["cached_angle"] = list_to_tuple(c_amp)
# print(x)
return x2

def strip_data(self, x):
strip_var = []
if self.no_angle:
strip_var += ["ang", "aligned_angle"]

Check warning on line 170 in tf_pwa/amp/preprocess.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/amp/preprocess.py#L170

Added line #L170 was not covered by tests
if self.no_p4:
strip_var += ["p"]

Check warning on line 172 in tf_pwa/amp/preprocess.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/amp/preprocess.py#L172

Added line #L172 was not covered by tests
if strip_var:
x = data_strip(x, strip_var)

Check warning on line 174 in tf_pwa/amp/preprocess.py

View check run for this annotation

Codecov / codecov/patch

tf_pwa/amp/preprocess.py#L174

Added line #L174 was not covered by tests
return x

def __call__(self, x):
extra = x["extra"]
x = self.build_cached(x)
x = self.strip_data(x)
for k in extra:
del x[k]
return x


@register_preprocessor("p4_directly")
class CachedAmpPreProcessor(BasePreProcessor):
def __init__(self, *args, **kwargs):
Expand Down
17 changes: 14 additions & 3 deletions tf_pwa/tests/config_hel.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
data:
dat_order: [B, C, D]
preprocessor: cached_angle
amp_model: base_factor

decay:
A: [BC, D, model: helicity_full]
A:
- [BC, D, model: helicity_full]
- [BD, C, model: helicity_full]
- [CD, B, model: helicity_full]
BC: [B, C]
BD: [B, D]
CD: [C, D]

particle:
$top:
Expand All @@ -13,13 +19,18 @@ particle:
B: { J: 0, P: -1, mass: 0.1 }
C: { J: 0, P: -1, mass: 0.1 }
D: { J: 0, P: -1, mass: 0.1 }
BC: [BC1, BC2]
BC: [BC1]
BC1:
J: 1
P: -1
mass: 1.0
width: 0.2
BC2:
BD:
J: 1
P: -1
mass: 2.0
width: 0.2
CD:
J: 1
P: -1
mass: 2.0
Expand Down
1 change: 1 addition & 0 deletions tf_pwa/tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,4 @@ def test_factor_hel():
phsp = config.generate_phsp(10)
amp = config.get_amplitude()
amp(phsp)
amp.decay_group.get_factor()

0 comments on commit 189be05

Please sign in to comment.