From de7c2c9eb3d4bf041e0684835eae83d59fdb3e9f Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Thu, 14 Sep 2023 23:21:28 +0800 Subject: [PATCH 1/8] feat: pre_trans for variable --- tf_pwa/config_loader/config_loader.py | 11 +++++++++++ tf_pwa/tests/config_toy2.yml | 5 +++++ tf_pwa/variable.py | 16 +++++++++++++--- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 720ace0f..67506db2 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -264,6 +264,7 @@ def add_constraints(self, amp): self.add_free_var_constraints(amp, constrains.get("free_var", [])) self.add_var_range_constraints(amp, constrains.get("var_range", {})) self.add_var_equal_constraints(amp, constrains.get("var_equal", [])) + self.add_pre_trans_constraints(amp, constrains.get("pre_trans", None)) self.add_gauss_constr_constraints( amp, constrains.get("gauss_constr", {}) ) @@ -312,6 +313,16 @@ def add_var_equal_constraints(self, amp, dic=None): print("same value:", k) amp.vm.set_same(k) + def add_pre_trans_constraints(self, amp, dic=None): + if dic is None: + return + from tf_pwa.transform import create_trans + + for k, v in dic.items(): + print("transform:", k, v) + trans = create_trans(v) + amp.vm.pre_trans[k] = trans + def add_decay_constraints(self, amp, dic=None): if dic is None: dic = {} diff --git a/tf_pwa/tests/config_toy2.yml b/tf_pwa/tests/config_toy2.yml index 9be6ab72..24842adc 100644 --- a/tf_pwa/tests/config_toy2.yml +++ b/tf_pwa/tests/config_toy2.yml @@ -34,6 +34,11 @@ particle: constrains: particle: null decay: null + pre_trans: + "R_BC_mass": + model: linear + k: 1.0 + b: 0.0 plot: config: diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 498afa7a..54524bd2 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -99,6 +99,7 @@ def __init__(self, name="", dtype=tf.float64): self.complex_vars = {} # {name:polar(bool),...} self.same_list = [] # [[name1,name2],...] self.mask_vars = {} + self.pre_trans = {} self.bnd_dic = {} # {name:(a,b),...} @@ -534,15 +535,21 @@ def get(self, name, val_in_fit=True): if name not in self.variables: raise Exception("{} not found".format(name)) if not val_in_fit or name not in self.bnd_dic: - return self.variables[name].numpy() # tf.Variable + value = self.variables[name] + if name in self.pre_trans: + value = self.pre_trans[name](value) + return value.numpy() # tf.Variable else: return self.bnd_dic[name].get_y2x(self.variables[name].numpy()) def read(self, name): val = self.variables[name] if name in self.mask_vars: - return tf.stop_gradient(tf.cast(self.mask_vars[name], val.dtype)) - return self.variables[name] + val = tf.stop_gradient(tf.cast(self.mask_vars[name], val.dtype)) + if name in self.pre_trans: + trans = self.pre_trans[name] + val = trans(val) + return val def set(self, name, value, val_in_fit=True): """ @@ -554,6 +561,9 @@ def set(self, name, value, val_in_fit=True): if val_in_fit and name in self.bnd_dic: value = self.bnd_dic[name].get_x2y(value) if name in self.variables: + if name in self.pre_trans: + trans = self.pre_trans[name] + value = trans.inverse(value) self.variables[name].assign(value) else: warnings.warn("{} not found".format(name)) From ca56654e5923665c0b05e0b16c3ebe6a937d8b95 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Fri, 15 Sep 2023 00:11:27 +0800 Subject: [PATCH 2/8] fixed: ci error --- tf_pwa/transform.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tf_pwa/transform.py diff --git a/tf_pwa/transform.py b/tf_pwa/transform.py new file mode 100644 index 00000000..c4726f0e --- /dev/null +++ b/tf_pwa/transform.py @@ -0,0 +1,39 @@ +from .config import create_config + +set_trans, get_trans, register_trans = create_config() + +T = "Tensor" + + +class BaseTransform: + def __call__(self, x: T) -> T: + return self.call(x) + + def call(self, x: T) -> T: + raise NotImplementedError() + + def inverse(self, y: T) -> T: + raise NotImplementedError() + + +def create_trans(item: dict) -> BaseTransform: + cls = get_trans(item.get("model", "default")) + obj = cls(**item) + return obj + + +@register_trans("default") +@register_trans("linear") +class LinearTrans(BaseTransform): + def __init__( + self, k: float = 1.0, b: float = 0.0, model: str = "", **kwargs + ): + self.k = k + self.b = b + self.model = model + + def call(self, x: T) -> T: + return self.k * x + self.b + + def inverse(self, x: T) -> T: + return (x - self.b) / self.k From d201626a193da7af2aaceba8656c7c835adbef32 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Fri, 15 Sep 2023 21:22:18 +0800 Subject: [PATCH 3/8] refactor: use from_trans to constrains two variables --- tf_pwa/config_loader/config_loader.py | 16 ++++++++++++++++ tf_pwa/tests/config_toy2.yml | 7 ++++--- tf_pwa/transform.py | 9 ++++----- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 67506db2..542fb8a5 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -265,6 +265,9 @@ def add_constraints(self, amp): self.add_var_range_constraints(amp, constrains.get("var_range", {})) self.add_var_equal_constraints(amp, constrains.get("var_equal", [])) self.add_pre_trans_constraints(amp, constrains.get("pre_trans", None)) + self.add_from_trans_constraints( + amp, constrains.get("from_trans", None) + ) self.add_gauss_constr_constraints( amp, constrains.get("gauss_constr", {}) ) @@ -323,6 +326,19 @@ def add_pre_trans_constraints(self, amp, dic=None): trans = create_trans(v) amp.vm.pre_trans[k] = trans + def add_from_trans_constraints(self, amp, dic=None): + if dic is None: + return + var_equal = [] + pre_trans = {} + for k, v in dic.items(): + x = v.pop("x", None) + if x is not None: + var_equal.append([k, x]) + pre_trans[k] = v + self.add_pre_trans_constraints(amp, pre_trans) + self.add_var_equal_constraints(amp, var_equal) + def add_decay_constraints(self, amp, dic=None): if dic is None: dic = {} diff --git a/tf_pwa/tests/config_toy2.yml b/tf_pwa/tests/config_toy2.yml index 24842adc..17ef8530 100644 --- a/tf_pwa/tests/config_toy2.yml +++ b/tf_pwa/tests/config_toy2.yml @@ -34,11 +34,12 @@ particle: constrains: particle: null decay: null - pre_trans: - "R_BC_mass": + from_trans: + "R_BD_mass": + x: R_CD_mass model: linear k: 1.0 - b: 0.0 + b: 0.01 plot: config: diff --git a/tf_pwa/transform.py b/tf_pwa/transform.py index c4726f0e..4cb42584 100644 --- a/tf_pwa/transform.py +++ b/tf_pwa/transform.py @@ -17,20 +17,19 @@ def inverse(self, y: T) -> T: def create_trans(item: dict) -> BaseTransform: - cls = get_trans(item.get("model", "default")) + model = item.pop("model", "default") + cls = get_trans(model) obj = cls(**item) + obj._model_name = model return obj @register_trans("default") @register_trans("linear") class LinearTrans(BaseTransform): - def __init__( - self, k: float = 1.0, b: float = 0.0, model: str = "", **kwargs - ): + def __init__(self, k: float = 1.0, b: float = 0.0, **kwargs): self.k = k self.b = b - self.model = model def call(self, x: T) -> T: return self.k * x + self.b From a98227410ec9bc68edb041d61abcda1e613231ec Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 16 Sep 2023 22:33:52 +0800 Subject: [PATCH 4/8] refactor: support list of x for multi vars --- tf_pwa/config_loader/config_loader.py | 11 ++++++++++- tf_pwa/transform.py | 20 ++++++++++++++++---- tf_pwa/variable.py | 7 ++++--- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 542fb8a5..77f319ca 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -323,6 +323,7 @@ def add_pre_trans_constraints(self, amp, dic=None): for k, v in dic.items(): print("transform:", k, v) + v["x"] = v.get("x", k) trans = create_trans(v) amp.vm.pre_trans[k] = trans @@ -334,7 +335,15 @@ def add_from_trans_constraints(self, amp, dic=None): for k, v in dic.items(): x = v.pop("x", None) if x is not None: - var_equal.append([k, x]) + if isinstance(x, list): + var_equal.append([k, x[0]]) + elif isinstance(x, str): + var_equal.append([k, x]) + else: + raise TypeError("x should be str or list") + else: + x = k + v["x"] = x pre_trans[k] = v self.add_pre_trans_constraints(amp, pre_trans) self.add_var_equal_constraints(amp, var_equal) diff --git a/tf_pwa/transform.py b/tf_pwa/transform.py index 4cb42584..c36c5b34 100644 --- a/tf_pwa/transform.py +++ b/tf_pwa/transform.py @@ -6,14 +6,23 @@ class BaseTransform: - def __call__(self, x: T) -> T: + def __call__(self, dic: dict) -> T: + x = self.read(dic) return self.call(x) + def read(self, x: dict) -> T: + if isinstance(self.x, (list, tuple)): + return [x[i] for i in self.x] + elif isinstance(self.x, str): + return x[self.x] + else: + raise TypeError("only str of list of str is supported for x") + def call(self, x: T) -> T: raise NotImplementedError() def inverse(self, y: T) -> T: - raise NotImplementedError() + return None def create_trans(item: dict) -> BaseTransform: @@ -27,11 +36,14 @@ def create_trans(item: dict) -> BaseTransform: @register_trans("default") @register_trans("linear") class LinearTrans(BaseTransform): - def __init__(self, k: float = 1.0, b: float = 0.0, **kwargs): + def __init__( + self, x: (list | str), k: float = 1.0, b: float = 0.0, **kwargs + ): + self.x = x self.k = k self.b = b - def call(self, x: T) -> T: + def call(self, x) -> T: return self.k * x + self.b def inverse(self, x: T) -> T: diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 54524bd2..1ec0546f 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -537,7 +537,7 @@ def get(self, name, val_in_fit=True): if not val_in_fit or name not in self.bnd_dic: value = self.variables[name] if name in self.pre_trans: - value = self.pre_trans[name](value) + value = self.pre_trans[name](self.variables) return value.numpy() # tf.Variable else: return self.bnd_dic[name].get_y2x(self.variables[name].numpy()) @@ -548,7 +548,7 @@ def read(self, name): val = tf.stop_gradient(tf.cast(self.mask_vars[name], val.dtype)) if name in self.pre_trans: trans = self.pre_trans[name] - val = trans(val) + val = trans(self.variables) return val def set(self, name, value, val_in_fit=True): @@ -564,7 +564,8 @@ def set(self, name, value, val_in_fit=True): if name in self.pre_trans: trans = self.pre_trans[name] value = trans.inverse(value) - self.variables[name].assign(value) + if value is not None: + self.variables[name].assign(value) else: warnings.warn("{} not found".format(name)) From e07969b33ccd744bb8dafe5e3d8d04ef9ebefcde Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 16 Sep 2023 22:43:04 +0800 Subject: [PATCH 5/8] fixed: type|type --- tf_pwa/transform.py | 2 +- tf_pwa/variable.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tf_pwa/transform.py b/tf_pwa/transform.py index c36c5b34..53c863b6 100644 --- a/tf_pwa/transform.py +++ b/tf_pwa/transform.py @@ -37,7 +37,7 @@ def create_trans(item: dict) -> BaseTransform: @register_trans("linear") class LinearTrans(BaseTransform): def __init__( - self, x: (list | str), k: float = 1.0, b: float = 0.0, **kwargs + self, x: "list | str", k: float = 1.0, b: float = 0.0, **kwargs ): self.x = x self.k = k diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 1ec0546f..60b0bec5 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -643,13 +643,13 @@ def get_all_dic(self, trainable_only=False): dic = {} if trainable_only: for i in self.trainable_vars: - val = self.variables[i].numpy() + val = self.read(i).numpy() # if i in self.bnd_dic: # val = self.bnd_dic[i].get_y2x(val) dic[i] = val else: for i in self.variables: - val = self.variables[i].numpy() + val = self.read(i).numpy() # if i in self.bnd_dic: # val = self.bnd_dic[i].get_y2x(val) dic[i] = val From b2bc256b2b5435592a21ec4832d7dfc8197296cc Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sun, 17 Sep 2023 22:50:23 +0800 Subject: [PATCH 6/8] ci: add more tests --- tf_pwa/config_loader/config_loader.py | 12 ++++++------ tf_pwa/tests/test_variable.py | 27 +++++++++++++++++++++++++++ tf_pwa/transform.py | 5 ++++- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 77f319ca..56126001 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -335,18 +335,18 @@ def add_from_trans_constraints(self, amp, dic=None): for k, v in dic.items(): x = v.pop("x", None) if x is not None: - if isinstance(x, list): - var_equal.append([k, x[0]]) - elif isinstance(x, str): - var_equal.append([k, x]) + if isinstance(x, list) and k != x[0]: + var_equal.append([x[0], k]) + elif isinstance(x, str) and x != k: + var_equal.append([x, k]) else: raise TypeError("x should be str or list") else: x = k v["x"] = x pre_trans[k] = v - self.add_pre_trans_constraints(amp, pre_trans) - self.add_var_equal_constraints(amp, var_equal) + ConfigLoader.add_pre_trans_constraints(self, amp, pre_trans) + ConfigLoader.add_var_equal_constraints(self, amp, var_equal) def add_decay_constraints(self, amp, dic=None): if dic is None: diff --git a/tf_pwa/tests/test_variable.py b/tf_pwa/tests/test_variable.py index 8fd6b25c..db0aa11a 100644 --- a/tf_pwa/tests/test_variable.py +++ b/tf_pwa/tests/test_variable.py @@ -142,3 +142,30 @@ def test_rename(): vm.rename_var("d", "c", True) assert vm.get("cr") == 2 assert vm.get("ci") == 3 + + +def test_transform(): + from tf_pwa.config_loader import ConfigLoader + from tf_pwa.transform import BaseTransform, register_trans + + @register_trans("__test1") + class Atrans(BaseTransform): + def call(self, x): + return x[0] + x[1] + + class Tmp: + pass + + tmp = Tmp() + with variable_scope() as vm: + tmp.vm = vm + Variable("a", value=1.0) + Variable("b", value=1.0) + Variable("c", value=1.0) + ConfigLoader.add_from_trans_constraints( + None, tmp, {"a": {"x": ["c", "b"], "model": "__test1"}} + ) + assert np.allclose(vm.get("a"), 2.0) + vm.set("a", 1.0) + assert vm.get_all_dic(False) == {"a": 2.0, "b": 1.0, "c": 1.0} + print(vm.get_all_dic(True)) diff --git a/tf_pwa/transform.py b/tf_pwa/transform.py index 53c863b6..79ce8734 100644 --- a/tf_pwa/transform.py +++ b/tf_pwa/transform.py @@ -6,6 +6,9 @@ class BaseTransform: + def __init__(self, x: "list | str", **kwargs): + self.x = x + def __call__(self, dic: dict) -> T: x = self.read(dic) return self.call(x) @@ -39,7 +42,7 @@ class LinearTrans(BaseTransform): def __init__( self, x: "list | str", k: float = 1.0, b: float = 0.0, **kwargs ): - self.x = x + super().__init__(x) self.k = k self.b = b From 53a627e318f6d680654f6a2357a9b87a57bb1bbc Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sun, 17 Sep 2023 23:57:14 +0800 Subject: [PATCH 7/8] fixed: ci error --- tf_pwa/variable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 60b0bec5..e4a432db 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -511,8 +511,8 @@ def same_real(name_list): self.trainable_vars.remove(name) else: # if one is untrainable, the others will all be untrainable - var = self.variables.get(name, None) - if var is not None: + var2 = self.variables.get(name, None) + if var2 is not None: if name_list[0] in self.trainable_vars: self.trainable_vars.remove(name_list[0]) for name in name_list: From cf340bbc92406c44bb7f3442446d0b7aa9c00594 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Mon, 18 Sep 2023 21:31:39 +0800 Subject: [PATCH 8/8] misc: from trans: when name not in vm --- tf_pwa/config_loader/config_loader.py | 8 +++++++- tf_pwa/tests/test_variable.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 56126001..918bd4f0 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -332,12 +332,15 @@ def add_from_trans_constraints(self, amp, dic=None): return var_equal = [] pre_trans = {} + new_var = [] for k, v in dic.items(): x = v.pop("x", None) if x is not None: if isinstance(x, list) and k != x[0]: + new_var += x var_equal.append([x[0], k]) elif isinstance(x, str) and x != k: + new_var.append(x) var_equal.append([x, k]) else: raise TypeError("x should be str or list") @@ -345,8 +348,11 @@ def add_from_trans_constraints(self, amp, dic=None): x = k v["x"] = x pre_trans[k] = v - ConfigLoader.add_pre_trans_constraints(self, amp, pre_trans) + for i in new_var: + if i not in amp.vm.variables: + amp.vm.add_real_var(i) ConfigLoader.add_var_equal_constraints(self, amp, var_equal) + ConfigLoader.add_pre_trans_constraints(self, amp, pre_trans) def add_decay_constraints(self, amp, dic=None): if dic is None: diff --git a/tf_pwa/tests/test_variable.py b/tf_pwa/tests/test_variable.py index db0aa11a..db4ab691 100644 --- a/tf_pwa/tests/test_variable.py +++ b/tf_pwa/tests/test_variable.py @@ -161,10 +161,10 @@ class Tmp: tmp.vm = vm Variable("a", value=1.0) Variable("b", value=1.0) - Variable("c", value=1.0) ConfigLoader.add_from_trans_constraints( None, tmp, {"a": {"x": ["c", "b"], "model": "__test1"}} ) + vm.set("c", 1.0) assert np.allclose(vm.get("a"), 2.0) vm.set("a", 1.0) assert vm.get_all_dic(False) == {"a": 2.0, "b": 1.0, "c": 1.0}