Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Histogram fill #140

Merged
merged 3 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions tf_pwa/amp/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class InterpLinearNpy(InterpolationParticle):
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".npy")
>>> m = np.linspace(0.2, 0.9)
>>> m = np.linspace(0.2, 0.9, 51)
>>> mi = m[::5]
>>> np.save(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_npy", {"file": a})
Expand All @@ -113,11 +113,16 @@ class InterpLinearNpy(InterpolationParticle):
"""

def __init__(self, *args, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.load(self.input_file)
self.data = self.get_data(**kwargs)
points = self.data[:, 0]
kwargs["points"] = points
super().__init__(*args, **kwargs)
kwargs["points"] = points.tolist()
super(InterpLinearNpy, self).__init__(*args, **kwargs)
self.delta = np.concatenate([[1e20], points[1:] - points[:-1], [1e20]])
self.x_left = np.concatenate([[points[-1] - 1], points])

def get_data(self, **kwargs):
self.input_file = kwargs.get("file")
return np.load(self.input_file)

def init_params(self):
pass
Expand All @@ -129,20 +134,19 @@ def get_point_values(self):

def interp(self, m):
x, p_r, p_i = self.get_point_values()
bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=x)
bin_idx = (bin_idx) % (len(self.bound) + 1)
bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=self.points)
ret_r_r = tf.gather(p_r[1:], bin_idx)
ret_i_r = tf.gather(p_i[1:], bin_idx)
ret_r_l = tf.gather(p_r[:-1], bin_idx)
ret_i_l = tf.gather(p_i[:-1], bin_idx)
delta = np.concatenate([[1e20], x[1:] - x[:-1], [1e20]])
x_left = np.concatenate([[x[0] - 1], x])
delta = tf.gather(delta, bin_idx)
x_left = tf.gather(x_left, bin_idx)
delta = tf.gather(self.delta, bin_idx)
x_left = tf.gather(self.x_left, bin_idx)
step = (m - x_left) / delta
a = step * (ret_r_r - ret_r_l) + ret_r_l
b = step * (ret_i_r - ret_i_l) + ret_i_l
return tf.complex(a, b)
ret = tf.complex(a, b)
cut = (bin_idx <= 0) | (bin_idx >= self.delta.shape[-1] - 1)
return tf.where(cut, tf.zeros_like(ret), ret)


@register_particle("linear_txt")
Expand All @@ -158,20 +162,17 @@ class InterpLinearTxt(InterpLinearNpy):
>>> import numpy as np
>>> from tf_pwa.utils import plot_particle_model
>>> a = tempfile.mktemp(".txt")
>>> m = np.linspace(0.2, 0.9)
>>> m = np.linspace(0.2, 0.9, 51)
>>> mi = m[::5]
>>> np.savetxt(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1))
>>> axs = plot_particle_model("linear_txt", {"file": a})
>>> _ = axs[3].plot(np.cos(m*5), np.sin(m*5), "--")

"""

def __init__(self, *args, **kwargs):
def get_data(self, **kwargs):
self.input_file = kwargs.get("file")
self.data = np.loadtxt(self.input_file)
points = self.data[:, 0]
kwargs["points"] = points
super(InterpLinearNpy, self).__init__(*args, **kwargs)
return np.loadtxt(self.input_file)


@register_particle("interp")
Expand Down
6 changes: 5 additions & 1 deletion tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,9 +1014,13 @@ def set_params(self, params, neglect_params=None):
if neglect_params is None:
neglect_params = self._neglect_when_set_params
if len(neglect_params) != 0:
# warnings.warn("Neglect {} when setting params.".format(neglect_params))
for v in params:
if v in self._neglect_when_set_params:
warnings.warn(
"Neglect {} when setting params.".format(
neglect_params
)
)
del ret[v]
amplitude.set_params(ret)
return True
Expand Down
27 changes: 1 addition & 26 deletions tf_pwa/config_loader/multi_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,32 +227,7 @@ def get_params(self, trainable_only=False):

def set_params(self, params, neglect_params=None):
_amps = self.get_amplitudes()
if isinstance(params, str):
if params == "":
return False
try:
with open(params) as f:
params = yaml.safe_load(f)
except Exception as e:
print(e)
return False
if hasattr(params, "params"):
params = params.params
if isinstance(params, dict):
if "value" in params:
params = params["value"]
ret = params.copy()
if neglect_params is None:
neglect_params = self._neglect_when_set_params
if len(neglect_params) != 0:
warnings.warn(
"Neglect {} when setting params.".format(neglect_params)
)
for v in params:
if v in self._neglect_when_set_params:
del ret[v]
self.vm.set_all(ret)
return True
self.configs[0].set_params(params, neglect_params=neglect_params)

@contextlib.contextmanager
def params_trans(self):
Expand Down
39 changes: 29 additions & 10 deletions tf_pwa/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,17 @@ def draw(self, ax=plt, **kwargs):
ret = []
for i in draw_type.split("+"):
ret.append(self.draw(ax=ax, type=i, **kwargs))
elif draw_type == "hist":
ret = self.draw_hist(ax=ax, **kwargs)
elif draw_type == "bar":
ret = self.draw_bar(ax=ax, **kwargs)
elif draw_type == "kde":
ret = self.draw_kde(ax=ax, **kwargs)
elif draw_type == "error":
ret = self.draw_error(ax=ax, **kwargs)
elif draw_type == "line":
ret = self.draw_line(ax=ax, **kwargs)
elif draw_type in [
"hist",
"bar",
"kde",
"error",
"line",
"fill",
"stepfill",
]:
draw_fun = getattr(self, "draw_" + draw_type)
ret = draw_fun(ax=ax, **kwargs)
else:
raise NotImplementedError()
return ret
Expand Down Expand Up @@ -130,6 +131,24 @@ def draw_kde(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
else:
return ax.plot(x, kde(x), color=color, **kwargs)

def draw_fill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
color = kwargs.pop("color", self._cached_color)
m = self.bin_center
bw = self.bin_width * bin_scale
kde = weighted_kde(m, self.count, bw, kind)
x = np.linspace(
self.binning[0], self.binning[-1], self.count.shape[0] * 10
)
return ax.fill_between(
x, kde(x), np.zeros_like(x), color=color, **kwargs
)

def draw_stepfill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs):
color = kwargs.pop("color", self._cached_color)
x = np.repeat(self.binning, 2)
y = np.concatenate([[0], np.repeat(self.count, 2), [0]])
return ax.fill_between(x, y, np.zeros_like(x), color=color, **kwargs)

def draw_pull(self, ax=plt, **kwargs):
with np.errstate(divide="ignore", invalid="ignore"):
y_error = np.where(self.error == 0, 0, self.count / self.error)
Expand Down
2 changes: 2 additions & 0 deletions tf_pwa/tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def test_hist1d():
hist.draw(type="line+bar")
hist.draw_kde(ax, kind="gauss", color="blue")
hist.draw_kde(ax, kind="cauchy", color="red")
(hist * 0.5).draw(ax, type="fill", facecolor="none", hatch="///")
(hist * 0.4).draw_stepfill(ax, facecolor="none", hatch="\\")
(0.1 * hist + hist * 0.1).draw_bar(ax)
hist.draw_error(ax)
hist2 = Hist1D.histogram(
Expand Down
Loading