From ca871bbdb1cb1ed9cdb0333cdd8f1a342147fc78 Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Fri, 23 Feb 2024 14:36:40 +0800 Subject: [PATCH 1/3] fixed: avoid create variable when fixed --- tf_pwa/variable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index f81c57fd..488d7e7d 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -424,8 +424,8 @@ def set_fix(self, name, value=None, unfix=False): else: if name in self.bnd_dic: value = self.bnd_dic[name].get_y2x(value) - var = tf.Variable(value, dtype=self.dtype, trainable=unfix) - self.variables[name] = var + self.variables[name].assign(value) + self.variables[name]._trainable = unfix if unfix: if name in self.trainable_vars: warnings.warn("{} has been freed already!".format(name)) From 4f1e8401c9c6c2b7fdadda575f1a160444c28ccb Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Fri, 23 Feb 2024 15:16:00 +0800 Subject: [PATCH 2/3] fixed: duplicate setting of same var --- tf_pwa/variable.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 488d7e7d..78de19cf 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -421,6 +421,8 @@ def set_fix(self, name, value=None, unfix=False): """ if value is None: value = self.variables[name].value + if callable(value): + value = value() else: if name in self.bnd_dic: value = self.bnd_dic[name].get_y2x(value) @@ -503,18 +505,27 @@ def set_same(self, name_list, cplx=False): :param cplx: Boolean. Whether the variables are complex or real. """ tmp_list = [] + head_list = [] for name in name_list: for add_list in self.same_list: if name not in self.variables: continue if name in add_list: tmp_list += add_list + head_list += [add_list[0]] self.same_list.remove(add_list) break + # use head to avoid duplicate setting + new_name_list = head_list + for i in name_list: + if i not in tmp_list: + new_name_list.append(i) + + # remove duplicate items for i in tmp_list: if i not in name_list: - name_list.append(i) # 去掉重复元素 + name_list.append(i) def same_real(name_list): name_list = [i for i in name_list if i in self.variables] @@ -534,10 +545,10 @@ def same_real(name_list): self.variables[name] = var if cplx: - same_real([name + "r" for name in name_list]) - same_real([name + "i" for name in name_list]) + same_real([name + "r" for name in new_name_list]) + same_real([name + "i" for name in new_name_list]) else: - same_real(name_list) + same_real(new_name_list) self.same_list.append(name_list) def get(self, name, val_in_fit=True): From 8326e878464097dc0324ca8daedc16f41b603f3b Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Sat, 24 Feb 2024 10:28:10 +0800 Subject: [PATCH 3/3] ci: imporve test coverage --- tf_pwa/tests/test_variable.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tf_pwa/tests/test_variable.py b/tf_pwa/tests/test_variable.py index db4ab691..2cc6ac09 100644 --- a/tf_pwa/tests/test_variable.py +++ b/tf_pwa/tests/test_variable.py @@ -144,6 +144,25 @@ def test_rename(): assert vm.get("ci") == 3 +def test_same(): + with variable_scope() as vm: + Variable("a", value=1.0) + Variable("b", value=1.0) + Variable("c", value=1.0) + vm.set_same(["a", "b"]) + vm.set_same(["b", "c"]) + assert len(vm.trainable_vars) == 1 + + +def test_fixed(): + with variable_scope() as vm: + Variable("a", value=1.0) + Variable("b", value=1.0) + vm.set_fix("a", 1.0) + vm.set_fix("b", unfix=True) + assert len(vm.trainable_vars) == 1 + + def test_transform(): from tf_pwa.config_loader import ConfigLoader from tf_pwa.transform import BaseTransform, register_trans