diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index c4d8eeee..e359cc19 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -226,6 +226,7 @@ def plot_partial_wave( chains_id_method=None, phsp_rec=None, cut_function=lambda x: 1, + plot_function=None, **kwargs ): """ @@ -251,6 +252,8 @@ def plot_partial_wave( :param linestyle_file: legend linestyle configuration file name (YAML format), string (such as "legend.yml") """ + if plot_function is None: + plot_function = self._plot_partial_wave if params is None: params = {} @@ -329,13 +332,13 @@ def plot_partial_wave( cut_function=cut_function, **kwargs, ) - self._plot_partial_wave( + plot_function( data_dict, phsp_dict, bg_dict, - prefix, - plot_var_dic, - chain_property, + prefix=prefix, + plot_var_dic=plot_var_dic, + chain_property=chain_property, nll=nll, **kwargs, ) @@ -360,13 +363,13 @@ def plot_partial_wave( cut_function=cut_function, **kwargs, ) - self._plot_partial_wave( + plot_function( data_dict, phsp_dict, bg_dict, - prefix + "d{}_".format(i), - plot_var_dic, - chain_property, + prefix=prefix + "d{}_".format(i), + plot_var_dic=plot_var_dic, + chain_property=chain_property, nll=nll, **kwargs, ) @@ -414,13 +417,13 @@ def plot_partial_wave( phsps_dict[ct] = np.concatenate(phsps_dict[ct]) for ct in bgs_dict: bgs_dict[ct] = np.concatenate(bgs_dict[ct]) - self._plot_partial_wave( + plot_function( datas_dict, phsps_dict, bgs_dict, - prefix + "com_", - plot_var_dic, - chain_property, + prefix=prefix + "com_", + plot_var_dic=plot_var_dic, + chain_property=chain_property, nll=nll, **kwargs, ) @@ -970,6 +973,31 @@ def _plot_var_name(name): raise TypeError("not string or list") +def build_read_var_function(all_var, where={}): + vari = [sym.simplify(i) for i in all_var] + used_var = [] + var_index = [] + all_symbols = set() + for i in vari: + all_symbols = all_symbols | i.free_symbols + all_symbols = tuple(all_symbols) + + for i in all_symbols: + var_index.append(str(i)) + used_var.append(where.get(str(i), str(i))) + + used_var = [_plot_var_name(i) for i in used_var] + + def get_var(dic, tail): + ret = [] + for i in used_var: + ret.append(dic[i + tail]) + return dict(zip(var_index, ret)) + + var_f = [sym.lambdify(all_symbols, i, modules="numpy") for i in vari] + return var_f, get_var + + @ConfigLoader.register_function() def _2d_plot_v2( self, @@ -995,32 +1023,13 @@ def _2d_plot_v2( if "&" in k: continue assert ("x" in v) and ("y" in v) + var_x = sym.simplify(v["x"]) var_y = sym.simplify(v["y"]) where = v.get("where", {}) - used_var = [] - var_index = [] - for i in var_x.free_symbols | var_y.free_symbols: - var_index.append(str(i)) - used_var.append(where.get(str(i), str(i))) - - used_var = [_plot_var_name(i) for i in used_var] - - def get_var(dic, tail): - ret = [] - for i in used_var: - ret.append(dic[i + tail]) - return dict(zip(var_index, ret)) - - var_x_f = sym.lambdify( - tuple(var_x.free_symbols | var_y.free_symbols), - var_x, - modules="numpy", - ) - var_y_f = sym.lambdify( - tuple(var_x.free_symbols | var_y.free_symbols), - var_y, - modules="numpy", + + (var_x_f, var_y_f), get_var = build_read_var_function( + [var_x, var_y], where ) data_1 = var_x_f(**get_var(data_dict, "")) @@ -1124,6 +1133,147 @@ def plot_axis(): print("Finish plotting 2D fitted " + prefix + k) +@ConfigLoader.register_function() +def get_dalitz(config, a, b): + decay = config.get_decay(False) + da = decay.get_decay_chain(a) + db = decay.get_decay_chain(b) + pa = decay.get_particle(a) + pb = decay.get_particle(b) + + for i in da: + if pa in i.outs: + topa = i.core + if pa == i.core: + outs_a = i.outs + for i in db: + if pb in i.outs: + topb = i.core + if pb == i.core: + outs_b = i.outs + same_finals = [i for i in outs_a if i in db.outs] + p1 = [i for i in outs_a if i not in same_finals] + p3 = [i for i in outs_b if i not in same_finals] + check = ((topa == topb),) + check = check and len(same_finals) == 1 + check = check and len(p1) == 1 + check = check and len(p3) == 1 + if not check: + return None + p0, p1, p2, p3 = topa, p1[0], same_finals[0], p3[0] + p0, p1, p2, p3 = [ + config.get_decay().get_particle(str(i)) for i in [p0, p1, p2, p3] + ] + m0, m1, m2, m3 = map(lambda x: x.get_mass(), [p0, p1, p2, p3]) + return m0, m1, m2, m3 + + +@ConfigLoader.register_function() +def get_dalitz_boundary(config, a, b, N=1000): + dalitz = get_dalitz(config, a, b) + assert dalitz is not None, "not valid daliz plot" + m0, m1, m2, m3 = dalitz + # print(m0, m1, m2, m3) + from tf_pwa.angle import kine_min_max + + s12_min, s12_max = float(m1 + m2), float(m0 - m3) + s12 = np.linspace(s12_min**2, s12_max**2, N) + s23_min, s23_max = kine_min_max(s12, *map(float, [m0, m1, m2, m3])) + return s12, np.stack([s23_min, s23_max], axis=-1) + + +@ConfigLoader.register_function() +def plot_adaptive_2dpull( + config, var1, var2, binning=[[2, 2]] * 3, ax=plt, where={}, cut_zero=True +): + import matplotlib as mpl + import matplotlib.colors as mcolors + import matplotlib.patches as mpathes + + from tf_pwa.adaptive_bins import AdaptiveBound + + def plot_function_2dpull( + data_dict, phsp_dict, bg_dict, plot_var_dic, **kwargs + ): + nonlocal ax + if cut_zero: + cut = data_dict["data_weights"] != 0 + else: + cut = np.ones(data_dict["data_weights"].shape, dtype=np.bool) + (var_x_f, var_y_f), get_var = build_read_var_function( + [var1, var2], where=where + ) + x = var_x_f(**get_var(data_dict, ""))[cut] + y = var_y_f(**get_var(data_dict, ""))[cut] + w = data_dict["data_weights"][cut] + x_phsp = var_x_f(**get_var(phsp_dict, "_MC")) + y_phsp = var_y_f(**get_var(phsp_dict, "_MC")) + w_phsp = phsp_dict["MC_total_fit"] + data_cut = np.array([x, y]) + adapter = AdaptiveBound(data_cut, binning) + phsps = adapter.split_data(np.array([x_phsp, y_phsp, w_phsp])) + datas = adapter.split_data(np.array([x, y, w])) + if bg_dict != {}: + x_bg = var_x_f(**get_var(bg_dict, "_sideband")) + y_bg = var_y_f(**get_var(bg_dict, "_sideband")) + w_bg = bg_dict["sideband_weights"] + bgs = adapter.split_data(np.array([x_bg, y_bg, w_bg])) + bound = adapter.get_bounds() + numbers = [] + pulls = [] + int_norm = 1 + for i, bnd in enumerate(bound): + min_x, min_y = bnd[0] + max_x, max_y = bnd[1] + ndata = np.sum(datas[i][2]) + nmc = np.sum(phsps[i][2]) + if bg_dict != {}: + nmc += np.sum(bgs[i][2]) + numbers.append((ndata, nmc)) + pulls.append((ndata - nmc) / np.sqrt(nmc)) + + max_weight = max(np.max(np.abs(pulls)), 5) + + my_cmap = plt.get_cmap("jet") + if ax == plt: + ax = plt.gca() # fig, ax = plt.subplots() + ax.scatter(x, y, s=1, c="black") + for i, bnd in enumerate(bound): + min_x, min_y = bnd[0] + max_x, max_y = bnd[1] + # print(weights[i]) # max_weight) + rect = mpathes.Rectangle( + (min_x, min_y), + max_x - min_x, + max_y - min_y, + linewidth=1, + facecolor=my_cmap( + pulls[i] / max_weight / 2 + 0.5 + ), # max_weight), + edgecolor="none", # black", + zorder=-1, + ) # cmap(weights[i]/max_weight)) + ax.add_patch(rect) + + normal = mpl.colors.Normalize(vmin=-max_weight, vmax=max_weight) + im = mpl.cm.ScalarMappable(norm=normal, cmap=my_cmap) + # ax.colorbar(im) + ax.get_figure().colorbar(im) + ax.title( + "$\\chi^2/Nbins={:.2f}/{}$".format( + np.sum(np.abs(pulls) ** 2), len(bound) + ) + ) + ax.set_xlim([np.min(x_phsp), np.max(x_phsp)]) + ax.set_ylim([np.min(y_phsp), np.max(y_phsp)]) + ax.set_xlabel(var1) + ax.set_ylabel(var2) + + config.plot_partial_wave( + plot_function=plot_function_2dpull, combine_plot=True + ) + + def hist_error(data, bins=50, xrange=None, weights=1.0, kind="poisson"): if not hasattr(weights, "__len__"): weights = [weights] * data.__len__() diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index df36a65d..b7c6218d 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -357,3 +357,12 @@ def test_cp_particles(): config = ConfigLoader(f"{this_dir}/config_self_cp.yml") phsp = config.generate_phsp(100) config.get_amplitude()(phsp) + + +def test_plot_2dpull(toy_config): + import matplotlib.pyplot as plt + + toy_config.plot_adaptive_2dpull("m_R_BC**2", "m_R_CD**2") + a, b = toy_config.get_dalitz_boundary("R_BC", "R_CD") + plt.plot(a, b, color="red") + plt.savefig("adptive_2d.png")