From 2d290013bdf5b64c7b4e7a18506ffb63126dde53 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 17 Oct 2023 16:25:06 +0100 Subject: [PATCH] _proximal_numpy accepts BDC --- .../optimisation/functions/MixedL21Norm.py | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py b/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py index b9e4f2208..3c79632bf 100644 --- a/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py +++ b/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py @@ -79,20 +79,7 @@ def _proximal_step_numpy(arr, tau): # DataContainer or BlockDataContainer tmp = tau.abs() - try: - # if tau is a DataContainer - arr /= tmp - res = arr - 1 - res.maximum(0.0, out=res) - res /= arr - - arr *= tmp - - resarray = res.as_array() - resarray[np.isnan(resarray)] = 0 - res.fill(resarray) - return res - except AttributeError: + if isinstance(tau, BlockDataContainer): # if tau is a BlockDataContainer # This is the old 21.3.1 CIL code for the Proximal of MixedL21Norm # https://github.com/TomographicImaging/CIL/blob/6edf48aee6e1dcf81a933d4d57dea661509fcba3/Wrappers/Python/cil/optimisation/functions/MixedL21Norm.py#L92-L104 @@ -100,16 +87,36 @@ def _proximal_step_numpy(arr, tau): tmp = (arr/tau).pnorm(2) res = (tmp - 1) res.maximum(0.0, out=res) - res.multiply(arr, out=res) + res = arr * res res.divide(tmp, out=res) - for el in res.containers: + _remove_nans_bdc(res, 0) - elarray = el.as_array() - elarray[np.isnan(elarray)]=0 - el.fill(elarray) + return res + else: + # if tau is a DataContainer + arr /= tmp + res = arr - 1 + res.maximum(0.0, out=res) + res /= arr + + arr *= tmp + _remove_nans(res, 0) return res + +def _remove_nans(res, value): + resarray = res.as_array() + resarray[np.isnan(resarray)] = value + res.fill(resarray) + # return res +def _remove_nans_bdc(bdc, value): + for el in bdc.containers: + if isinstance(el, BlockDataContainer): + # recursive call + _remove_nans_bdc(el, value) + else: + _remove_nans(el, value) class MixedL21Norm(Function):