Skip to content

Commit

Permalink
Merge pull request #102 from wwdws1/leg_pos
Browse files Browse the repository at this point in the history
feat: new option for changing the position of legend
  • Loading branch information
jiangyi15 authored Aug 24, 2023
2 parents da19b41 + 2f281ba commit 8818cb7
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 15 deletions.
5 changes: 3 additions & 2 deletions config.sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
## line starts with `##` is documnet.
## line starts with `#` is option.


## data entry is the data file name list used for fit
data:
## 4-momentum order of final particles in dat files [option]
Expand Down Expand Up @@ -42,7 +41,6 @@ data:
# data_charge: ["data/data4600_cc.dat"] # charge conjugation condition same as weight
# cp_trans: True # when used charge conjugation as above, this do p -> -p for charge conjugation process.


## `decay` describe the decay structure, each node can be a list of particle name
decay:
## each entry is a list of direct decay particle, or a list or the list.
Expand Down Expand Up @@ -136,6 +134,8 @@ plot:
## plot configuration TODO
config:
bins: 50
legend_outside: True
## Make the legend outside the plot, default is False (inside)
## invariant mass
mass:
R_CD:
Expand All @@ -144,6 +144,7 @@ plot:
display: "$M_{CD}$"
# upper_ylim: 200
# trans: x*x
legend_outside: False
R_BD:
display: "$M_{BD}$"
R_BC:
Expand Down
10 changes: 10 additions & 0 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,10 @@ def get_mass_vars(self):
units = v.get("units", "GeV")
bins = v.get("bins", self.defaults_config.get("bins", 50))
legend = v.get("legend", self.defaults_config.get("legend", True))
legend_outside = v.get(
"legend_outside",
self.defaults_config.get("legend_outside", False),
)
yscale = v.get(
"yscale", self.defaults_config.get("yscale", "linear")
)
Expand All @@ -1146,6 +1150,7 @@ def get_mass_vars(self):
"m",
),
"legend": legend,
"legend_outside": legend_outside,
"range": xrange,
"bins": bins,
"trans": trans,
Expand Down Expand Up @@ -1199,6 +1204,10 @@ def get_angle_vars(self, is_align=False):
legend = v.get(
"legend", self.defaults_config.get("legend", False)
)
legend_outside = v.get(
"legend_outside",
self.defaults_config.get("legend_outside", False),
)
yscale = v.get(
"yscale", self.defaults_config.get("yscale", "linear")
)
Expand All @@ -1225,6 +1234,7 @@ def get_angle_vars(self, is_align=False):
"bins": bins,
"range": xrange,
"legend": legend,
"legend_outside": legend_outside,
"yscale": yscale,
}

Expand Down
43 changes: 30 additions & 13 deletions tf_pwa/config_loader/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,14 @@ def plot_partial_wave(
has_legend = conf.get("legend", False)
xrange = conf.get("range", None)
bins = conf.get("bins", None)
legend_outside = conf.get("legend_outside", False)
units = conf.get("units", "")
yscale = conf.get("yscale", "linear")
plot_var_dic[name] = {
"display": display,
"upper_ylim": upper_ylim,
"legend": has_legend,
"legend_outside": legend_outside,
"idx": idx,
"trans": trans,
"range": xrange,
Expand Down Expand Up @@ -369,7 +371,6 @@ def plot_partial_wave(
**kwargs,
)
else:

for dt, mc, sb, w_bkg, i in zip(
data, phsp, bg, ws_bkg, range(self._Ngroup)
):
Expand Down Expand Up @@ -602,7 +603,6 @@ def _plot_partial_wave(
ref_amp=None,
**kwargs
):

# cmap = plt.get_cmap("jet")
# N = 10
# colors = [cmap(float(i) / (N+1)) for i in range(1, N+1)]
Expand All @@ -622,6 +622,7 @@ def _plot_partial_wave(
display = plot_var_dic[name]["display"]
upper_ylim = plot_var_dic[name]["upper_ylim"]
has_legend = plot_var_dic[name]["legend"]
legend_outside = plot_var_dic[name]["legend_outside"]
bins = plot_var_dic[name]["bins"]
units = plot_var_dic[name]["units"]
xrange = plot_var_dic[name]["range"]
Expand All @@ -636,7 +637,10 @@ def _plot_partial_wave(
)
fig = plt.figure()
if plot_delta or plot_pull:
ax = plt.subplot2grid((4, 1), (0, 0), rowspan=3)
if legend_outside:
ax = plt.subplot2grid((4, 6), (0, 0), rowspan=3, colspan=5)
else:
ax = plt.subplot2grid((4, 1), (0, 0), rowspan=3)
else:
ax = fig.add_subplot(1, 1, 1)

Expand Down Expand Up @@ -728,13 +732,25 @@ def _plot_partial_wave(
ax.set_xlim(xrange)
ax.set_yscale(yscale)
if has_legend:
leg = ax.legend(
legends,
legends_label,
frameon=False,
labelspacing=0.1,
borderpad=0.0,
)
if legend_outside:
leg = ax.legend(
legends,
legends_label,
frameon=False,
fontsize="small",
labelspacing=0.1,
borderpad=0.0,
bbox_to_anchor=(1.02, 0.5),
loc=6,
)
else:
leg = ax.legend(
legends,
legends_label,
frameon=False,
labelspacing=0.1,
borderpad=0.0,
)
if nll is None:
ax.set_title(display, fontsize="xx-large")
else:
Expand All @@ -748,7 +764,10 @@ def _plot_partial_wave(
ax.set_ylabel("Events/{:.3f}{}".format(ywidth, units))
if plot_delta or plot_pull:
plt.setp(ax.get_xticklabels(), visible=False)
ax2 = plt.subplot2grid((4, 1), (3, 0), rowspan=1)
if legend_outside:
ax2 = plt.subplot2grid((4, 6), (3, 0), rowspan=1, colspan=5)
else:
ax2 = plt.subplot2grid((4, 1), (3, 0), rowspan=1)
# y_err = fit_y - data_y
# if plot_pull:
# _epsilon = 1e-10
Expand Down Expand Up @@ -856,7 +875,6 @@ def _2d_plot(
color_first=True,
**kwargs
):

twodplot = self.config["plot"].get("2Dplot", {})
for k, i in twodplot.items():
if "&" not in k:
Expand Down Expand Up @@ -970,7 +988,6 @@ def _2d_plot_v2(
color_first=True,
**kwargs
):

twodplot = self.config["plot"].get("2Dplot", {})
for k, v in twodplot.items():
if "&" in k:
Expand Down
4 changes: 4 additions & 0 deletions tf_pwa/tests/config_toy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ constrains:
decay: null

plot:
config:
legend_outside: True
mass:
R_BC: { display: "$M_{BC}$" }
R_BD: { display: "$M_{BD}$" }
Expand All @@ -40,5 +42,7 @@ plot:
R_BC/B:
cos(beta):
display: "cos $\\theta$"
legend: True
legend_outside: False
alpha:
display: "$\\phi$"

0 comments on commit 8818cb7

Please sign in to comment.