From c22a20ae6718c4045eaf77bdf100bd7d1260826e Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 17 Feb 2024 21:28:35 +0800 Subject: [PATCH 1/3] feat: histogram with fill_between --- tf_pwa/histogram.py | 14 ++++++++++++++ tf_pwa/tests/test_histogram.py | 1 + 2 files changed, 15 insertions(+) diff --git a/tf_pwa/histogram.py b/tf_pwa/histogram.py index 261ff83..a11cd58 100644 --- a/tf_pwa/histogram.py +++ b/tf_pwa/histogram.py @@ -99,6 +99,8 @@ def draw(self, ax=plt, **kwargs): ret = self.draw_error(ax=ax, **kwargs) elif draw_type == "line": ret = self.draw_line(ax=ax, **kwargs) + elif draw_type == "fill": + ret = self.draw_fill(ax=ax, **kwargs) else: raise NotImplementedError() return ret @@ -130,6 +132,18 @@ def draw_kde(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs): else: return ax.plot(x, kde(x), color=color, **kwargs) + def draw_fill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs): + color = kwargs.pop("color", self._cached_color) + m = self.bin_center + bw = self.bin_width * bin_scale + kde = weighted_kde(m, self.count, bw, kind) + x = np.linspace( + self.binning[0], self.binning[-1], self.count.shape[0] * 10 + ) + return ax.fill_between( + x, kde(x), np.zeros_like(x), color=color, **kwargs + ) + def draw_pull(self, ax=plt, **kwargs): with np.errstate(divide="ignore", invalid="ignore"): y_error = np.where(self.error == 0, 0, self.count / self.error) diff --git a/tf_pwa/tests/test_histogram.py b/tf_pwa/tests/test_histogram.py index ecf4bae..fdd5cc1 100644 --- a/tf_pwa/tests/test_histogram.py +++ b/tf_pwa/tests/test_histogram.py @@ -15,6 +15,7 @@ def test_hist1d(): hist.draw(type="line+bar") hist.draw_kde(ax, kind="gauss", color="blue") hist.draw_kde(ax, kind="cauchy", color="red") + (hist * 0.5).draw(ax, type="fill", facecolor="none", hatch="///") (0.1 * hist + hist * 0.1).draw_bar(ax) hist.draw_error(ax) hist2 = Hist1D.histogram( From 685ffa60ad600eb665a0abb6e8944d0380f84463 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 17 Feb 2024 23:13:58 +0800 Subject: [PATCH 2/3] refactor: linear_npy and set MultiConfig.set_params same as ConfigLoader --- tf_pwa/amp/interpolation.py | 37 ++++++++++++++------------- tf_pwa/config_loader/config_loader.py | 6 ++++- tf_pwa/config_loader/multi_config.py | 27 +------------------ 3 files changed, 25 insertions(+), 45 deletions(-) diff --git a/tf_pwa/amp/interpolation.py b/tf_pwa/amp/interpolation.py index 77af08b..7687d23 100644 --- a/tf_pwa/amp/interpolation.py +++ b/tf_pwa/amp/interpolation.py @@ -104,7 +104,7 @@ class InterpLinearNpy(InterpolationParticle): >>> import numpy as np >>> from tf_pwa.utils import plot_particle_model >>> a = tempfile.mktemp(".npy") - >>> m = np.linspace(0.2, 0.9) + >>> m = np.linspace(0.2, 0.9, 51) >>> mi = m[::5] >>> np.save(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1)) >>> axs = plot_particle_model("linear_npy", {"file": a}) @@ -113,11 +113,16 @@ class InterpLinearNpy(InterpolationParticle): """ def __init__(self, *args, **kwargs): - self.input_file = kwargs.get("file") - self.data = np.load(self.input_file) + self.data = self.get_data(**kwargs) points = self.data[:, 0] - kwargs["points"] = points - super().__init__(*args, **kwargs) + kwargs["points"] = points.tolist() + super(InterpLinearNpy, self).__init__(*args, **kwargs) + self.delta = np.concatenate([[1e20], points[1:] - points[:-1], [1e20]]) + self.x_left = np.concatenate([[points[-1] - 1], points]) + + def get_data(self, **kwargs): + self.input_file = kwargs.get("file") + return np.load(self.input_file) def init_params(self): pass @@ -129,20 +134,19 @@ def get_point_values(self): def interp(self, m): x, p_r, p_i = self.get_point_values() - bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=x) - bin_idx = (bin_idx) % (len(self.bound) + 1) + bin_idx = tf.raw_ops.Bucketize(input=m, boundaries=self.points) ret_r_r = tf.gather(p_r[1:], bin_idx) ret_i_r = tf.gather(p_i[1:], bin_idx) ret_r_l = tf.gather(p_r[:-1], bin_idx) ret_i_l = tf.gather(p_i[:-1], bin_idx) - delta = np.concatenate([[1e20], x[1:] - x[:-1], [1e20]]) - x_left = np.concatenate([[x[0] - 1], x]) - delta = tf.gather(delta, bin_idx) - x_left = tf.gather(x_left, bin_idx) + delta = tf.gather(self.delta, bin_idx) + x_left = tf.gather(self.x_left, bin_idx) step = (m - x_left) / delta a = step * (ret_r_r - ret_r_l) + ret_r_l b = step * (ret_i_r - ret_i_l) + ret_i_l - return tf.complex(a, b) + ret = tf.complex(a, b) + cut = (bin_idx <= 0) | (bin_idx >= self.delta.shape[-1] - 1) + return tf.where(cut, tf.zeros_like(ret), ret) @register_particle("linear_txt") @@ -158,7 +162,7 @@ class InterpLinearTxt(InterpLinearNpy): >>> import numpy as np >>> from tf_pwa.utils import plot_particle_model >>> a = tempfile.mktemp(".txt") - >>> m = np.linspace(0.2, 0.9) + >>> m = np.linspace(0.2, 0.9, 51) >>> mi = m[::5] >>> np.savetxt(a, np.stack([mi, np.cos(mi*5), np.sin(mi*5)], axis=-1)) >>> axs = plot_particle_model("linear_txt", {"file": a}) @@ -166,12 +170,9 @@ class InterpLinearTxt(InterpLinearNpy): """ - def __init__(self, *args, **kwargs): + def get_data(self, **kwargs): self.input_file = kwargs.get("file") - self.data = np.loadtxt(self.input_file) - points = self.data[:, 0] - kwargs["points"] = points - super(InterpLinearNpy, self).__init__(*args, **kwargs) + return np.loadtxt(self.input_file) @register_particle("interp") diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 75d8d06..fdae923 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -1014,9 +1014,13 @@ def set_params(self, params, neglect_params=None): if neglect_params is None: neglect_params = self._neglect_when_set_params if len(neglect_params) != 0: - # warnings.warn("Neglect {} when setting params.".format(neglect_params)) for v in params: if v in self._neglect_when_set_params: + warnings.warn( + "Neglect {} when setting params.".format( + neglect_params + ) + ) del ret[v] amplitude.set_params(ret) return True diff --git a/tf_pwa/config_loader/multi_config.py b/tf_pwa/config_loader/multi_config.py index 7600c90..b5ca1b6 100644 --- a/tf_pwa/config_loader/multi_config.py +++ b/tf_pwa/config_loader/multi_config.py @@ -227,32 +227,7 @@ def get_params(self, trainable_only=False): def set_params(self, params, neglect_params=None): _amps = self.get_amplitudes() - if isinstance(params, str): - if params == "": - return False - try: - with open(params) as f: - params = yaml.safe_load(f) - except Exception as e: - print(e) - return False - if hasattr(params, "params"): - params = params.params - if isinstance(params, dict): - if "value" in params: - params = params["value"] - ret = params.copy() - if neglect_params is None: - neglect_params = self._neglect_when_set_params - if len(neglect_params) != 0: - warnings.warn( - "Neglect {} when setting params.".format(neglect_params) - ) - for v in params: - if v in self._neglect_when_set_params: - del ret[v] - self.vm.set_all(ret) - return True + self.configs[0].set_params(params, neglect_params=neglect_params) @contextlib.contextmanager def params_trans(self): From 56ceea2df9c604c2f12f879313f1771e9fd936b5 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 17 Feb 2024 23:28:02 +0800 Subject: [PATCH 3/3] feat: stepfill for Hist1D --- tf_pwa/histogram.py | 29 +++++++++++++++++------------ tf_pwa/tests/test_histogram.py | 1 + 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tf_pwa/histogram.py b/tf_pwa/histogram.py index a11cd58..23534e4 100644 --- a/tf_pwa/histogram.py +++ b/tf_pwa/histogram.py @@ -89,18 +89,17 @@ def draw(self, ax=plt, **kwargs): ret = [] for i in draw_type.split("+"): ret.append(self.draw(ax=ax, type=i, **kwargs)) - elif draw_type == "hist": - ret = self.draw_hist(ax=ax, **kwargs) - elif draw_type == "bar": - ret = self.draw_bar(ax=ax, **kwargs) - elif draw_type == "kde": - ret = self.draw_kde(ax=ax, **kwargs) - elif draw_type == "error": - ret = self.draw_error(ax=ax, **kwargs) - elif draw_type == "line": - ret = self.draw_line(ax=ax, **kwargs) - elif draw_type == "fill": - ret = self.draw_fill(ax=ax, **kwargs) + elif draw_type in [ + "hist", + "bar", + "kde", + "error", + "line", + "fill", + "stepfill", + ]: + draw_fun = getattr(self, "draw_" + draw_type) + ret = draw_fun(ax=ax, **kwargs) else: raise NotImplementedError() return ret @@ -144,6 +143,12 @@ def draw_fill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs): x, kde(x), np.zeros_like(x), color=color, **kwargs ) + def draw_stepfill(self, ax=plt, kind="gauss", bin_scale=1.0, **kwargs): + color = kwargs.pop("color", self._cached_color) + x = np.repeat(self.binning, 2) + y = np.concatenate([[0], np.repeat(self.count, 2), [0]]) + return ax.fill_between(x, y, np.zeros_like(x), color=color, **kwargs) + def draw_pull(self, ax=plt, **kwargs): with np.errstate(divide="ignore", invalid="ignore"): y_error = np.where(self.error == 0, 0, self.count / self.error) diff --git a/tf_pwa/tests/test_histogram.py b/tf_pwa/tests/test_histogram.py index fdd5cc1..fb2ca0c 100644 --- a/tf_pwa/tests/test_histogram.py +++ b/tf_pwa/tests/test_histogram.py @@ -16,6 +16,7 @@ def test_hist1d(): hist.draw_kde(ax, kind="gauss", color="blue") hist.draw_kde(ax, kind="cauchy", color="red") (hist * 0.5).draw(ax, type="fill", facecolor="none", hatch="///") + (hist * 0.4).draw_stepfill(ax, facecolor="none", hatch="\\") (0.1 * hist + hist * 0.1).draw_bar(ax) hist.draw_error(ax) hist2 = Hist1D.histogram(