Skip to content

Commit

Permalink
feat: plot 2d pull directly
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Aug 25, 2023
1 parent 11f56b3 commit 7778b32
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 35 deletions.
220 changes: 185 additions & 35 deletions tf_pwa/config_loader/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def plot_partial_wave(
chains_id_method=None,
phsp_rec=None,
cut_function=lambda x: 1,
plot_function=None,
**kwargs
):
"""
Expand All @@ -251,6 +252,8 @@ def plot_partial_wave(
:param linestyle_file: legend linestyle configuration file name (YAML format), string (such as "legend.yml")
"""
if plot_function is None:
plot_function = self._plot_partial_wave

if params is None:
params = {}
Expand Down Expand Up @@ -329,13 +332,13 @@ def plot_partial_wave(
cut_function=cut_function,
**kwargs,
)
self._plot_partial_wave(
plot_function(
data_dict,
phsp_dict,
bg_dict,
prefix,
plot_var_dic,
chain_property,
prefix=prefix,
plot_var_dic=plot_var_dic,
chain_property=chain_property,
nll=nll,
**kwargs,
)
Expand All @@ -360,13 +363,13 @@ def plot_partial_wave(
cut_function=cut_function,
**kwargs,
)
self._plot_partial_wave(
plot_function(
data_dict,
phsp_dict,
bg_dict,
prefix + "d{}_".format(i),
plot_var_dic,
chain_property,
prefix=prefix + "d{}_".format(i),
plot_var_dic=plot_var_dic,
chain_property=chain_property,
nll=nll,
**kwargs,
)
Expand Down Expand Up @@ -414,13 +417,13 @@ def plot_partial_wave(
phsps_dict[ct] = np.concatenate(phsps_dict[ct])
for ct in bgs_dict:
bgs_dict[ct] = np.concatenate(bgs_dict[ct])
self._plot_partial_wave(
plot_function(
datas_dict,
phsps_dict,
bgs_dict,
prefix + "com_",
plot_var_dic,
chain_property,
prefix=prefix + "com_",
plot_var_dic=plot_var_dic,
chain_property=chain_property,
nll=nll,
**kwargs,
)
Expand Down Expand Up @@ -970,6 +973,31 @@ def _plot_var_name(name):
raise TypeError("not string or list")


def build_read_var_function(all_var, where={}):
vari = [sym.simplify(i) for i in all_var]
used_var = []
var_index = []
all_symbols = set()
for i in vari:
all_symbols = all_symbols | i.free_symbols
all_symbols = tuple(all_symbols)

for i in all_symbols:
var_index.append(str(i))
used_var.append(where.get(str(i), str(i)))

used_var = [_plot_var_name(i) for i in used_var]

def get_var(dic, tail):
ret = []
for i in used_var:
ret.append(dic[i + tail])
return dict(zip(var_index, ret))

var_f = [sym.lambdify(all_symbols, i, modules="numpy") for i in vari]
return var_f, get_var


@ConfigLoader.register_function()
def _2d_plot_v2(
self,
Expand All @@ -995,32 +1023,13 @@ def _2d_plot_v2(
if "&" in k:
continue
assert ("x" in v) and ("y" in v)

var_x = sym.simplify(v["x"])
var_y = sym.simplify(v["y"])
where = v.get("where", {})
used_var = []
var_index = []
for i in var_x.free_symbols | var_y.free_symbols:
var_index.append(str(i))
used_var.append(where.get(str(i), str(i)))

used_var = [_plot_var_name(i) for i in used_var]

def get_var(dic, tail):
ret = []
for i in used_var:
ret.append(dic[i + tail])
return dict(zip(var_index, ret))

var_x_f = sym.lambdify(
tuple(var_x.free_symbols | var_y.free_symbols),
var_x,
modules="numpy",
)
var_y_f = sym.lambdify(
tuple(var_x.free_symbols | var_y.free_symbols),
var_y,
modules="numpy",

(var_x_f, var_y_f), get_var = build_read_var_function(
[var_x, var_y], where
)

data_1 = var_x_f(**get_var(data_dict, ""))
Expand Down Expand Up @@ -1124,6 +1133,147 @@ def plot_axis():
print("Finish plotting 2D fitted " + prefix + k)


@ConfigLoader.register_function()
def get_dalitz(config, a, b):
decay = config.get_decay(False)
da = decay.get_decay_chain(a)
db = decay.get_decay_chain(b)
pa = decay.get_particle(a)
pb = decay.get_particle(b)

for i in da:
if pa in i.outs:
topa = i.core
if pa == i.core:
outs_a = i.outs
for i in db:
if pb in i.outs:
topb = i.core
if pb == i.core:
outs_b = i.outs
same_finals = [i for i in outs_a if i in db.outs]
p1 = [i for i in outs_a if i not in same_finals]
p3 = [i for i in outs_b if i not in same_finals]
check = ((topa == topb),)
check = check and len(same_finals) == 1
check = check and len(p1) == 1
check = check and len(p3) == 1
if not check:
return None
p0, p1, p2, p3 = topa, p1[0], same_finals[0], p3[0]
p0, p1, p2, p3 = [
config.get_decay().get_particle(str(i)) for i in [p0, p1, p2, p3]
]
m0, m1, m2, m3 = map(lambda x: x.get_mass(), [p0, p1, p2, p3])
return m0, m1, m2, m3


@ConfigLoader.register_function()
def get_dalitz_boundary(config, a, b, N=1000):
dalitz = get_dalitz(config, a, b)
assert dalitz is not None, "not valid daliz plot"
m0, m1, m2, m3 = dalitz
# print(m0, m1, m2, m3)
from tf_pwa.angle import kine_min_max

s12_min, s12_max = float(m1 + m2), float(m0 - m3)
s12 = np.linspace(s12_min**2, s12_max**2, N)
s23_min, s23_max = kine_min_max(s12, *map(float, [m0, m1, m2, m3]))
return s12, np.stack([s23_min, s23_max], axis=-1)


@ConfigLoader.register_function()
def plot_adaptive_2dpull(
config, var1, var2, binning=[[2, 2]] * 3, ax=plt, where={}, cut_zero=True
):
import matplotlib as mpl
import matplotlib.colors as mcolors
import matplotlib.patches as mpathes

from tf_pwa.adaptive_bins import AdaptiveBound

def plot_function_2dpull(
data_dict, phsp_dict, bg_dict, plot_var_dic, **kwargs
):
nonlocal ax
if cut_zero:
cut = data_dict["data_weights"] != 0
else:
cut = np.ones(data_dict["data_weights"].shape, dtype=np.bool)
(var_x_f, var_y_f), get_var = build_read_var_function(
[var1, var2], where=where
)
x = var_x_f(**get_var(data_dict, ""))[cut]
y = var_y_f(**get_var(data_dict, ""))[cut]
w = data_dict["data_weights"][cut]
x_phsp = var_x_f(**get_var(phsp_dict, "_MC"))
y_phsp = var_y_f(**get_var(phsp_dict, "_MC"))
w_phsp = phsp_dict["MC_total_fit"]
data_cut = np.array([x, y])
adapter = AdaptiveBound(data_cut, binning)
phsps = adapter.split_data(np.array([x_phsp, y_phsp, w_phsp]))
datas = adapter.split_data(np.array([x, y, w]))
if bg_dict != {}:
x_bg = var_x_f(**get_var(bg_dict, "_sideband"))
y_bg = var_y_f(**get_var(bg_dict, "_sideband"))
w_bg = bg_dict["sideband_weights"]
bgs = adapter.split_data(np.array([x_bg, y_bg, w_bg]))
bound = adapter.get_bounds()
numbers = []
pulls = []
int_norm = 1
for i, bnd in enumerate(bound):
min_x, min_y = bnd[0]
max_x, max_y = bnd[1]
ndata = np.sum(datas[i][2])
nmc = np.sum(phsps[i][2])
if bg_dict != {}:
nmc += np.sum(bgs[i][2])
numbers.append((ndata, nmc))
pulls.append((ndata - nmc) / np.sqrt(nmc))

max_weight = max(np.max(np.abs(pulls)), 5)

my_cmap = plt.get_cmap("jet")
if ax == plt:
ax = plt.gca() # fig, ax = plt.subplots()
ax.scatter(x, y, s=1, c="black")
for i, bnd in enumerate(bound):
min_x, min_y = bnd[0]
max_x, max_y = bnd[1]
# print(weights[i]) # max_weight)
rect = mpathes.Rectangle(
(min_x, min_y),
max_x - min_x,
max_y - min_y,
linewidth=1,
facecolor=my_cmap(
pulls[i] / max_weight / 2 + 0.5
), # max_weight),
edgecolor="none", # black",
zorder=-1,
) # cmap(weights[i]/max_weight))
ax.add_patch(rect)

normal = mpl.colors.Normalize(vmin=-max_weight, vmax=max_weight)
im = mpl.cm.ScalarMappable(norm=normal, cmap=my_cmap)
# ax.colorbar(im)
ax.get_figure().colorbar(im)
ax.title(
"$\\chi^2/Nbins={:.2f}/{}$".format(
np.sum(np.abs(pulls) ** 2), len(bound)
)
)
ax.set_xlim([np.min(x_phsp), np.max(x_phsp)])
ax.set_ylim([np.min(y_phsp), np.max(y_phsp)])
ax.set_xlabel(var1)
ax.set_ylabel(var2)

config.plot_partial_wave(
plot_function=plot_function_2dpull, combine_plot=True
)


def hist_error(data, bins=50, xrange=None, weights=1.0, kind="poisson"):
if not hasattr(weights, "__len__"):
weights = [weights] * data.__len__()
Expand Down
9 changes: 9 additions & 0 deletions tf_pwa/tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,12 @@ def test_cp_particles():
config = ConfigLoader(f"{this_dir}/config_self_cp.yml")
phsp = config.generate_phsp(100)
config.get_amplitude()(phsp)


def test_plot_2dpull(toy_config):
import matplotlib.pyplot as plt

toy_config.plot_adaptive_2dpull("m_R_BC**2", "m_R_CD**2")
a, b = toy_config.get_dalitz_boundary("R_BC", "R_CD")
plt.plot(a, b, color="red")
plt.savefig("adptive_2d.png")

0 comments on commit 7778b32

Please sign in to comment.