Skip to content

Commit

Permalink
Merge pull request #85 from jiangyi15/id_particle
Browse files Browse the repository at this point in the history
swap index order of amplitude for identical particles.
  • Loading branch information
jiangyi15 authored Jun 25, 2023
2 parents 9850bf3 + 080149d commit 4220537
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 44 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ jobs:
run: |
conda install --file requirements-min.txt -y
python -m pip install -e . --no-deps
conda install pylint -y
conda install pre-commit -c conda-forge -y
pre-commit install
pre-commit run -a
Expand All @@ -117,6 +118,7 @@ jobs:
run: |
conda install --file tensorflow_2_6_requirements.txt -c conda-forge -y
python -m pip install -e . --no-deps
conda install pylint -y
conda install pre-commit -c conda-forge -y
pre-commit install
pre-commit run -a
Expand Down
7 changes: 5 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ repos:
hooks:
- id: isort

- repo: https://github.com/pre-commit/mirrors-pylint
rev: v2.5.0 # Use the sha / tag you want to point at
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
args: ["--rcfile=.pylintrc", "--score=no"]

- repo: https://github.com/myint/rstcheck
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ test =
dev =
%(doc)s
%(test)s
pylint
pre-commit
all =
%(dev)s
Expand Down
49 changes: 48 additions & 1 deletion tf_pwa/amp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,15 +1409,48 @@ def get_angle_amp(self, data):
@functools.lru_cache()
def get_swap_factor(self, key):
factor = 1.0
used = []
for i, j in zip(self.identical_particles, key[1]):
p = self.get_particle(i[0])
if int(p.J * 2) % 2 == 0:
continue
for m, n in zip(i, j):
if (m, n) in used or (n, m) in used:
continue
used.append((m, n))
if m != n:
factor *= -1.0
return factor

@functools.lru_cache()
def get_id_swap_transpose(self, key, n):
_, change = key
# print(key)
old_order = [str(i) for i in self.outs]
trans = []
for i, j in zip(self.identical_particles, change):
for k, l in zip(i, j):
trans.append((k, l))
trans = tuple(trans)
return self.get_swap_transpose(trans, n)

@functools.lru_cache()
def get_swap_transpose(self, trans, n):
trans = dict(trans)
# print(trans)
tmp = {v: k for k, v in trans.items()}
tmp.update(trans)
trans = tmp
# print(trans)
old_order = [str(i) for i in self.outs]
new_order = []
for i in old_order:
new_order.append(trans.get(i, i))
index_map = {k: i for i, k in enumerate(new_order)}
trans_order = [index_map[str(i)] for i in self.outs]
diff = n - len(trans_order)
return [i for i in range(diff)] + [i + diff for i in trans_order]

def get_amp2(self, data):
amp = self.get_amp(data)
id_swap = data.get("id_swap", {})
Expand All @@ -1426,6 +1459,9 @@ def get_amp2(self, data):
factor = self.get_swap_factor(k)
amp_swap = factor * self.get_amp(new_data)
# print(k, amp, amp_swap)
swap_index = self.get_id_swap_transpose(k, len(amp_swap.shape))
# print(swap_index)
amp_swap = tf.transpose(amp_swap, swap_index)
amp = amp + amp_swap
return amp

Expand All @@ -1441,15 +1477,26 @@ def get_amp3(self, data):
name_map = {str(i): i for i in self.outs}
frac = 1.0
same_particle = []
change = []
for a, b in cg:
for i, j in zip(a, b):
if i == j:
same_particle.append(i)
frac = frac * getattr(name_map[i], "C", -1)
else:
change.append((i, j))
transpose = self.get_swap_transpose(
tuple(change), len(amp_swap.shape)
)
p_reverse = [Ellipsis] + [
slice(None, None, -1) for i in range(len(amp_swap.shape) - 1)
]
amp = amp + amp_swap.__getitem__(p_reverse) * frac

amp = (
amp
+ tf.transpose(amp_swap, transpose).__getitem__(p_reverse)
* frac
)
return amp

def sum_amp(self, data, cached=True):
Expand Down
37 changes: 37 additions & 0 deletions tf_pwa/amp/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,43 @@ def get_bin_index(self, m):
return bin_idx


@register_particle("linear_npy")
class InterpLinearNpy(InterpolationParticle):
def __init__(self, *args, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.load(self.input_file)
points = self.data[:, 0]
kwargs["points"] = points
super().__init__(*args, **kwargs)

def init_params(self):
pass

def get_point_values(self):
v_r = np.concatenate([[0.0], self.data[:, 1], [0.0]])
v_i = np.concatenate([[0.0], self.data[:, 1], [0.0]])
return self.data[:, 0], v_r, v_i

def interp(self, m):
x, p_r, p_i = self.get_point_values()
bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=x)
bin_idx = (bin_idx + len(self.bound)) % len(self.bound)
ret_r_l = tf.gather(p_r[1:], bin_idx)
ret_i_l = tf.gather(p_r[1:], bin_idx)
ret_r_r = tf.gather(p_r[:-1], bin_idx)
ret_i_r = tf.gather(p_r[:-1], bin_idx)
delta = np.concatenate(
[[1.0], self.data[1:, 1] - self.data[:-1, 1], [1.0]]
)
x_left = np.concatenate([[x[0] - 1], x])
delta = tf.gather(delta, bin_idx)
x_left = tf.gather(x_left, bin_idx)
step = (m - x_left) / delta
a = step * (ret_r_l - ret_r_r)
b = step * (ret_i_l - ret_i_r)
return tf.complex(a, b)


@register_particle("interp")
class Interp(InterpolationParticle):
"""linear interpolation for real number"""
Expand Down
Loading

0 comments on commit 4220537

Please sign in to comment.