From de7c2c9eb3d4bf041e0684835eae83d59fdb3e9f Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Thu, 14 Sep 2023 23:21:28 +0800 Subject: [PATCH] feat: pre_trans for variable --- tf_pwa/config_loader/config_loader.py | 11 +++++++++++ tf_pwa/tests/config_toy2.yml | 5 +++++ tf_pwa/variable.py | 16 +++++++++++++--- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index 720ace0f..67506db2 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -264,6 +264,7 @@ def add_constraints(self, amp): self.add_free_var_constraints(amp, constrains.get("free_var", [])) self.add_var_range_constraints(amp, constrains.get("var_range", {})) self.add_var_equal_constraints(amp, constrains.get("var_equal", [])) + self.add_pre_trans_constraints(amp, constrains.get("pre_trans", None)) self.add_gauss_constr_constraints( amp, constrains.get("gauss_constr", {}) ) @@ -312,6 +313,16 @@ def add_var_equal_constraints(self, amp, dic=None): print("same value:", k) amp.vm.set_same(k) + def add_pre_trans_constraints(self, amp, dic=None): + if dic is None: + return + from tf_pwa.transform import create_trans + + for k, v in dic.items(): + print("transform:", k, v) + trans = create_trans(v) + amp.vm.pre_trans[k] = trans + def add_decay_constraints(self, amp, dic=None): if dic is None: dic = {} diff --git a/tf_pwa/tests/config_toy2.yml b/tf_pwa/tests/config_toy2.yml index 9be6ab72..24842adc 100644 --- a/tf_pwa/tests/config_toy2.yml +++ b/tf_pwa/tests/config_toy2.yml @@ -34,6 +34,11 @@ particle: constrains: particle: null decay: null + pre_trans: + "R_BC_mass": + model: linear + k: 1.0 + b: 0.0 plot: config: diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 498afa7a..54524bd2 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -99,6 +99,7 @@ def __init__(self, name="", dtype=tf.float64): self.complex_vars = {} # {name:polar(bool),...} self.same_list = [] # [[name1,name2],...] self.mask_vars = {} + self.pre_trans = {} self.bnd_dic = {} # {name:(a,b),...} @@ -534,15 +535,21 @@ def get(self, name, val_in_fit=True): if name not in self.variables: raise Exception("{} not found".format(name)) if not val_in_fit or name not in self.bnd_dic: - return self.variables[name].numpy() # tf.Variable + value = self.variables[name] + if name in self.pre_trans: + value = self.pre_trans[name](value) + return value.numpy() # tf.Variable else: return self.bnd_dic[name].get_y2x(self.variables[name].numpy()) def read(self, name): val = self.variables[name] if name in self.mask_vars: - return tf.stop_gradient(tf.cast(self.mask_vars[name], val.dtype)) - return self.variables[name] + val = tf.stop_gradient(tf.cast(self.mask_vars[name], val.dtype)) + if name in self.pre_trans: + trans = self.pre_trans[name] + val = trans(val) + return val def set(self, name, value, val_in_fit=True): """ @@ -554,6 +561,9 @@ def set(self, name, value, val_in_fit=True): if val_in_fit and name in self.bnd_dic: value = self.bnd_dic[name].get_x2y(value) if name in self.variables: + if name in self.pre_trans: + trans = self.pre_trans[name] + value = trans.inverse(value) self.variables[name].assign(value) else: warnings.warn("{} not found".format(name))