Skip to content

Commit

Permalink
Regparam grid bug fix (#477)
Browse files Browse the repository at this point in the history
* Update for 3.12

* Updated workflows and changelog

* Upload to 3.12

* Fix bug in regparam grid search

Regparam would never build the grid correctly. Now using grid or Brent is automatically determined from number of elements in the regparamrange.

* Add extra error messages

* Update changelog

* Updated Example

* Updated test

The test has been updated. The previous convergence criteria was unreliable and only worked based on a coincidence.

* Prepare For Release

* Remove duplicate python version
  • Loading branch information
HKaras authored Jul 15, 2024
1 parent 534495e commit 178249e
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.8, 3.9, "3.10", "3.11", "3.12","3.12"]
python-version: [3.8, 3.9, "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3

Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2019-2023 Luis Fabregas, Stefan Stoll, Gunnar Jeschke, and other contributors
Copyright (c) 2019-2024 Luis Fabregas, Stefan Stoll, Gunnar Jeschke, and other contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
41 changes: 27 additions & 14 deletions deerlab/selregparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import math as m
import deerlab as dl

def selregparam(y, A, solver, method='aic', algorithm='brent', noiselvl=None,
searchrange=[1e-8,1e2],regop=None, weights=None, full_output=False, candidates=None):
def selregparam(y, A, solver, method='aic', algorithm='auto', noiselvl=None,
searchrange=[1e-8,1e2],regop=None, weights=None, full_output=False):
r"""
Selection of optimal regularization parameter based on a selection criterion.
Expand Down Expand Up @@ -52,6 +52,8 @@ def selregparam(y, A, solver, method='aic', algorithm='brent', noiselvl=None,
* ``'ncp'`` - Normalized Cumulative Periodogram (NCP)
* ``'gml'`` - Generalized Maximum Likelihood (GML)
* ``'mcl'`` - Mallows' C_L (MCL)
If ``'lr'`` or ``'lc'`` is specified, the search algorithm is automatically set to ``'grid'``.
weights : array_like, optional
Array of weighting coefficients for the individual datasets in global fitting.
Expand All @@ -60,18 +62,16 @@ def selregparam(y, A, solver, method='aic', algorithm='brent', noiselvl=None,
algorithm : string, optional
Search algorithm:
* ``'auto'`` - Automatically, selects algrothium based on the searchrange. If the searchrange has two elements its set to ``'brent'`` otherwise to ``'grid'``.
* ``'grid'`` - Grid-search, slow.
* ``'brent'`` - Brent-algorithm, fast.
The default is ``'brent'``.
searchrange : two-element list, optional
Search range for the optimization of the regularization parameter with the ``'brent'`` algorithm.
If not specified the default search range defaults to ``[1e-8,1e2]``.
The default is ``'auto'``.
candidates : list, optional
List or array of candidate regularization parameter values to be evaluated with the ``'grid'`` algorithm.
If not specified, these are automatically computed from a grid within ``searchrange``.
searchrange : list, optional
Either the search range for the optimization of the regularization parameter with the ``'brent'`` algorithm.
Or if more than two values are specified, then it is interpreted as candidates for the ``'grid'`` algorithm.
If not specified the default search range defaults to ``[1e-8,1e2]`` and the ``'brent'`` algorithm.
regop : 2D array_like, optional
Regularization operator matrix, the default is the second-order differential operator.
Expand Down Expand Up @@ -108,6 +108,13 @@ def selregparam(y, A, solver, method='aic', algorithm='brent', noiselvl=None,
# If multiple datasets are passed, concatenate the datasets and kernels
y, A, weights,_,__, noiselvl = dl.utils.parse_multidatasets(y, A, weights, noiselvl)

if algorithm == 'auto' and len(searchrange) == 2:
algorithm = 'brent'
elif algorithm == 'auto' and len(searchrange) > 2:
algorithm = 'grid'
elif algorithm == 'auto' and len(searchrange) < 2:
raise ValueError("`searchrange` must have at least two elements if `algorithm` is set to `'auto'")

# The L-curve criteria require a grid-evaluation
if method == 'lr' or method == 'lc':
algorithm = 'grid'
Expand All @@ -121,9 +128,11 @@ def selregparam(y, A, solver, method='aic', algorithm='brent', noiselvl=None,
evalalpha = lambda alpha: _evalalpha(alpha, y, A, L, solver, method, noiselvl, weights)

# Evaluate functional over search range, using specified search method
if algorithm == 'brent':
if algorithm == 'brent':

# Search boundaries
if len(searchrange) != 2:
raise ValueError("Search range must have two elements for the 'brent' algorithm.")
lga_min = m.log10(searchrange[0])
lga_max = m.log10(searchrange[1])

Expand All @@ -148,10 +157,10 @@ def register_ouputs(optout):
elif algorithm=='grid':

# Get range of potential alpha values candidates
if candidates is None:
if len(searchrange) == 2:
alphaCandidates = 10**np.linspace(np.log10(searchrange[0]),np.log10(searchrange[1]),60)
else:
alphaCandidates = np.atleast_1d(candidates)
alphaCandidates = np.atleast_1d(searchrange)

# Evaluate the full grid of alpha-candidates
functional,residuals,penalties,alphas_evaled = tuple(zip(*[evalalpha(alpha) for alpha in alphaCandidates]))
Expand All @@ -176,6 +185,10 @@ def register_ouputs(optout):

# Find minimum of the selection functional
alphaOpt = alphaCandidates[np.argmin(functional)]
functional = np.array(functional)
residuals = np.array(residuals)
penalties = np.array(penalties)
alphas_evaled = np.array(alphas_evaled)
else:
raise KeyError("Search method not found. Must be either 'brent' or 'grid'.")

Expand Down
4 changes: 2 additions & 2 deletions deerlab/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,8 @@ def linear_problem(y,A,optimize_alpha,alpha):
if optimize_alpha:
linsolver_result = lambda AtA, Aty: parseResult(linSolver(AtA, Aty))
output = dl.selregparam((y-yfrozen)[mask], Ared[mask,:], linsolver_result, regparam,
weights=weights[mask], regop=L, candidates=regparamrange,
noiselvl=noiselvl,searchrange=regparamrange,full_output=True)
weights=weights[mask], regop=L, noiselvl=noiselvl,
searchrange=regparamrange,full_output=True)
alpha = output[0]
alpha_stats['alphas_evaled'] = output[1]
alpha_stats['functional'] = output[2]
Expand Down
5 changes: 3 additions & 2 deletions docsrc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ Release Notes
- |fix| : Something which was not working as expected or leading to errors has been fixed.
- |api| : This will require changes in your scripts or code.

Release ``v1.1.3`` - Ongoing
Release ``v1.1.3`` - July 2024
------------------------------------------
- |fix| : Removes unnecessary files from the docs
- |efficiency| : Improves the performance of the ``dipolarkernel`` function by 10-30% (:pr:`473`), by caching the interpolation of he effective dipolar evolution time vector.
- |fix| : Add support for Python 3.12
- |api| : Removes the `candidates` input from `selregparam` and integrates its function into `regparamrange`.
- |fix| : Adds support for Numpy 2.0
- |fix| : Add support for Python 3.12


Release ``v1.1.2`` - November 2023
Expand Down
54 changes: 52 additions & 2 deletions examples/intermediate/ex_selregparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
t = t + tmin

# Distance vector
r = np.linspace(1.5,7,50) # nm
r = np.linspace(1.5,7,100) # nm

# Construct the model
Vmodel = dl.dipolarmodel(t,r, experiment=dl.ex_4pdeer(tau1,tau2, pathways=[1]))
Expand Down Expand Up @@ -101,7 +101,7 @@

# %%
# Over and Under selection of the regularisation parameter
# --------------------------------------------------------
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Here we will demonstrate the effect of selecting a regularisation parameter
# that is either too small or too large.

Expand Down Expand Up @@ -145,8 +145,58 @@
# As we can see when the regularisation parameter is too small we still get a high
# quality fit in the time domain, however, our distance domain data is now way too
# spikey and non-physical.

# In contrast when the regularisation parameter is too large we struggle to get
# a good fit, however, we get a much smoother distance distribution.
# This could have been seen from the selection functional above. The effect of
# lower regularisation parameter had a smaller effect on the functional than the
# effect of going to a larger one.

# Plotting the full L-Curve
# ++++++++++++++++++++++++++
# Normally the selection of the regularisation parameter is done through a Brent optimisation
# algorithm. This results in the L-Curve being sparsely sampled. However, the full L-Curve can be
# generated by evaluating the functional at specified regularisation paramters.

# This has the disadvantage of being computationally expensive, however, it can be useful.

regparamrange = 10**np.linspace(np.log10(1e-6),np.log10(1e1),60) # Builds the range of regularisation parameters
results_grid= dl.fit(Vmodel,Vexp,regparam='bic',regparamrange=regparamrange)
print(results_grid)

fig, axs =plt.subplots(1,3, figsize=(9,4),width_ratios=(1,1,0.1))
alphas = results_grid.regparam_stats['alphas_evaled'][:]
funcs = results_grid.regparam_stats['functional'][:]


idx = np.argsort(alphas)

axs[0].semilogx(alphas[idx], funcs[idx],marker='.',ls='-')
axs[0].set_title(r"$\alpha$ selection functional");
axs[0].set_xlabel("Regularisation Parameter")
axs[0].set_ylabel("Functional Value ")

# Just the final L-Curve
x = results_grid.regparam_stats['residuals']
y = results_grid.regparam_stats['penalties']
idx = np.argsort(x)


axs[1].loglog(x[idx],y[idx])

n_points = results_grid.regparam_stats['alphas_evaled'].shape[-1]
lams = results_grid.regparam_stats['alphas_evaled']
norm = mpl.colors.LogNorm(vmin=lams[:].min(), vmax=lams.max())
for i in range(n_points):
axs[1].plot(x[i], y[i],marker = '.', ms=8, color=cmap(norm(lams[i])))

i_optimal = np.argmin(np.abs(lams - results_grid.regparam))
axs[1].annotate(fr"$\alpha =$ {results_grid.regparam:.2g}", xy = (x[i_optimal],y[i_optimal]),arrowprops=dict(facecolor='black', shrink=0.05, width=5), xytext=(20, 20),textcoords='offset pixels')
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),cax=axs[2])
axs[1].set_ylabel("Penalties")
axs[2].set_ylabel("Regularisation Parameter")
axs[1].set_xlabel("Residuals")
axs[1].set_title("L-Curve");
fig.tight_layout()

# %%
9 changes: 5 additions & 4 deletions test/test_selregparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def test_unconstrained(dataset,design_matrix,regularization_matrix):
#=======================================================================
def test_manual_candidates(dataset,design_matrix,regularization_matrix):
"Check that the alpha-search range can be manually passed"
alphas = np.linspace(-8,2,60)
alpha_manual = np.log10(selregparam(dataset,design_matrix,cvxnnls,method='aic',candidates=alphas,regop=regularization_matrix))
alpha_auto = np.log10(selregparam(dataset,design_matrix,cvxnnls,method='aic',regop=regularization_matrix))
assert abs(alpha_manual-alpha_auto)<1e-4
alphas = np.logspace(-8,3,60,base=10)
alpha_manual = np.log10(selregparam(dataset,design_matrix,cvxnnls,method='aic',searchrange=alphas,regop=regularization_matrix))
alpha_auto = np.log10(selregparam(dataset,design_matrix,cvxnnls,method='aic',searchrange=(1e-8,1e-3), regop=regularization_matrix))

assert abs(alpha_manual-alpha_auto)<1
#=======================================================================

#=======================================================================
Expand Down

0 comments on commit 178249e

Please sign in to comment.