Skip to content

Commit

Permalink
Merge branch 'dev' into cov_ten
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Feb 6, 2024
2 parents 8b4a2d4 + a2be967 commit b12d1b3
Show file tree
Hide file tree
Showing 8 changed files with 565 additions and 98 deletions.
2 changes: 1 addition & 1 deletion tf_pwa/amp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class ParticleGS(Particle):
f(m) = \Gamma_0 \frac{m_0 ^2 }{q_0^3} \left[q^2 [h(m)-h(m_0)] + (m_0^2 - m^2) q_0^2 \frac{d h}{d m}|_{m0} \right]
.. math::
h(m) = \frac{2}{\pi} \frac{q}{m} \ln \left(\frac{m+q}{2m_{\pi}} \right)
h(m) = \frac{2}{\pi} \frac{q}{m} \ln \left(\frac{m+2q}{2m_{\pi}} \right)
.. math::
\frac{d h}{d m}|_{m0} = h(m_0) [(8q_0^2)^{-1} - (2m_0^2)^{-1}] + (2\pi m_0^2)^{-1}
Expand Down
70 changes: 59 additions & 11 deletions tf_pwa/amp/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def get_bin_index(self, m):

@register_particle("linear_npy")
class InterpLinearNpy(InterpolationParticle):
"""
Linear interpolation model from a `.npy` file with array of [mi, re(ai), im(ai)].
Required `file: path_of_file.npy`, for the path of `.npy` file.
The example is `exp(5 I m)`.
.. plot::
>>> import tempfile
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".npy")
>>> m = np.linspace(0.2, 0.9)
>>> mi = m[::5]
>>> np.save(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_npy", {"file": a})
>>> _ = axs[3].plot(np.cos(m*5), np.sin(m*5), "--")
"""

def __init__(self, *args, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.load(self.input_file)
Expand All @@ -103,29 +124,56 @@ def init_params(self):

def get_point_values(self):
v_r = np.concatenate([[0.0], self.data[:, 1], [0.0]])
v_i = np.concatenate([[0.0], self.data[:, 1], [0.0]])
v_i = np.concatenate([[0.0], self.data[:, 2], [0.0]])
return self.data[:, 0], v_r, v_i

def interp(self, m):
x, p_r, p_i = self.get_point_values()
bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=x)
bin_idx = (bin_idx + len(self.bound)) % len(self.bound)
ret_r_l = tf.gather(p_r[1:], bin_idx)
ret_i_l = tf.gather(p_r[1:], bin_idx)
ret_r_r = tf.gather(p_r[:-1], bin_idx)
ret_i_r = tf.gather(p_r[:-1], bin_idx)
delta = np.concatenate(
[[1.0], self.data[1:, 1] - self.data[:-1, 1], [1.0]]
)
bin_idx = (bin_idx) % (len(self.bound) + 1)
ret_r_r = tf.gather(p_r[1:], bin_idx)
ret_i_r = tf.gather(p_i[1:], bin_idx)
ret_r_l = tf.gather(p_r[:-1], bin_idx)
ret_i_l = tf.gather(p_i[:-1], bin_idx)
delta = np.concatenate([[1e20], x[1:] - x[:-1], [1e20]])
x_left = np.concatenate([[x[0] - 1], x])
delta = tf.gather(delta, bin_idx)
x_left = tf.gather(x_left, bin_idx)
step = (m - x_left) / delta
a = step * (ret_r_l - ret_r_r)
b = step * (ret_i_l - ret_i_r)
a = step * (ret_r_r - ret_r_l) + ret_r_l
b = step * (ret_i_r - ret_i_l) + ret_i_l
return tf.complex(a, b)


@register_particle("linear_txt")
class InterpLinearTxt(InterpLinearNpy):
"""Linear interpolation model from a `.txt` file with array of [mi, re(ai), im(ai)].
Required `file: path_of_file.txt`, for the path of `.txt` file.
The example is `exp(5 I m)`.
.. plot::
>>> import tempfile
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".txt")
>>> m = np.linspace(0.2, 0.9)
>>> mi = m[::5]
>>> np.savetxt(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_txt", {"file": a})
>>> _ = axs[3].plot(np.cos(m*5), np.sin(m*5), "--")
"""

def __init__(self, *args, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.loadtxt(self.input_file)
points = self.data[:, 0]
kwargs["points"] = points
super(InterpLinearNpy, self).__init__(*args, **kwargs)


@register_particle("interp")
class Interp(InterpolationParticle):
"""linear interpolation for real number"""
Expand Down
7 changes: 3 additions & 4 deletions tf_pwa/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,15 @@ def get_metric(self):
"""
return tf.cast(tf.constant([1.0, -1.0, -1.0, -1.0]), self.dtype)

def Dot(self, others):
s = self * others * LorentzVector.get_metric(self)
def Dot(self, other):
s = self * other * LorentzVector.get_metric(self)
return tf.reduce_sum(s, axis=-1)

def M2(self):
"""
The invariant mass squared
"""
s = self * self * LorentzVector.get_metric(self)
return tf.reduce_sum(s, axis=-1)
return LorentzVector.Dot(self, self)

def M(self):
"""
Expand Down
85 changes: 75 additions & 10 deletions tf_pwa/config_loader/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,48 @@ def plot_partial_wave(
)


@ConfigLoader.register_function()
def plot_partial_wave_interf(self, res1, res2, **kwargs):

labels = ["data"]
if self.config["data"].get("model", "auto") == "cfit":
labels.append("background")
elif self.config["data"].get("bg", None) is not None:
labels.append("background")

if kwargs.get("ref_amp", None) is not None:
labels.append("reference fit")
labels.append("total fit")

if kwargs.get("force_legend_labels", None) is not None:
labels = kwargs["force_legend_labels"]
del kwargs["force_legend_labels"]

labels += [str(res1), str(res2), "sum", "interference"]

if not isinstance(res1, list):
res1 = [res1]
if not isinstance(res2, list):
res2 = [res2]

amp = self.get_amplitude()

def weights_function(data, **kwargs):
with amp.temp_used_res(res1):
a = amp(data)
with amp.temp_used_res(res2):
b = amp(data)
with amp.temp_used_res(res1 + res2):
ab = amp(data)
return [a, b, ab, ab - a - b]

self.plot_partial_wave(
partial_waves_function=weights_function,
force_legend_labels=labels,
**kwargs,
)


@ConfigLoader.register_function()
def _get_plot_partial_wave_input(
self,
Expand All @@ -333,6 +375,7 @@ def _get_plot_partial_wave_input(
save_root=False,
chains_id_method=None,
cut_function=lambda x: 1,
partial_waves_function=None,
**kwargs
):
"""
Expand Down Expand Up @@ -393,7 +436,13 @@ def _get_plot_partial_wave_input(
if chains_id_method is not None:
self.chains_id_method = chains_id_method

chain_property = create_chain_property(self, res)
if partial_waves_function is None:
chain_property = create_chain_property(self, res)
else:
chain_property = [
[i, "pw_{}".format(i), "partial waves {}".format(i), None]
for i in range(100)
]
plot_var_dic = create_plot_var_dic(self.plot_params)

if self._Ngroup == 1:
Expand All @@ -411,6 +460,7 @@ def _get_plot_partial_wave_input(
res=res,
phsp_rec=phsp_rec[0],
cut_function=cut_function,
partial_waves_function=partial_waves_function,
**kwargs,
)
all_data = data_dict, phsp_dict, bg_dict
Expand All @@ -435,6 +485,7 @@ def _get_plot_partial_wave_input(
save_root=save_root,
phsp_rec=phsp_rec[i],
cut_function=cut_function,
partial_waves_function=partial_waves_function,
**kwargs,
)
all_data = data_dict, phsp_dict, bg_dict
Expand Down Expand Up @@ -463,6 +514,7 @@ def _get_plot_partial_wave_input(
res=res,
phsp_rec=phsp_rec[i],
cut_function=cut_function,
partial_waves_function=partial_waves_function,
**kwargs,
)
# self._plot_partial_wave(data_dict, phsp_dict, bg_dict, path+'d{}_'.format(i), plot_var_dic, chain_property, **kwargs)
Expand Down Expand Up @@ -528,6 +580,7 @@ def _cal_partial_wave(
ref_amp=None,
phsp_rec=None,
cut_function=lambda x: 1,
partial_waves_function=None,
**kwargs
):
data_dict = {}
Expand Down Expand Up @@ -566,9 +619,14 @@ def _cal_partial_wave(
norm_frac = n_sig / np.sum(total_weight)
if ref_amp is not None:
norm_frac_ref = n_sig / np.sum(total_weight_ref)
weights = batch_call_numpy(
lambda x: amp.partial_weight(x, combine=res), phsp, batch
)
if partial_waves_function is None:
weights = batch_call_numpy(
lambda x: amp.partial_weight(x, combine=res), phsp, batch
)
else:
weights = batch_call_numpy(
lambda x: partial_waves_function(x, combine=res), phsp, batch
)
data_weights = data.get("weight", np.ones((data_shape(data),)))
data_dict["data_weights"] = (
batch_call_numpy(cut_function, data, batch) * data_weights
Expand All @@ -583,20 +641,24 @@ def _cal_partial_wave(
)
if bg is not None:
bg_weight = -w_bkg
# sideband weight
bg_dict["sideband_weights"] = (
batch_call_numpy(cut_function, bg, batch) * bg_weight
) # sideband weight
)
for i, name_i, label, _ in chain_property:
if i >= len(weights):
break
weight_i = (
weights[i]
* norm_frac
* bin_scale
* phsp.get("weight", 1.0)
* phsp.get("eff_value", 1.0)
)
# MC partial weight
phsp_dict["MC_{0}_{1}_fit".format(i, name_i)] = cut_phsp * sr(
weight_i
) # MC partial weight
)
for name in plot_var_dic:
idx = plot_var_dic[name]["idx"]
trans = lambda x: np.reshape(plot_var_dic[name]["trans"](x), (-1,))
Expand Down Expand Up @@ -673,6 +735,7 @@ def _plot_partial_wave(
add_chi2=False,
dpi=300,
force_legend_labels=None,
labels=None,
**kwargs
):
# cmap = plt.get_cmap("jet")
Expand Down Expand Up @@ -759,7 +822,7 @@ def _plot_partial_wave(
le = bg_hist.draw_bar(
ax, label="back ground", alpha=0.5, color="grey"
)
has_negative = has_negative and np.any(bg_hist.count < 0)
has_negative = has_negative or np.any(bg_hist.count < 0)
fitted_hist = fitted_hist + bg_hist
if ref_amp is not None:
fitted_hist_ref = fitted_hist_ref + bg_hist
Expand All @@ -769,11 +832,11 @@ def _plot_partial_wave(
le2 = fitted_hist_ref.draw(
ax, label="reference fit", color="red", linewidth=2
)
has_negative = has_negative and np.any(fitted_hist_ref.count < 0)
has_negative = has_negative or np.any(fitted_hist_ref.count < 0)
legends.append(le2[0])
legends_label.append("reference fit")
le2 = fitted_hist.draw(ax, label="total fit", color="black")
has_negative = has_negative and np.any(fitted_hist.count < 0)
has_negative = has_negative or np.any(fitted_hist.count < 0)
legends.append(le2[0])
legends_label.append("total fit")

Expand Down Expand Up @@ -819,7 +882,7 @@ def _plot_partial_wave(
linewidth=1,
)

has_negative = has_negative and np.any(hist_i.count < 0)
has_negative = has_negative or np.any(hist_i.count < 0)
legends.append(le3[0])
legends_label.append(label)
if yscale == "log":
Expand All @@ -834,6 +897,8 @@ def _plot_partial_wave(
if force_legend_labels:
legends_label = force_legend_labels
if has_legend:
if labels is not None:
legends_label = labels
if legend_outside:
leg = ax.legend(
legends,
Expand Down
Loading

0 comments on commit b12d1b3

Please sign in to comment.