Skip to content

Commit

Permalink
feat: pre_trans for variable
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Sep 14, 2023
1 parent 20ac860 commit de7c2c9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
11 changes: 11 additions & 0 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def add_constraints(self, amp):
self.add_free_var_constraints(amp, constrains.get("free_var", []))
self.add_var_range_constraints(amp, constrains.get("var_range", {}))
self.add_var_equal_constraints(amp, constrains.get("var_equal", []))
self.add_pre_trans_constraints(amp, constrains.get("pre_trans", None))
self.add_gauss_constr_constraints(
amp, constrains.get("gauss_constr", {})
)
Expand Down Expand Up @@ -312,6 +313,16 @@ def add_var_equal_constraints(self, amp, dic=None):
print("same value:", k)
amp.vm.set_same(k)

def add_pre_trans_constraints(self, amp, dic=None):
if dic is None:
return
from tf_pwa.transform import create_trans

for k, v in dic.items():
print("transform:", k, v)
trans = create_trans(v)
amp.vm.pre_trans[k] = trans

def add_decay_constraints(self, amp, dic=None):
if dic is None:
dic = {}
Expand Down
5 changes: 5 additions & 0 deletions tf_pwa/tests/config_toy2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ particle:
constrains:
particle: null
decay: null
pre_trans:
"R_BC_mass":
model: linear
k: 1.0
b: 0.0

plot:
config:
Expand Down
16 changes: 13 additions & 3 deletions tf_pwa/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self, name="", dtype=tf.float64):
self.complex_vars = {} # {name:polar(bool),...}
self.same_list = [] # [[name1,name2],...]
self.mask_vars = {}
self.pre_trans = {}

self.bnd_dic = {} # {name:(a,b),...}

Expand Down Expand Up @@ -534,15 +535,21 @@ def get(self, name, val_in_fit=True):
if name not in self.variables:
raise Exception("{} not found".format(name))
if not val_in_fit or name not in self.bnd_dic:
return self.variables[name].numpy() # tf.Variable
value = self.variables[name]
if name in self.pre_trans:
value = self.pre_trans[name](value)
return value.numpy() # tf.Variable
else:
return self.bnd_dic[name].get_y2x(self.variables[name].numpy())

def read(self, name):
val = self.variables[name]
if name in self.mask_vars:
return tf.stop_gradient(tf.cast(self.mask_vars[name], val.dtype))
return self.variables[name]
val = tf.stop_gradient(tf.cast(self.mask_vars[name], val.dtype))
if name in self.pre_trans:
trans = self.pre_trans[name]
val = trans(val)
return val

def set(self, name, value, val_in_fit=True):
"""
Expand All @@ -554,6 +561,9 @@ def set(self, name, value, val_in_fit=True):
if val_in_fit and name in self.bnd_dic:
value = self.bnd_dic[name].get_x2y(value)
if name in self.variables:
if name in self.pre_trans:
trans = self.pre_trans[name]
value = trans.inverse(value)
self.variables[name].assign(value)
else:
warnings.warn("{} not found".format(name))
Expand Down

0 comments on commit de7c2c9

Please sign in to comment.