From 1d4961b0cae71617a7129aa043c4af05cd5f25c7 Mon Sep 17 00:00:00 2001 From: wwdws1 <270410665@qq.com> Date: Wed, 23 Aug 2023 22:33:16 +0800 Subject: [PATCH 1/5] add option for legend position --- config.sample.yml | 4 +- tf_pwa/config_loader/config_loader.py | 10 ++++ tf_pwa/config_loader/plot.py | 71 ++++++++++++++++++++++----- tf_pwa/tests/config_toy.yml | 10 ++-- 4 files changed, 77 insertions(+), 18 deletions(-) diff --git a/config.sample.yml b/config.sample.yml index 24976639..b92e7a17 100644 --- a/config.sample.yml +++ b/config.sample.yml @@ -5,7 +5,6 @@ ## line starts with `##` is documnet. ## line starts with `#` is option. - ## data entry is the data file name list used for fit data: ## 4-momentum order of final particles in dat files [option] @@ -42,7 +41,6 @@ data: # data_charge: ["data/data4600_cc.dat"] # charge conjugation condition same as weight # cp_trans: True # when used charge conjugation as above, this do p -> -p for charge conjugation process. - ## `decay` describe the decay structure, each node can be a list of particle name decay: ## each entry is a list of direct decay particle, or a list or the list. @@ -136,6 +134,8 @@ plot: ## plot configuration TODO config: bins: 50 + legend_position: "default" + ## The position of legend, can be "default"="best", "left"="upper left", "right"="upper right", "middle"="upper center" and "outside", this option can also be set in each plot item. ## invariant mass mass: R_CD: diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 4700874d..db159c89 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -1133,6 +1133,10 @@ def get_mass_vars(self): units = v.get("units", "GeV") bins = v.get("bins", self.defaults_config.get("bins", 50)) legend = v.get("legend", self.defaults_config.get("legend", True)) + legend_position = v.get( + "legend_position", + self.defaults_config.get("legend_position", "best"), + ) yscale = v.get( "yscale", self.defaults_config.get("yscale", "linear") ) @@ -1146,6 +1150,7 @@ def get_mass_vars(self): "m", ), "legend": legend, + "legend_position": legend_position, "range": xrange, "bins": bins, "trans": trans, @@ -1199,6 +1204,10 @@ def get_angle_vars(self, is_align=False): legend = v.get( "legend", self.defaults_config.get("legend", False) ) + legend_position = v.get( + "legend_position", + self.defaults_config.get("legend_position", "best"), + ) yscale = v.get( "yscale", self.defaults_config.get("yscale", "linear") ) @@ -1225,6 +1234,7 @@ def get_angle_vars(self, is_align=False): "bins": bins, "range": xrange, "legend": legend, + "legend_position": legend_position, "yscale": yscale, } diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index 8ae49748..fab29dac 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -297,12 +297,14 @@ def plot_partial_wave( has_legend = conf.get("legend", False) xrange = conf.get("range", None) bins = conf.get("bins", None) + legend_position = conf.get("legend_position", "left") units = conf.get("units", "") yscale = conf.get("yscale", "linear") plot_var_dic[name] = { "display": display, "upper_ylim": upper_ylim, "legend": has_legend, + "legend_position": legend_position, "idx": idx, "trans": trans, "range": xrange, @@ -369,7 +371,6 @@ def plot_partial_wave( **kwargs, ) else: - for dt, mc, sb, w_bkg, i in zip( data, phsp, bg, ws_bkg, range(self._Ngroup) ): @@ -602,7 +603,6 @@ def _plot_partial_wave( ref_amp=None, **kwargs ): - # cmap = plt.get_cmap("jet") # N = 10 # colors = [cmap(float(i) / (N+1)) for i in range(1, N+1)] @@ -622,6 +622,7 @@ def _plot_partial_wave( display = plot_var_dic[name]["display"] upper_ylim = plot_var_dic[name]["upper_ylim"] has_legend = plot_var_dic[name]["legend"] + legend_position = plot_var_dic[name]["legend_position"] bins = plot_var_dic[name]["bins"] units = plot_var_dic[name]["units"] xrange = plot_var_dic[name]["range"] @@ -636,7 +637,10 @@ def _plot_partial_wave( ) fig = plt.figure() if plot_delta or plot_pull: - ax = plt.subplot2grid((4, 1), (0, 0), rowspan=3) + if legend_position == "outside": + ax = plt.subplot2grid((4, 6), (0, 0), rowspan=3, colspan=5) + else: + ax = plt.subplot2grid((4, 1), (0, 0), rowspan=3) else: ax = fig.add_subplot(1, 1, 1) @@ -728,13 +732,53 @@ def _plot_partial_wave( ax.set_xlim(xrange) ax.set_yscale(yscale) if has_legend: - leg = ax.legend( - legends, - legends_label, - frameon=False, - labelspacing=0.1, - borderpad=0.0, - ) + if legend_position == "best" or "default": + leg = ax.legend( + legends, + legends_label, + frameon=False, + labelspacing=0.1, + borderpad=0.0, + loc="best", + ) + elif legend_position == "left" or "upper left": + leg = ax.legend( + legends, + legends_label, + frameon=False, + labelspacing=0.1, + borderpad=0.0, + loc="upper left", + ) + elif legend_position == "middle" or "upper center": + leg = ax.legend( + legends, + legends_label, + frameon=False, + labelspacing=0.1, + borderpad=0.0, + loc="upper center", + ) + elif legend_position == "right" or "upper right": + leg = ax.legend( + legends, + legends_label, + frameon=False, + labelspacing=0.1, + borderpad=0.0, + loc="upper right", + ) + elif legend_position == "outside": + leg = ax.legend( + legends, + legends_label, + frameon=False, + fontsize="small", + labelspacing=0.1, + borderpad=0.0, + bbox_to_anchor=(1.02, 0.5), + loc=6, + ) if nll is None: ax.set_title(display, fontsize="xx-large") else: @@ -748,7 +792,10 @@ def _plot_partial_wave( ax.set_ylabel("Events/{:.3f}{}".format(ywidth, units)) if plot_delta or plot_pull: plt.setp(ax.get_xticklabels(), visible=False) - ax2 = plt.subplot2grid((4, 1), (3, 0), rowspan=1) + if legend_position == "outside": + ax2 = plt.subplot2grid((4, 6), (3, 0), rowspan=1, colspan=5) + else: + ax2 = plt.subplot2grid((4, 1), (3, 0), rowspan=1) # y_err = fit_y - data_y # if plot_pull: # _epsilon = 1e-10 @@ -856,7 +903,6 @@ def _2d_plot( color_first=True, **kwargs ): - twodplot = self.config["plot"].get("2Dplot", {}) for k, i in twodplot.items(): if "&" not in k: @@ -970,7 +1016,6 @@ def _2d_plot_v2( color_first=True, **kwargs ): - twodplot = self.config["plot"].get("2Dplot", {}) for k, v in twodplot.items(): if "&" in k: diff --git a/tf_pwa/tests/config_toy.yml b/tf_pwa/tests/config_toy.yml index 7d253ed3..fbeeff2e 100644 --- a/tf_pwa/tests/config_toy.yml +++ b/tf_pwa/tests/config_toy.yml @@ -32,13 +32,17 @@ constrains: decay: null plot: + config: + legend_position: "outside" mass: - R_BC: { display: "$M_{BC}$" } - R_BD: { display: "$M_{BD}$" } - R_CD: { display: "$M_{CD}$" } + R_BC: { display: "$M_{BC}$", legend_position: "left" } + R_BD: { display: "$M_{BD}$", legend_position: "middle" } + R_CD: { display: "$M_{CD}$", legend_position: "right" } angle: R_BC/B: cos(beta): display: "cos $\\theta$" + legend: True + legend_position: "default" alpha: display: "$\\phi$" From fe44e325272c25c1ec8404a683c16e03caea25e6 Mon Sep 17 00:00:00 2001 From: wwdws1 <270410665@qq.com> Date: Wed, 23 Aug 2023 22:36:44 +0800 Subject: [PATCH 2/5] fix: default legend position from left to best --- tf_pwa/config_loader/plot.py | 80 +++++++++--------------------------- 1 file changed, 20 insertions(+), 60 deletions(-) diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index fab29dac..7522db55 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -47,9 +47,7 @@ def default_color_generator(color_first): if color_first: style = itertools.product(marker, linestyles, colors) else: - style = _reverse( - itertools.product(marker, colors, linestyles), (0, 2, 1) - ) + style = _reverse(itertools.product(marker, colors, linestyles), (0, 2, 1)) return style @@ -111,10 +109,7 @@ def _get_cfit_bg(self, data, phsp, batch=65000): nbg = ndata * w w_bg = batch_call_numpy(bg_f, phsp_i, batch) * phsp_i.get_weight() phsp_weight.append(-w_bg / np.sum(w_bg) * nbg) - ret = [ - data_replace(phsp_i, "weight", w) - for phsp_i, w in zip(phsp, phsp_weight) - ] + ret = [data_replace(phsp_i, "weight", w) for phsp_i, w in zip(phsp, phsp_weight)] return ret @@ -126,10 +121,7 @@ def _get_cfit_eff_phsp(self, phsp, batch=65000): w_eff = batch_call_numpy(eff_f, phsp_i, batch) * phsp_i.get_weight() phsp_weight.append(w_eff) - ret = [ - data_replace(phsp_i, "weight", w) - for phsp_i, w in zip(phsp, phsp_weight) - ] + ret = [data_replace(phsp_i, "weight", w) for phsp_i, w in zip(phsp, phsp_weight)] return ret @@ -280,9 +272,7 @@ def plot_partial_wave( phsp_rec = _get_cfit_eff_phsp(self, phsp_rec, batch) amp = self.get_amplitude() self._Ngroup = len(data) - ws_bkg = [ - None if bg_i is None else bg_i.get("weight", None) for bg_i in bg - ] + ws_bkg = [None if bg_i is None else bg_i.get("weight", None) for bg_i in bg] # ws_bkg, ws_inmc = self._get_bg_weight(data, bg) if chains_id_method is not None: self.chains_id_method = chains_id_method @@ -297,7 +287,7 @@ def plot_partial_wave( has_legend = conf.get("legend", False) xrange = conf.get("range", None) bins = conf.get("bins", None) - legend_position = conf.get("legend_position", "left") + legend_position = conf.get("legend_position", "best") units = conf.get("units", "") yscale = conf.get("yscale", "linear") plot_var_dic[name] = { @@ -509,9 +499,7 @@ def _cal_partial_wave( phsp_dict["MC_total_fit"] = cut_phsp * phsp_weights # MC total weight if ref_amp is not None: - phsp_dict["MC_total_fit_ref"] = ( - cut_phsp * total_weight_ref * norm_frac_ref - ) + phsp_dict["MC_total_fit_ref"] = cut_phsp * total_weight_ref * norm_frac_ref if bg is not None: bg_weight = -w_bkg bg_dict["sideband_weights"] = ( @@ -532,9 +520,7 @@ def _cal_partial_wave( idx = plot_var_dic[name]["idx"] trans = lambda x: np.reshape(plot_var_dic[name]["trans"](x), (-1,)) - data_i = batch_call_numpy( - lambda x: trans(data_index(x, idx)), data, batch - ) + data_i = batch_call_numpy(lambda x: trans(data_index(x, idx)), data, batch) if idx[-1] == "m": tmp_idx = list(idx) tmp_idx[-1] = "p" @@ -557,9 +543,7 @@ def _cal_partial_wave( phsp_dict[name + "_MC"] = phsp_i # MC if bg is not None: - bg_i = batch_call_numpy( - lambda x: trans(data_index(x, idx)), bg, batch - ) + bg_i = batch_call_numpy(lambda x: trans(data_index(x, idx)), bg, batch) bg_dict[name + "_sideband"] = bg_i # sideband data_dict = data_to_numpy(data_dict) phsp_dict = data_to_numpy(phsp_dict) @@ -647,9 +631,7 @@ def _plot_partial_wave( legends = [] legends_label = [] - le = data_hist.draw_error( - ax, fmt=".", zorder=-2, label="data", color="black" - ) + le = data_hist.draw_error(ax, fmt=".", zorder=-2, label="data", color="black") legends.append(le) legends_label.append("data") @@ -666,12 +648,8 @@ def _plot_partial_wave( ) if bg_dict: - bg_hist = Hist1D.histogram( - bg_i, weights=bg_weight, range=xrange, bins=bins - ) - le = bg_hist.draw_bar( - ax, label="back ground", alpha=0.5, color="grey" - ) + bg_hist = Hist1D.histogram(bg_i, weights=bg_weight, range=xrange, bins=bins) + le = bg_hist.draw_bar(ax, label="back ground", alpha=0.5, color="grey") fitted_hist = fitted_hist + bg_hist if ref_amp is not None: fitted_hist_ref = fitted_hist_ref + bg_hist @@ -704,9 +682,7 @@ def _plot_partial_wave( # marker, ls, color = line["marker"], line["linestyle"], line["color"] le3 = hist_i.draw_kde(ax, **kwargs) else: - le3 = hist_i.draw_kde( - ax, fmt=curve_style, label=label, linewidth=1 - ) + le3 = hist_i.draw_kde(ax, fmt=curve_style, label=label, linewidth=1) else: if curve_style is None: line = style.get_style(name_i) @@ -782,13 +758,9 @@ def _plot_partial_wave( if nll is None: ax.set_title(display, fontsize="xx-large") else: - ax.set_title( - "{}: -lnL= {:.2f}".format(display, nll), fontsize="xx-large" - ) + ax.set_title("{}: -lnL= {:.2f}".format(display, nll), fontsize="xx-large") ax.set_xlabel(display + units) - ywidth = np.mean( - data_hist.bin_width - ) # (max(data_x) - min(data_x)) / bins + ywidth = np.mean(data_hist.bin_width) # (max(data_x) - min(data_x)) / bins ax.set_ylabel("Events/{:.3f}{}".format(ywidth, units)) if plot_delta or plot_pull: plt.setp(ax.get_xticklabels(), visible=False) @@ -956,9 +928,7 @@ def plot_axis(): if bg_dict: bg_1 = bg_dict[var1 + "_sideband"] bg_2 = bg_dict[var2 + "_sideband"] - plt.scatter( - bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband" - ) + plt.scatter(bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband") plot_axis() plt.legend() plt.savefig(prefix + k + "_bkg") @@ -969,9 +939,7 @@ def plot_axis(): # fit pdf if "fitted" in plot_figs: phsp_weights = phsp_dict["MC_total_fit"] - plt.hist2d( - phsp_1, phsp_2, bins=100, weights=phsp_weights, cmin=1e-12 - ) + plt.hist2d(phsp_1, phsp_2, bins=100, weights=phsp_weights, cmin=1e-12) plot_axis() plt.colorbar() plt.savefig(prefix + k + "_fitted") @@ -1103,9 +1071,7 @@ def plot_axis(): if bg_dict: bg_1 = var_x_f(**get_var(bg_dict, "_sideband")) bg_2 = var_y_f(**get_var(bg_dict, "_sideband")) - plt.scatter( - bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband" - ) + plt.scatter(bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband") plot_axis() plt.savefig(prefix + k + "_bkg") plt.clf() @@ -1170,9 +1136,7 @@ def hist_error(data, bins=50, xrange=None, weights=1.0, kind="poisson"): return data_x, data_y, data_err -def hist_line( - data, weights, bins, xrange=None, inter=1, kind="UnivariateSpline" -): +def hist_line(data, weights, bins, xrange=None, inter=1, kind="UnivariateSpline"): """interpolate data from hostgram into a line >>> import numpy as np @@ -1187,9 +1151,7 @@ def hist_line( return interp_hist(x, y, num=num, kind=kind) -def hist_line_step( - data, weights, bins, xrange=None, inter=1, kind="quadratic" -): +def hist_line_step(data, weights, bins, xrange=None, inter=1, kind="quadratic"): """ >>> import numpy as np @@ -1219,9 +1181,7 @@ def export_legend(ax, filename="legend.pdf", ncol=1): ) fig = legend.figure fig.canvas.draw() - bbox = legend.get_window_extent().transformed( - fig.dpi_scale_trans.inverted() - ) + bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) fig.savefig(filename, dpi="figure", bbox_inches=bbox) plt.close(fig2) plt.close(fig) From dcd280535f4e8018d8456ea59b58eaedabe0d77f Mon Sep 17 00:00:00 2001 From: wwdws1 <270410665@qq.com> Date: Wed, 23 Aug 2023 23:05:29 +0800 Subject: [PATCH 3/5] format code --- config.sample.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config.sample.yml b/config.sample.yml index b92e7a17..76fea1c6 100644 --- a/config.sample.yml +++ b/config.sample.yml @@ -135,7 +135,8 @@ plot: config: bins: 50 legend_position: "default" - ## The position of legend, can be "default"="best", "left"="upper left", "right"="upper right", "middle"="upper center" and "outside", this option can also be set in each plot item. + ## The position of legend, can be "default"="best", "left"="upper left", "right"="upper right", "middle"="upper center" and "outside". + ## This option can also be set in each plot item. ## invariant mass mass: R_CD: From 4d98325aab5037cf8a4cc56476b51287fb2b1679 Mon Sep 17 00:00:00 2001 From: wwdws1 <270410665@qq.com> Date: Wed, 23 Aug 2023 23:38:34 +0800 Subject: [PATCH 4/5] format code 2 --- tf_pwa/config_loader/plot.py | 78 +++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 19 deletions(-) diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index 7522db55..b5ef14bd 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -47,7 +47,9 @@ def default_color_generator(color_first): if color_first: style = itertools.product(marker, linestyles, colors) else: - style = _reverse(itertools.product(marker, colors, linestyles), (0, 2, 1)) + style = _reverse( + itertools.product(marker, colors, linestyles), (0, 2, 1) + ) return style @@ -109,7 +111,10 @@ def _get_cfit_bg(self, data, phsp, batch=65000): nbg = ndata * w w_bg = batch_call_numpy(bg_f, phsp_i, batch) * phsp_i.get_weight() phsp_weight.append(-w_bg / np.sum(w_bg) * nbg) - ret = [data_replace(phsp_i, "weight", w) for phsp_i, w in zip(phsp, phsp_weight)] + ret = [ + data_replace(phsp_i, "weight", w) + for phsp_i, w in zip(phsp, phsp_weight) + ] return ret @@ -121,7 +126,10 @@ def _get_cfit_eff_phsp(self, phsp, batch=65000): w_eff = batch_call_numpy(eff_f, phsp_i, batch) * phsp_i.get_weight() phsp_weight.append(w_eff) - ret = [data_replace(phsp_i, "weight", w) for phsp_i, w in zip(phsp, phsp_weight)] + ret = [ + data_replace(phsp_i, "weight", w) + for phsp_i, w in zip(phsp, phsp_weight) + ] return ret @@ -272,7 +280,9 @@ def plot_partial_wave( phsp_rec = _get_cfit_eff_phsp(self, phsp_rec, batch) amp = self.get_amplitude() self._Ngroup = len(data) - ws_bkg = [None if bg_i is None else bg_i.get("weight", None) for bg_i in bg] + ws_bkg = [ + None if bg_i is None else bg_i.get("weight", None) for bg_i in bg + ] # ws_bkg, ws_inmc = self._get_bg_weight(data, bg) if chains_id_method is not None: self.chains_id_method = chains_id_method @@ -499,7 +509,9 @@ def _cal_partial_wave( phsp_dict["MC_total_fit"] = cut_phsp * phsp_weights # MC total weight if ref_amp is not None: - phsp_dict["MC_total_fit_ref"] = cut_phsp * total_weight_ref * norm_frac_ref + phsp_dict["MC_total_fit_ref"] = ( + cut_phsp * total_weight_ref * norm_frac_ref + ) if bg is not None: bg_weight = -w_bkg bg_dict["sideband_weights"] = ( @@ -520,7 +532,9 @@ def _cal_partial_wave( idx = plot_var_dic[name]["idx"] trans = lambda x: np.reshape(plot_var_dic[name]["trans"](x), (-1,)) - data_i = batch_call_numpy(lambda x: trans(data_index(x, idx)), data, batch) + data_i = batch_call_numpy( + lambda x: trans(data_index(x, idx)), data, batch + ) if idx[-1] == "m": tmp_idx = list(idx) tmp_idx[-1] = "p" @@ -543,7 +557,9 @@ def _cal_partial_wave( phsp_dict[name + "_MC"] = phsp_i # MC if bg is not None: - bg_i = batch_call_numpy(lambda x: trans(data_index(x, idx)), bg, batch) + bg_i = batch_call_numpy( + lambda x: trans(data_index(x, idx)), bg, batch + ) bg_dict[name + "_sideband"] = bg_i # sideband data_dict = data_to_numpy(data_dict) phsp_dict = data_to_numpy(phsp_dict) @@ -631,7 +647,9 @@ def _plot_partial_wave( legends = [] legends_label = [] - le = data_hist.draw_error(ax, fmt=".", zorder=-2, label="data", color="black") + le = data_hist.draw_error( + ax, fmt=".", zorder=-2, label="data", color="black" + ) legends.append(le) legends_label.append("data") @@ -648,8 +666,12 @@ def _plot_partial_wave( ) if bg_dict: - bg_hist = Hist1D.histogram(bg_i, weights=bg_weight, range=xrange, bins=bins) - le = bg_hist.draw_bar(ax, label="back ground", alpha=0.5, color="grey") + bg_hist = Hist1D.histogram( + bg_i, weights=bg_weight, range=xrange, bins=bins + ) + le = bg_hist.draw_bar( + ax, label="back ground", alpha=0.5, color="grey" + ) fitted_hist = fitted_hist + bg_hist if ref_amp is not None: fitted_hist_ref = fitted_hist_ref + bg_hist @@ -682,7 +704,9 @@ def _plot_partial_wave( # marker, ls, color = line["marker"], line["linestyle"], line["color"] le3 = hist_i.draw_kde(ax, **kwargs) else: - le3 = hist_i.draw_kde(ax, fmt=curve_style, label=label, linewidth=1) + le3 = hist_i.draw_kde( + ax, fmt=curve_style, label=label, linewidth=1 + ) else: if curve_style is None: line = style.get_style(name_i) @@ -758,9 +782,13 @@ def _plot_partial_wave( if nll is None: ax.set_title(display, fontsize="xx-large") else: - ax.set_title("{}: -lnL= {:.2f}".format(display, nll), fontsize="xx-large") + ax.set_title( + "{}: -lnL= {:.2f}".format(display, nll), fontsize="xx-large" + ) ax.set_xlabel(display + units) - ywidth = np.mean(data_hist.bin_width) # (max(data_x) - min(data_x)) / bins + ywidth = np.mean( + data_hist.bin_width + ) # (max(data_x) - min(data_x)) / bins ax.set_ylabel("Events/{:.3f}{}".format(ywidth, units)) if plot_delta or plot_pull: plt.setp(ax.get_xticklabels(), visible=False) @@ -928,7 +956,9 @@ def plot_axis(): if bg_dict: bg_1 = bg_dict[var1 + "_sideband"] bg_2 = bg_dict[var2 + "_sideband"] - plt.scatter(bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband") + plt.scatter( + bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband" + ) plot_axis() plt.legend() plt.savefig(prefix + k + "_bkg") @@ -939,7 +969,9 @@ def plot_axis(): # fit pdf if "fitted" in plot_figs: phsp_weights = phsp_dict["MC_total_fit"] - plt.hist2d(phsp_1, phsp_2, bins=100, weights=phsp_weights, cmin=1e-12) + plt.hist2d( + phsp_1, phsp_2, bins=100, weights=phsp_weights, cmin=1e-12 + ) plot_axis() plt.colorbar() plt.savefig(prefix + k + "_fitted") @@ -1071,7 +1103,9 @@ def plot_axis(): if bg_dict: bg_1 = var_x_f(**get_var(bg_dict, "_sideband")) bg_2 = var_y_f(**get_var(bg_dict, "_sideband")) - plt.scatter(bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband") + plt.scatter( + bg_1, bg_2, s=1, c="g", alpha=0.8, label="sideband" + ) plot_axis() plt.savefig(prefix + k + "_bkg") plt.clf() @@ -1136,7 +1170,9 @@ def hist_error(data, bins=50, xrange=None, weights=1.0, kind="poisson"): return data_x, data_y, data_err -def hist_line(data, weights, bins, xrange=None, inter=1, kind="UnivariateSpline"): +def hist_line( + data, weights, bins, xrange=None, inter=1, kind="UnivariateSpline" +): """interpolate data from hostgram into a line >>> import numpy as np @@ -1151,7 +1187,9 @@ def hist_line(data, weights, bins, xrange=None, inter=1, kind="UnivariateSpline" return interp_hist(x, y, num=num, kind=kind) -def hist_line_step(data, weights, bins, xrange=None, inter=1, kind="quadratic"): +def hist_line_step( + data, weights, bins, xrange=None, inter=1, kind="quadratic" +): """ >>> import numpy as np @@ -1181,7 +1219,9 @@ def export_legend(ax, filename="legend.pdf", ncol=1): ) fig = legend.figure fig.canvas.draw() - bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + bbox = legend.get_window_extent().transformed( + fig.dpi_scale_trans.inverted() + ) fig.savefig(filename, dpi="figure", bbox_inches=bbox) plt.close(fig2) plt.close(fig) From 2f281ba2aa8dc78bf04b29fa52697843c1919e49 Mon Sep 17 00:00:00 2001 From: wwdws1 <270410665@qq.com> Date: Thu, 24 Aug 2023 04:52:44 +0800 Subject: [PATCH 5/5] Simplify the 'legend location' options --- config.sample.yml | 6 ++-- tf_pwa/config_loader/config_loader.py | 16 ++++----- tf_pwa/config_loader/plot.py | 48 ++++++--------------------- tf_pwa/tests/config_toy.yml | 10 +++--- 4 files changed, 26 insertions(+), 54 deletions(-) diff --git a/config.sample.yml b/config.sample.yml index 76fea1c6..7d889a76 100644 --- a/config.sample.yml +++ b/config.sample.yml @@ -134,9 +134,8 @@ plot: ## plot configuration TODO config: bins: 50 - legend_position: "default" - ## The position of legend, can be "default"="best", "left"="upper left", "right"="upper right", "middle"="upper center" and "outside". - ## This option can also be set in each plot item. + legend_outside: True + ## Make the legend outside the plot, default is False (inside) ## invariant mass mass: R_CD: @@ -145,6 +144,7 @@ plot: display: "$M_{CD}$" # upper_ylim: 200 # trans: x*x + legend_outside: False R_BD: display: "$M_{BD}$" R_BC: diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index db159c89..4ba970e4 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -1133,9 +1133,9 @@ def get_mass_vars(self): units = v.get("units", "GeV") bins = v.get("bins", self.defaults_config.get("bins", 50)) legend = v.get("legend", self.defaults_config.get("legend", True)) - legend_position = v.get( - "legend_position", - self.defaults_config.get("legend_position", "best"), + legend_outside = v.get( + "legend_outside", + self.defaults_config.get("legend_outside", False), ) yscale = v.get( "yscale", self.defaults_config.get("yscale", "linear") @@ -1150,7 +1150,7 @@ def get_mass_vars(self): "m", ), "legend": legend, - "legend_position": legend_position, + "legend_outside": legend_outside, "range": xrange, "bins": bins, "trans": trans, @@ -1204,9 +1204,9 @@ def get_angle_vars(self, is_align=False): legend = v.get( "legend", self.defaults_config.get("legend", False) ) - legend_position = v.get( - "legend_position", - self.defaults_config.get("legend_position", "best"), + legend_outside = v.get( + "legend_outside", + self.defaults_config.get("legend_outside", False), ) yscale = v.get( "yscale", self.defaults_config.get("yscale", "linear") @@ -1234,7 +1234,7 @@ def get_angle_vars(self, is_align=False): "bins": bins, "range": xrange, "legend": legend, - "legend_position": legend_position, + "legend_outside": legend_outside, "yscale": yscale, } diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index b5ef14bd..9590cf3c 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -297,14 +297,14 @@ def plot_partial_wave( has_legend = conf.get("legend", False) xrange = conf.get("range", None) bins = conf.get("bins", None) - legend_position = conf.get("legend_position", "best") + legend_outside = conf.get("legend_outside", False) units = conf.get("units", "") yscale = conf.get("yscale", "linear") plot_var_dic[name] = { "display": display, "upper_ylim": upper_ylim, "legend": has_legend, - "legend_position": legend_position, + "legend_outside": legend_outside, "idx": idx, "trans": trans, "range": xrange, @@ -622,7 +622,7 @@ def _plot_partial_wave( display = plot_var_dic[name]["display"] upper_ylim = plot_var_dic[name]["upper_ylim"] has_legend = plot_var_dic[name]["legend"] - legend_position = plot_var_dic[name]["legend_position"] + legend_outside = plot_var_dic[name]["legend_outside"] bins = plot_var_dic[name]["bins"] units = plot_var_dic[name]["units"] xrange = plot_var_dic[name]["range"] @@ -637,7 +637,7 @@ def _plot_partial_wave( ) fig = plt.figure() if plot_delta or plot_pull: - if legend_position == "outside": + if legend_outside: ax = plt.subplot2grid((4, 6), (0, 0), rowspan=3, colspan=5) else: ax = plt.subplot2grid((4, 1), (0, 0), rowspan=3) @@ -732,52 +732,24 @@ def _plot_partial_wave( ax.set_xlim(xrange) ax.set_yscale(yscale) if has_legend: - if legend_position == "best" or "default": - leg = ax.legend( - legends, - legends_label, - frameon=False, - labelspacing=0.1, - borderpad=0.0, - loc="best", - ) - elif legend_position == "left" or "upper left": - leg = ax.legend( - legends, - legends_label, - frameon=False, - labelspacing=0.1, - borderpad=0.0, - loc="upper left", - ) - elif legend_position == "middle" or "upper center": - leg = ax.legend( - legends, - legends_label, - frameon=False, - labelspacing=0.1, - borderpad=0.0, - loc="upper center", - ) - elif legend_position == "right" or "upper right": + if legend_outside: leg = ax.legend( legends, legends_label, frameon=False, + fontsize="small", labelspacing=0.1, borderpad=0.0, - loc="upper right", + bbox_to_anchor=(1.02, 0.5), + loc=6, ) - elif legend_position == "outside": + else: leg = ax.legend( legends, legends_label, frameon=False, - fontsize="small", labelspacing=0.1, borderpad=0.0, - bbox_to_anchor=(1.02, 0.5), - loc=6, ) if nll is None: ax.set_title(display, fontsize="xx-large") @@ -792,7 +764,7 @@ def _plot_partial_wave( ax.set_ylabel("Events/{:.3f}{}".format(ywidth, units)) if plot_delta or plot_pull: plt.setp(ax.get_xticklabels(), visible=False) - if legend_position == "outside": + if legend_outside: ax2 = plt.subplot2grid((4, 6), (3, 0), rowspan=1, colspan=5) else: ax2 = plt.subplot2grid((4, 1), (3, 0), rowspan=1) diff --git a/tf_pwa/tests/config_toy.yml b/tf_pwa/tests/config_toy.yml index fbeeff2e..e1e55160 100644 --- a/tf_pwa/tests/config_toy.yml +++ b/tf_pwa/tests/config_toy.yml @@ -33,16 +33,16 @@ constrains: plot: config: - legend_position: "outside" + legend_outside: True mass: - R_BC: { display: "$M_{BC}$", legend_position: "left" } - R_BD: { display: "$M_{BD}$", legend_position: "middle" } - R_CD: { display: "$M_{CD}$", legend_position: "right" } + R_BC: { display: "$M_{BC}$" } + R_BD: { display: "$M_{BD}$" } + R_CD: { display: "$M_{CD}$" } angle: R_BC/B: cos(beta): display: "cos $\\theta$" legend: True - legend_position: "default" + legend_outside: False alpha: display: "$\\phi$"