Skip to content

Commit

Permalink
Merge pull request jiangyi15#141 from jiangyi15/fix_same
Browse files Browse the repository at this point in the history
Fixed some bugs of variables
  • Loading branch information
jiangyi15 authored Feb 24, 2024
2 parents 8225a1b + 8326e87 commit 81d5ac3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
19 changes: 19 additions & 0 deletions tf_pwa/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ def test_rename():
assert vm.get("ci") == 3


def test_same():
with variable_scope() as vm:
Variable("a", value=1.0)
Variable("b", value=1.0)
Variable("c", value=1.0)
vm.set_same(["a", "b"])
vm.set_same(["b", "c"])
assert len(vm.trainable_vars) == 1


def test_fixed():
with variable_scope() as vm:
Variable("a", value=1.0)
Variable("b", value=1.0)
vm.set_fix("a", 1.0)
vm.set_fix("b", unfix=True)
assert len(vm.trainable_vars) == 1


def test_transform():
from tf_pwa.config_loader import ConfigLoader
from tf_pwa.transform import BaseTransform, register_trans
Expand Down
23 changes: 17 additions & 6 deletions tf_pwa/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,13 @@ def set_fix(self, name, value=None, unfix=False):
"""
if value is None:
value = self.variables[name].value
if callable(value):
value = value()
else:
if name in self.bnd_dic:
value = self.bnd_dic[name].get_y2x(value)
var = tf.Variable(value, dtype=self.dtype, trainable=unfix)
self.variables[name] = var
self.variables[name].assign(value)
self.variables[name]._trainable = unfix
if unfix:
if name in self.trainable_vars:
warnings.warn("{} has been freed already!".format(name))
Expand Down Expand Up @@ -503,18 +505,27 @@ def set_same(self, name_list, cplx=False):
:param cplx: Boolean. Whether the variables are complex or real.
"""
tmp_list = []
head_list = []
for name in name_list:
for add_list in self.same_list:
if name not in self.variables:
continue
if name in add_list:
tmp_list += add_list
head_list += [add_list[0]]
self.same_list.remove(add_list)
break

# use head to avoid duplicate setting
new_name_list = head_list
for i in name_list:
if i not in tmp_list:
new_name_list.append(i)

# remove duplicate items
for i in tmp_list:
if i not in name_list:
name_list.append(i) # 去掉重复元素
name_list.append(i)

def same_real(name_list):
name_list = [i for i in name_list if i in self.variables]
Expand All @@ -534,10 +545,10 @@ def same_real(name_list):
self.variables[name] = var

if cplx:
same_real([name + "r" for name in name_list])
same_real([name + "i" for name in name_list])
same_real([name + "r" for name in new_name_list])
same_real([name + "i" for name in new_name_list])
else:
same_real(name_list)
same_real(new_name_list)
self.same_list.append(name_list)

def get(self, name, val_in_fit=True):
Expand Down

0 comments on commit 81d5ac3

Please sign in to comment.