Skip to content

Commit

Permalink
_proximal_numpy accepts BDC
Browse files Browse the repository at this point in the history
  • Loading branch information
paskino committed Oct 17, 2023
1 parent 3a387ac commit 2d29001
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,37 +79,44 @@ def _proximal_step_numpy(arr, tau):
# DataContainer or BlockDataContainer
tmp = tau.abs()

try:
# if tau is a DataContainer
arr /= tmp
res = arr - 1
res.maximum(0.0, out=res)
res /= arr

arr *= tmp

resarray = res.as_array()
resarray[np.isnan(resarray)] = 0
res.fill(resarray)
return res
except AttributeError:
if isinstance(tau, BlockDataContainer):
# if tau is a BlockDataContainer
# This is the old 21.3.1 CIL code for the Proximal of MixedL21Norm
# https://github.com/TomographicImaging/CIL/blob/6edf48aee6e1dcf81a933d4d57dea661509fcba3/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py#L92-L104

tmp = (arr/tau).pnorm(2)
res = (tmp - 1)
res.maximum(0.0, out=res)
res.multiply(arr, out=res)
res = arr * res
res.divide(tmp, out=res)

for el in res.containers:
_remove_nans_bdc(res, 0)

elarray = el.as_array()
elarray[np.isnan(elarray)]=0
el.fill(elarray)
return res
else:
# if tau is a DataContainer
arr /= tmp
res = arr - 1
res.maximum(0.0, out=res)
res /= arr

arr *= tmp

_remove_nans(res, 0)
return res

def _remove_nans(res, value):
resarray = res.as_array()
resarray[np.isnan(resarray)] = value
res.fill(resarray)
# return res
def _remove_nans_bdc(bdc, value):
for el in bdc.containers:
if isinstance(el, BlockDataContainer):
# recursive call
_remove_nans_bdc(el, value)
else:
_remove_nans(el, value)

class MixedL21Norm(Function):

Expand Down

0 comments on commit 2d29001

Please sign in to comment.