Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tv warmstart #1493

Merged
merged 50 commits into from
Sep 18, 2023
Merged
Changes from 2 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
8fa44c3
TV warm_start file
MargaretDuff Aug 3, 2023
7b73566
Naming
MargaretDuff Aug 3, 2023
d39387b
Removed hasstarted and replaced with a property and setter for p2
MargaretDuff Aug 7, 2023
647dd41
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 8, 2023
76c9d0f
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 8, 2023
7d539fe
Updated docstrings
MargaretDuff Aug 8, 2023
bc7ae27
Testing in the p2 setter
MargaretDuff Aug 8, 2023
2d6a48d
Try again with the test on the setter
MargaretDuff Aug 8, 2023
8698a5a
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 9, 2023
85f9a54
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 9, 2023
c63355b
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 9, 2023
0a46d9a
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 9, 2023
ee6895c
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 9, 2023
af5f9f8
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 9, 2023
b494a98
Removing comments referencing lines in the algorithm
MargaretDuff Aug 9, 2023
8643abd
Changelog and developer list
MargaretDuff Aug 9, 2023
f422900
Changed commenting to clear up ROF/FGP confusion
MargaretDuff Aug 9, 2023
579643d
Changes to ohow p1 is multiplied by multip - Edo's comments
MargaretDuff Aug 9, 2023
6ac35bc
fix for tau as image
paskino Aug 10, 2023
9ef73b6
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Aug 11, 2023
ee6ffd9
Changed defaults to True
MargaretDuff Aug 11, 2023
9e9f727
Merge branch 'tv_warmstart' of github.com:MargaretDuff/CIL-margaret i…
MargaretDuff Aug 11, 2023
624aa43
Default max iterations =10 and started unit tests
MargaretDuff Aug 14, 2023
e882e2d
Added test to check default p2 value
MargaretDuff Aug 14, 2023
70de544
Neatening block function tests
MargaretDuff Aug 14, 2023
72350c5
Changed tolerances on testing
MargaretDuff Aug 14, 2023
24b2e4e
Changes to unittests after discussion with Gemma
MargaretDuff Aug 17, 2023
be508d6
Spelling
MargaretDuff Aug 17, 2023
126545b
TV warmstart unittests
MargaretDuff Aug 17, 2023
314107d
Merge branch 'master' into tv_warmstart
MargaretDuff Aug 17, 2023
34de0d3
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
df50ae3
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
e7bc6ed
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
9abaec9
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
fa98b9b
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
2a0027b
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
cb126b0
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
6499a43
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 13, 2023
995239e
Updated documentation
MargaretDuff Sep 13, 2023
c1fec46
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 14, 2023
f4eba6a
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 14, 2023
3109a9f
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 14, 2023
6dc1da2
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 14, 2023
8a43601
Update Wrappers/Python/test/test_functions.py
MargaretDuff Sep 14, 2023
5e4edee
Update Wrappers/Python/test/test_functions.py
MargaretDuff Sep 14, 2023
244586b
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 14, 2023
6f22d38
Update Wrappers/Python/test/test_functions.py
MargaretDuff Sep 14, 2023
e6dd3de
Update Wrappers/Python/cil/optimisation/functions/TotalVariation.py
MargaretDuff Sep 14, 2023
3414073
Changes requested by Edo
MargaretDuff Sep 14, 2023
0b48716
Changed warmstart to warm_start in test functions
MargaretDuff Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 47 additions & 27 deletions Wrappers/Python/cil/optimisation/functions/TotalVariation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
# Claire Delplancke (University of Bath)


from cil.optimisation.functions import Function, IndicatorBox, MixedL21Norm, MixedL11Norm
from cil.optimisation.operators import GradientOperator
import numpy as np
Expand Down Expand Up @@ -91,6 +92,11 @@ class TotalVariation(Function):

.. math:: \underset{u}{\mathrm{argmin}} \frac{1}{2\frac{\tau}{1+\gamma\tau}}\|u - \frac{b}{1+\gamma\tau}\|^{2} + \mathrm{TV}(u)

warmstart : :obj`boolean`, default = False
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
If set to true, the FGP aglorithm to calculate the TV proximal is initiated by the final value from the previous iteration and not at zero.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
This allows the max_iteration value to be reduced to 5-10 iterations.

MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

Note
----

Expand All @@ -111,7 +117,7 @@ class TotalVariation(Function):

>>> alpha = 2.0
>>> TV = TotalVariation()
>>> sol = TV.proxima(b, tau = alpha)
>>> sol = TV.proximal(b, tau = alpha)

Examples
--------
Expand All @@ -122,7 +128,7 @@ class TotalVariation(Function):

>>> alpha = 2.0
>>> TV = TotalVariation(isotropic=False, lower=1.0, upper=2.0)
>>> sol = TV.proxima(b, tau = alpha)
>>> sol = TV.proximal(b, tau = alpha)


Examples
Expand All @@ -133,7 +139,7 @@ class TotalVariation(Function):
>>> alpha = 2.0
>>> gamma = 1e-3
>>> TV = alpha * TotalVariation(isotropic=False, strong_convexity_constant=gamma)
>>> sol = TV.proxima(b, tau = 1.0)
>>> sol = TV.proximal(b, tau = 1.0)

"""

Expand All @@ -148,7 +154,8 @@ def __init__(self,
isotropic = True,
split = False,
info = False,
strong_convexity_constant = 0):
strong_convexity_constant = 0,
warmstart=False):
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved


super(TotalVariation, self).__init__(L = None)
Expand Down Expand Up @@ -190,6 +197,11 @@ def __init__(self,
# splitting Gradient
self.split = split

# warm-start
self.warmstart = warmstart
if self.warmstart:
self.hasstarted = False
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

# Strong convexity for TV
self.strong_convexity_constant = strong_convexity_constant

Expand Down Expand Up @@ -234,7 +246,7 @@ def __call__(self, x):
def proximal(self, x, tau, out = None):

r""" Returns the proximal operator of the TotalVariation function at :code:`x` ."""

self.tau=tau #Introduced for testing
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

if self.strong_convexity_constant>0:

Expand Down Expand Up @@ -274,11 +286,19 @@ def _fista_on_dual_rof(self, x, tau, out = None):
self.calculate_Lipschitz()

# initialise
t = 1
t = 1 # line 2
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

p1 = self.gradient.range_geometry().allocate(0) # dua variable - value alloacated here is not used - is overwritten during iterations
if not self.warmstart:
self.p2 = self.gradient.range_geometry().allocate(0) # previous dual variable - needs loading for warm start!
tmp_q = self.gradient.range_geometry().allocate(0) # should be equal to self.p2 - i.e. needs loading for warm start
else:
if not self.hasstarted:
self.p2 = self.gradient.range_geometry().allocate(0) # previous dual variable - needs loading for warm start!
self.hasstarted = True
tmp_q = self.p2.copy() # should be equal to self.p2 - i.e. needs loading for warm start

MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

p1 = self.gradient.range_geometry().allocate(0)
p2 = self.gradient.range_geometry().allocate(0)
tmp_q = self.gradient.range_geometry().allocate(0)

# multiply tau by -1 * regularisation_parameter here so it's not recomputed every iteration
# when tau is an array this is done inplace so reverted at the end
Expand All @@ -294,43 +314,43 @@ def _fista_on_dual_rof(self, x, tau, out = None):
out = self.gradient.domain_geometry().allocate(0)

should_break = False
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
for k in range(self.iterations):
for k in range(self.iterations): # line 3 in alogirhtm one of "Multicontrast MRI Reconstruction with Structure-Guided Total Variation", Ehrhardt, Betcke, 2016.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

t0 = t
self.gradient.adjoint(tmp_q, out = out)
out.sapyb(tau_reg_neg, x, 1.0, out=out)
self.projection_C(out, tau=None, out = out)
self.gradient.adjoint(tmp_q, out = out) # line 4
out.sapyb(tau_reg_neg, x, 1.0, out=out)# line 4
self.projection_C(out, tau=None, out = out)# line 4

if should_break:
break
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

self.gradient.direct(out, out=p1)
self.gradient.direct(out, out=p1) # line 4

multip = (-self.L)/tau_reg_neg
p1.multiply(multip,out=p1)
multip = (-self.L)/tau_reg_neg# line 4
p1.multiply(multip,out=p1) #line 4/5

tmp_q += p1
tmp_q += p1 # line 5
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

if self.tolerance is not None and k%5==0:
if self.tolerance is not None and k%5==0: # testing convergence criterion
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
error = p1.norm()
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
error /= tmp_q.norm()
if error <= self.tolerance:
should_break = True
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

# Depending on the case, isotropic or anisotropic, the proximal conjugate of the MixedL21Norm (isotropic case),
# or the proximal conjugate of the MixedL11Norm (anisotropic case) is computed.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
self.func.proximal_conjugate(tmp_q, 1.0, out=p1)
self.func.proximal_conjugate(tmp_q, 1.0, out=p1) # line 5

t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2
t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2 # line 6

p1.subtract(p2, out=tmp_q)
tmp_q *= (t0-1)/t
tmp_q += p1
p1.subtract(self.p2, out=tmp_q) # line 7
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
tmp_q *= (t0-1)/t # line 7
tmp_q += p1 # line 7

#switch p1 and p2 references
#switch p1 and self.p2 references
tmp = p1
p1 = p2
p2 = tmp
p1 = self.p2
self.p2 = tmp
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

# Print stopping information (iterations and tolerance error) of FGP_TV
if self.info:
Expand Down Expand Up @@ -375,4 +395,4 @@ def __rmul__(self, scalar):
if not isinstance (scalar, Number):
raise TypeError("scalar: Expected a number, got {}".format(type(scalar)))
self.regularisation_parameter *= scalar
return self
return self
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved