Skip to content

Commit

Permalink
Merge pull request #87 from jiangyi15/extra_var
Browse files Browse the repository at this point in the history
Resolution improvement and extra_var in data
  • Loading branch information
jiangyi15 authored Jul 29, 2023
2 parents 4220537 + 9e066ac commit f3e7662
Show file tree
Hide file tree
Showing 19 changed files with 264 additions and 158 deletions.
3 changes: 2 additions & 1 deletion checks/resolution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ $$ -\ln L = - \sum \ln P(m_j') + N \ln \int P(m_i') d \Phi \approx - \sum_{j} \l

# 6. plot_resolution.py

Draw the histogram of fit results.
Draw the histogram of fit results. Use "phsp" to calculate the weights and draw
the variables in "phsp".
6 changes: 2 additions & 4 deletions checks/resolution/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ data:
resolution_size: 100
data: ["data/data.dat"]
data_weight: ["data/data_w.dat"]
data_origin: ["data/toy_rec.dat"]
data_rec: ["data/toy_rec.dat"]
phsp: ["data/phsp_truth.dat"]
phsp_rec: ["data/phsp_rec.dat"]
phsp_noeff: ["data/phsp_plot.dat"]
phsp_plot: ["data/phsp_plot_rec.dat"]
phsp_plot_re: ["data/phsp_plot_re.npy"]
phsp_plot_re_weight: ["data/phsp_plot_re_w.dat"]

decay:
A:
Expand Down
1 change: 0 additions & 1 deletion checks/resolution/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,3 @@ def run(name):
if __name__ == "__main__":
run("toy")
run("phsp")
run("phsp_plot")
3 changes: 0 additions & 3 deletions checks/resolution/gen_toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
toy.savetxt("data/toy.dat", config.get_dat_order())
phsp.savetxt("data/phsp.dat", config.get_dat_order())

phsp = config.generate_phsp(50000)
phsp.savetxt("data/phsp_plot.dat", config.get_dat_order())


config.plot_partial_wave(
data=[toy], phsp=[phsp], prefix="figure/toy_", plot_pull=True
Expand Down
72 changes: 1 addition & 71 deletions checks/resolution/plot_resolution.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
import sys

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tf_pwa.config_loader import ConfigLoader
from tf_pwa.data import batch_call, data_index, data_split
from tf_pwa.histogram import Hist1D


def sum_resolution(amps, weights, size=1):
amps = tf.reshape(amps * weights, (-1, size))
amps = tf.reduce_sum(amps, axis=-1).numpy()
return amps


def main():
Expand All @@ -21,65 +9,7 @@ def main():
if len(sys.argv) > 1:
param = sys.argv[1]
config.set_params(param + ".json")
amp = config.get_amplitude()

data = config.get_data("data_origin")[0]
phsp = config.get_data("phsp_plot")[0]
phsp_re = config.get_data("phsp_plot_re")[0]

print("data loaded")
amps = batch_call(amp, phsp_re, 20000)
pw = batch_call(amp.partial_weight, phsp_re, 20000)

re_weight = phsp_re["weight"]
re_size = config.resolution_size
print(sum_resolution(tf.ones_like(re_weight), re_weight, re_size))
print(amps)
amps = sum_resolution(amps, re_weight, re_size)
print(np.argmax(amps), amps)

pw = [sum_resolution(i, re_weight, re_size) for i in pw]

m_idx = config.get_data_index("mass", "BC")
m_phsp = data_index(phsp, m_idx).numpy()
m_data = data_index(data, m_idx).numpy()

m_min, m_max = np.min(m_phsp), np.max(m_phsp)

scale = m_data.shape[0] / np.sum(amps)

get_hist = lambda m, w: Hist1D.histogram(
m, weights=w, range=(m_min, m_max), bins=120
)

data_hist = get_hist(m_data, None)
phsp_hist = get_hist(m_phsp, scale * amps)
pw_hist = []
for i in pw:
pw_hist.append(get_hist(m_phsp, scale * i))

ax2 = plt.subplot2grid((4, 1), (3, 0), rowspan=1)
ax = plt.subplot2grid((4, 1), (0, 0), rowspan=3, sharex=ax2)
data_hist.draw_error(ax, label="data")
phsp_hist.draw(ax, label="fit")

for i, j in zip(pw_hist, config.get_decay()):
i.draw_kde(
ax, label=str(j.inner[-1])
) # "/".join([str(k) for k in j.inner]))

(data_hist - phsp_hist).draw_pull(ax2)
ax.set_ylim((0, None))
ax.legend()
ax.set_yscale("symlog")
ax.set_ylabel("Events/{:.1f} MeV".format((m_max - m_min) * 10))
ax2.set_xlabel("M( R_BC )")
ax2.set_ylabel("pull")
# ax2.set_xlim((1.3, 1.7))
ax2.set_ylim((-5, 5))
plt.setp(ax.get_xticklabels(), visible=False)
print(param + "m_R_BC_fit.png")
plt.savefig("m_R_BC_fit_" + param + ".png")
config.plot_partial_wave(prefix="figure/" + param + "_", plot_pull=True)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion checks/resolution/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ python plot_function.py
python sample.py
python plot_resolution.py toy_params
python ../../fit.py -i toy_params.json
python plot_resolution.py
# python plot_resolution.py
27 changes: 2 additions & 25 deletions checks/resolution/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,39 +92,16 @@ def main():

decay_chain = config.get_decay(False).get_decay_chain("BC")

toy = config.get_data("data_origin")[0]
# print(toy)
# exit()
toy = config.get_data("data_rec")[0]

ha = HelicityAngle(decay_chain)
ms, costheta, phi = ha.find_variable(toy)
dat = ha.build_data(ms, costheta, phi)

# print(var)
# print(ha.build_data(*var))
p4, w = random_sample(config, decay_chain, toy)
# exit()

np.savetxt("data/data.dat", np.stack(p4).reshape((-1, 4)))
np.savetxt(
"data/data_w.dat", np.transpose(w).reshape((-1,))
) # np.ones(p4.reshape((-1,4)).shape[0]))

toy = config.get_data("phsp_plot")[0]
p4, w = random_sample(config, decay_chain, toy)

np.save("data/phsp_plot_re.npy", p4.reshape((-1, 4)))
np.savetxt(
"data/phsp_plot_re_w.dat",
np.transpose(w * toy.get_weight()).reshape((-1,)),
) # np.repeat(toy.get_weight(), config.resolution_size))

# toy = config.get_data("phsp_origin")[0]
# p4, w = random_sample(config, decay_chain, toy)


# np.savetxt("data/phsp_re.dat", p4.reshape((-1,4)))
# np.savetxt("data/phsp_re_w.dat", np.transpose(w * toy.get_weight()).reshape((-1,))) # np.repeat(toy.get_weight(), config.resolution_size))
np.savetxt("data/data_w.dat", np.transpose(w).reshape((-1,)))


if __name__ == "__main__":
Expand Down
45 changes: 37 additions & 8 deletions fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,17 @@ def fit(
return fit_result


def write_some_results(config, fit_result, save_root=False):
def write_some_results(config, fit_result, save_root=False, cpu_plot=False):
# plot partial wave distribution
config.plot_partial_wave(fit_result, plot_pull=True, save_root=save_root)
if cpu_plot:
with tf.device("CPU"):
config.plot_partial_wave(
fit_result, plot_pull=True, save_root=save_root
)
else:
config.plot_partial_wave(
fit_result, plot_pull=True, save_root=save_root
)

# calculate fit fractions
phsp_noeff = config.get_phsp_noeff()
Expand All @@ -197,14 +205,24 @@ def write_some_results(config, fit_result, save_root=False):
# chi2, ndf = config.cal_chi2(mass=["R_BC", "R_CD"], bins=[[2,2]]*4)


def write_some_results_combine(config, fit_result, save_root=False):
def write_some_results_combine(
config, fit_result, save_root=False, cpu_plot=False
):

from tf_pwa.applications import fit_fractions

for i, c in enumerate(config.configs):
c.plot_partial_wave(
fit_result, prefix="figure/s{}_".format(i), save_root=save_root
)
if cpu_plot:
with tf.device("CPU"):
c.plot_partial_wave(
fit_result,
prefix="figure/s{}_".format(i),
save_root=save_root,
)
else:
c.plot_partial_wave(
fit_result, prefix="figure/s{}_".format(i), save_root=save_root
)

for it, config_i in enumerate(config.configs):
print("########## fit fractions {}:".format(it))
Expand Down Expand Up @@ -255,6 +273,9 @@ def main():
parser.add_argument(
"--no-GPU", action="store_false", default=True, dest="has_gpu"
)
parser.add_argument(
"--CPU-plot", action="store_true", default=False, dest="cpu_plot"
)
parser.add_argument("-c", "--config", default="config.yml", dest="config")
parser.add_argument(
"-i", "--init_params", default="init_params.json", dest="init"
Expand Down Expand Up @@ -285,10 +306,18 @@ def main():
results.printer,
)
if isinstance(config, ConfigLoader):
write_some_results(config, fit_result, save_root=results.save_root)
write_some_results(
config,
fit_result,
save_root=results.save_root,
cpu_plot=results.cpu_plot,
)
else:
write_some_results_combine(
config, fit_result, save_root=results.save_root
config,
fit_result,
save_root=results.save_root,
cpu_plot=results.cpu_plot,
)


Expand Down
24 changes: 18 additions & 6 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,19 @@ def get_phsp_noeff(self):
)
return self.get_data("phsp")[0]

def get_phsp_plot(self):
if "phsp_plot" in self.config["data"]:
assert len(self.config["data"]["phsp_plot"]) == len(
def get_phsp_plot(self, tail=""):
if "phsp_plot" + tail in self.config["data"]:
assert len(self.config["data"]["phsp_plot" + tail]) == len(
self.config["data"]["phsp"]
)
return self.get_data("phsp_plot")
return self.get_data("phsp")
return self.get_data("phsp_plot" + tail)
return self.get_data("phsp" + tail)

def get_data_rec(self, name):
ret = self.get_data(name + "_rec")
if ret is None:
ret = self.get_data(name)
return ret

def get_decay(self, full=True):
if full:
Expand Down Expand Up @@ -481,7 +487,13 @@ def _get_model(self, vm=None, name=""):
)
else:
model.append(
Model_cfit(amp, wb, bg_function, eff_function)
Model_cfit(
amp,
wb,
bg_function,
eff_function,
resolution_size=self.resolution_size,
)
)
elif "inmc" in self.config["data"]:
float_wmc = self.config["data"].get(
Expand Down
36 changes: 28 additions & 8 deletions tf_pwa/config_loader/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import os
import re
import warnings

import numpy as np
Expand Down Expand Up @@ -107,10 +108,11 @@ def particle_item():
return new_order

def get_weight_sign(self, idx):
negtive_idx = self.dic.get("negtive_idx", ["bg"])
negtive_idx = self.dic.get("negtive_idx", ["bg*"])
weight_sign = 1
if idx in negtive_idx:
weight_sign = -1
for i in negtive_idx:
if re.match(i, idx):
weight_sign = -1
return weight_sign

def get_data(self, idx) -> dict:
Expand Down Expand Up @@ -204,10 +206,16 @@ def load_weight_file(self, weight_files):
ret = []
if isinstance(weight_files, list):
for i in weight_files:
data = np.loadtxt(i).reshape((-1,))
if i.endswith(".npy"):
data = np.load(i).reshape((-1,))
else:
data = np.loadtxt(i).reshape((-1,))
ret.append(data)
elif isinstance(weight_files, str):
data = np.loadtxt(weight_files).reshape((-1,))
if weight_files.endswith(".npy"):
data = np.load(weight_files).reshape((-1,))
else:
data = np.loadtxt(weight_files).reshape((-1,))
ret.append(data)
else:
raise TypeError(
Expand Down Expand Up @@ -362,7 +370,7 @@ def get_data(self, idx) -> list:
charge = self.dic.get(idx + "_charge", None)
if charge is None:
charge = [None] * len(files)
elif not isinstance(charge[0], list):
elif not isinstance(charge, list):
charge = [charge]
ret = [
self.load_data(i, j, weight_sign, k)
Expand All @@ -377,13 +385,25 @@ def get_data(self, idx) -> list:
if isinstance(bg_value, str):
bg_value = [bg_value]
for i, file_name in enumerate(bg_value):
ret[i]["bg_value"] = np.reshape(np.loadtxt(file_name), (-1,))
ret[i]["bg_value"] = self.load_weight_file(file_name)
eff_value = self.dic.get(idx + "_eff_value", None)
if eff_value is not None:
if isinstance(eff_value, str):
eff_value = [eff_value]
for i, file_name in enumerate(eff_value):
ret[i]["eff_value"] = np.reshape(np.loadtxt(file_name), (-1,))
ret[i]["eff_value"] = self.load_weight_file(file_name)
extra_var = self.dic.get("extra_var", None)
if extra_var:
for i in extra_var:
idx_var = self.dic.get(idx + "_" + i, None)
if idx_var is not None:
for j, file_name in enumerate(idx_var):
ret[j][i] = self.load_weight_file(file_name)
else:
for j, k in enumerate(ret):
ret[j][i] = extra_var[i]["default"] * np.ones(
data_shape(k)
)
ret = self.process_scale(idx, ret)
return ret

Expand Down
Loading

0 comments on commit f3e7662

Please sign in to comment.