Skip to content

Commit

Permalink
Merge pull request #97 from jiangyi15/grad_scale
Browse files Browse the repository at this point in the history
feat: grad_scale option for fit
  • Loading branch information
jiangyi15 authored Aug 20, 2023
2 parents b8e892c + 05fa8ae commit f5db1ef
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
4 changes: 4 additions & 0 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,8 @@ def fit(
jac=True,
print_init_nll=True,
callback=None,
grad_scale=1.0,
gtol=1e-3,
):
if data is None and phsp is None:
data, phsp, bg, inmc = self.get_all_data()
Expand Down Expand Up @@ -701,6 +703,8 @@ def fit(
maxiter=maxiter,
jac=jac,
callback=callback,
grad_scale=grad_scale,
gtol=gtol,
)
if self.fit_params.hess_inv is not None:
self.inv_he = self.fit_params.hess_inv
Expand Down
14 changes: 8 additions & 6 deletions tf_pwa/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def fit_scipy(
jac=True,
callback=None,
standard_complex=True,
grad_scale=1.0,
gtol=1e-3,
):
"""
Expand Down Expand Up @@ -268,7 +270,7 @@ def callback(x):
fcn.vm.set_bound(bounds_dict)

f_g = fcn.vm.trans_fcn_grad(fcn.nll_grad)
f_g = Cached_FG(f_g)
f_g = Cached_FG(f_g, grad_scale=grad_scale)
# print(f_g)
x0 = np.array(fcn.vm.get_all_val(True))
# print(x0, fcn.vm.get_all_dic())
Expand All @@ -281,7 +283,7 @@ def callback(x):
method=method,
jac=True,
callback=callback,
options={"disp": 1, "gtol": 1e-3, "maxiter": maxiter},
options={"disp": 1, "gtol": gtol, "maxiter": maxiter},
)
except LargeNumberError:
return except_result(fcn, x0.shape[0])
Expand All @@ -293,7 +295,7 @@ def callback(x):
method=method,
jac=jac,
callback=callback,
options={"disp": 1, "gtol": 1e-3, "maxiter": maxiter},
options={"disp": 1, "gtol": gtol, "maxiter": maxiter},
)
except LargeNumberError:
return except_result(fcn, x0.shape[0])
Expand All @@ -305,7 +307,7 @@ def callback(x):
method=method,
jac=True,
callback=callback,
options={"disp": 1, "gtol": 1e-3, "maxiter": maxiter},
options={"disp": 1, "gtol": gtol, "maxiter": maxiter},
)
except LargeNumberError:
return except_result(fcn, x0.shape[0])
Expand All @@ -319,7 +321,7 @@ def callback(x):
method=method,
jac=True,
callback=callback,
options={"disp": 1, "gtol": 1e-2, "maxiter": maxiter},
options={"disp": 1, "gtol": gtol * 10, "maxiter": maxiter},
)
if hasattr(s, "hess_inv"):
edm = np.dot(np.dot(s.hess_inv, s.jac), s.jac)
Expand All @@ -339,7 +341,7 @@ def callback(x):
ndf = s.x.shape[0]
min_nll = s.fun
success = s.success
hess_inv = fcn.vm.trans_error_matrix(s.hess_inv, s.x)
hess_inv = fcn.vm.trans_error_matrix(s.hess_inv / grad_scale, s.x)
fcn.vm.remove_bound()

xn = fcn.vm.get_all_val()
Expand Down
5 changes: 3 additions & 2 deletions tf_pwa/fit_improve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ class LineSearchWarning(RuntimeWarning):


class Cached_FG:
def __init__(self, f_g):
def __init__(self, f_g, grad_scale=1.0):
self.f_g = f_g
self.cached_fun = 0
self.cached_grad = 0
self.ncall = 0
self.grad_scale = grad_scale

def __call__(self, x):
f = self.fun(x)
Expand All @@ -38,7 +39,7 @@ def __call__(self, x):
new_x[i] -= 1e-6
f2, _ = self.f_g(new_x)
self.cached_grad[i] = (f1 - f2) / 2e-6
return f, self.cached_grad
return self.grad_scale * f, self.grad_scale * self.cached_grad

def fun(self, x):
f, g = self.f_g(x)
Expand Down

0 comments on commit f5db1ef

Please sign in to comment.