Skip to content

Commit

Permalink
Merge branch 'master' into fix_IndicatorBox
Browse files Browse the repository at this point in the history
Signed-off-by: Edoardo Pasca <edo.paskino@gmail.com>
  • Loading branch information
paskino authored Aug 2, 2023
2 parents 756c609 + e7e137b commit 9cdf620
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@

* x.x.x
- fix bug in IndicatorBox proximal_conjugate
- allow CCPi Regulariser functions for not CIL object
- Add norm for CompositionOperator.
- Refactor SIRT algorithm to make it more computationally and memory efficient
- Optimisation in L2NormSquared
- Fix for show_geometry bug for 2D data

* 23.0.1
- Fix bug with NikonReader requiring ROI to be set in constructor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


from cil.framework import DataOrder
from cil.framework import DataContainer
from cil.optimisation.functions import Function
import numpy as np
import warnings
Expand Down Expand Up @@ -269,7 +270,7 @@ def __rmul__(self, scalar):
self.alpha *= scalar
return self
def check_input(self, input):
if input.geometry.length > 3:
if len(input.shape) > 3:
raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length))

class TGV(RegulariserFunction):
Expand Down Expand Up @@ -350,9 +351,9 @@ def __rmul__(self, scalar):
# f = alpha * f

def check_input(self, input):
if len(input.dimension_labels) == 2:
if len(input.shape) == 2:
self.LipshitzConstant = 12
elif len(input.dimension_labels) == 3:
elif len(input.shape) == 3:
self.LipshitzConstant = 16 # Vaggelis to confirm
else:
raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length))
Expand Down Expand Up @@ -430,7 +431,7 @@ def __rmul__(self, scalar):
return self

def check_input(self, input):
if input.geometry.length > 3:
if len(input.shape) > 3:
raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length))

class TNV(RegulariserFunction):
Expand All @@ -455,9 +456,9 @@ def __call__(self,x):
return np.nan

def proximal_numpy(self, in_arr, tau):
if in_arr.ndim != 3:
# https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/413c6001003c6f1272aeb43152654baaf0c8a423/src/Python/src/cpu_regularisers.pyx#L584-L588
raise ValueError('Only 3D data is supported. Passed data has {} dimensions'.format(in_arr.ndim))
# remove any dimension of size 1
in_arr = np.squeeze(in_arr)

res = regularisers.TNV(in_arr,
self.alpha * tau,
self.max_iteration,
Expand All @@ -480,8 +481,14 @@ def __rmul__(self, scalar):

def check_input(self, input):
'''TNV requires 2D+channel data with the first dimension as the channel dimension'''
DataOrder.check_order_for_engine('cil', input.geometry)
if ( input.geometry.channels == 1 ) or ( not input.geometry.length == 3) :
raise ValueError('TNV requires 2D+channel data. Got {}'.format(input.geometry.dimension_labels))
if isinstance(input, DataContainer):
DataOrder.check_order_for_engine('cil', input.geometry)
if ( input.geometry.channels == 1 ) or ( not input.geometry.ndim == 3) :
raise ValueError('TNV requires 2D+channel data. Got {}'.format(input.geometry.dimension_labels))
else:
# if it is not a CIL DataContainer we assume that the data is passed in the correct order
# discard any dimension of size 1
if sum(1 for i in input.shape if i!=1) != 3:
raise ValueError('TNV requires 3D data (with channel as first axis). Got {}'.format(input.shape))


28 changes: 23 additions & 5 deletions Wrappers/Python/cil/recon/FBP.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,12 @@ def _setup_PO_for_chunks(self, num_slices):
self.data_slice = ag_slice.allocate()
self.operator = self._PO_class(ig_slice,ag_slice)

def _process_chunk(self, i, step, out):
def _process_chunk(self, i, step):
self.data_slice.fill(np.squeeze(self.input.array[:,i:i+step,:]))
if not self.filter_inplace:
self._pre_filtering(self.data_slice)
out.array[i:i+step,:,:] = self.operator.adjoint(self.data_slice).array[:]

return self.operator.adjoint(self.data_slice).array


def run(self, out=None, verbose=1):
Expand Down Expand Up @@ -569,15 +570,32 @@ def run(self, out=None, verbose=1):
#process dataset by requested chunk size
self._setup_PO_for_chunks(self.slices_per_chunk)
for i in range(0, tot_slices-remainder, self.slices_per_chunk):
self._process_chunk(i, self.slices_per_chunk, ret)

if 'bottom' in self.acquisition_geometry.config.panel.origin:
start = i
end = i + self.slices_per_chunk
else:
start = tot_slices -i - self.slices_per_chunk
end = tot_slices - i

ret.array[start:end,:,:] = self._process_chunk(i, self.slices_per_chunk)

if verbose:
pbar.update(1)

#process excess rows
if remainder:
i = tot_slices-remainder
self._setup_PO_for_chunks(remainder)
self._process_chunk(i, remainder, ret)

if 'bottom' in self.acquisition_geometry.config.panel.origin:
start = tot_slices-remainder
end = tot_slices
else:
start = 0
end = remainder

ret.array[start:end,:,:] = self._process_chunk(i, remainder)

if verbose:
pbar.update(1)

Expand Down
4 changes: 2 additions & 2 deletions Wrappers/Python/cil/utilities/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,12 +845,12 @@ def display_detector(self):

#mark data origin
if 'right' in self.acquisition_geometry.config.panel.origin:
if 'bottom' in self.acquisition_geometry.config.panel.origin:
if self.ndim==2 or 'bottom' in self.acquisition_geometry.config.panel.origin:
pix0 = det[0]
else:
pix0 = det[3]
else:
if 'bottom' in self.acquisition_geometry.config.panel.origin:
if self.ndim==2 or 'bottom' in self.acquisition_geometry.config.panel.origin:
pix0 = det[1]
else:
pix0 = det[2]
Expand Down
132 changes: 126 additions & 6 deletions Wrappers/Python/test/test_SIRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,31 @@
from cil.optimisation.functions import TotalVariation, L2NormSquared, KullbackLeibler
from cil.optimisation.algorithms import FISTA

import os
from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, TNV, FGP_dTV
from cil.utilities.display import show2D

from testclass import CCPiTestClass
from utils import has_nvidia, has_ccpi_regularisation, initialise_tests

initialise_tests()

try:
import sirf.STIR as pet
import sirf.Gadgetron as mr
import sirf.Reg as reg
from sirf.Utilities import examples_data_path

has_sirf = True
except ImportError as ie:
has_sirf = False

if has_ccpi_regularisation:
from ccpi.filters import regularisers
from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, FGP_dTV, TNV



class KullbackLeiblerSIRF(object):

def setUp(self):
Expand Down Expand Up @@ -295,7 +308,9 @@ def test_BlockDataContainer_with_SIRF_DataContainer_divide(self):
bdc = BlockDataContainer(image1, image2)
bdc1 = bdc.divide(1.)

self.assertBlockDataContainerEqual(bdc , bdc1)
# self.assertBlockDataContainerEqual(bdc , bdc1)
np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array())
np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array())


@unittest.skipUnless(has_sirf, "Has SIRF")
Expand All @@ -316,7 +331,9 @@ def test_BlockDataContainer_with_SIRF_DataContainer_multiply(self):
bdc = BlockDataContainer(image1, image2)
bdc1 = bdc.multiply(1.)

self.assertBlockDataContainerEqual(bdc , bdc1)
# self.assertBlockDataContainerEqual(bdc , bdc1)
np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array())
np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array())


@unittest.skipUnless(has_sirf, "Has SIRF")
Expand All @@ -340,7 +357,9 @@ def test_BlockDataContainer_with_SIRF_DataContainer_add(self):

bdc = BlockDataContainer(image1, image2)

self.assertBlockDataContainerEqual(bdc , bdc1)
np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array())
np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array())
# self.assertBlockDataContainerEqual(bdc , bdc1)


@unittest.skipUnless(has_sirf, "Has SIRF")
Expand All @@ -360,10 +379,111 @@ def test_BlockDataContainer_with_SIRF_DataContainer_subtract(self):

bdc = BlockDataContainer(image1, image2)

self.assertBlockDataContainerEqual(bdc , bdc1)


# self.assertBlockDataContainerEqual(bdc , bdc1)
np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array())
np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array())



class CCPiRegularisationWithSIRFTests():

def setUpFGP_TV(self, max_iteration=100, alpha=1.):
return alpha*FGP_TV(max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_FGP_TV_call_works(self):
regulariser = self.setUpFGP_TV()
output_number = regulariser(self.image1)
self.assertTrue(True)
# TODO: test the actual value
# expected = 160600016.0
# np.testing.assert_allclose(output_number, expected, rtol=1e-5)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_FGP_TV_proximal_works(self):
regulariser = self.setUpFGP_TV()
solution = regulariser.proximal(x=self.image1, tau=1)
self.assertTrue(True)

# TGV
def setUpTGV(self, max_iteration=100, alpha=1.):
return alpha * TGV(max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_call_works(self):
regulariser = self.setUpTGV()
output_number = regulariser(self.image1)
self.assertTrue(True)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_proximal_works(self):
regulariser = self.setUpTGV()
solution = regulariser.proximal(x=self.image1, tau=1)
self.assertTrue(True)

# dTV
def setUpdTV(self, max_iteration=100, alpha=1.):
return alpha * FGP_dTV(reference=self.image2, max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_call_works(self):
regulariser = self.setUpTGV()
output_number = regulariser(self.image1)
self.assertTrue(True)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_proximal_works(self):
regulariser = self.setUpTGV()
solution = regulariser.proximal(x=self.image1, tau=1)
self.assertTrue(True)

# TNV
def setUpTNV(self, max_iteration=100, alpha=1.):
return alpha * TNV(max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TNV_call_works(self):
new_shape = [ i for i in self.image1.shape if i!=1]
if len(new_shape) == 3:
regulariser = self.setUpTNV()
output_number = regulariser(self.image1)
self.assertTrue(True)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TNV_proximal_works(self):
new_shape = [ i for i in self.image1.shape if i!=1]
if len(new_shape) == 3:
regulariser = self.setUpTNV()
solution = regulariser.proximal(x=self.image1, tau=1.)
self.assertTrue(True)

class TestPETRegularisation(unittest.TestCase, CCPiRegularisationWithSIRFTests):
skip_TNV_on_2D = True
def setUp(self):
self.image1 = pet.ImageData(os.path.join(
examples_data_path('PET'),'thorax_single_slice','emission.hv'
))
self.image2 = self.image1 * 0.5

@unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D")
def test_TNV_call_works(self):
super().test_TNV_call_works()

@unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D")
def test_TNV_proximal_works(self):
super().test_TNV_proximal_works()

class TestRegRegularisation(unittest.TestCase, CCPiRegularisationWithSIRFTests):
def setUp(self):
self.image1 = reg.ImageData(os.path.join(examples_data_path('Registration'),'test2.nii.gz'))
self.image2 = self.image1 * 0.5

class TestMRRegularisation(unittest.TestCase, CCPiRegularisationWithSIRFTests):
def setUp(self):
acq_data = mr.AcquisitionData(os.path.join(examples_data_path('MR'),'simulated_MR_2D_cartesian.h5'))
preprocessed_data = mr.preprocess_acquisition_data(acq_data)
recon = mr.FullySampledReconstructor()
recon.set_input(preprocessed_data)
recon.process()
self.image1 = recon.get_output()
self.image2 = self.image1 * 0.5
23 changes: 22 additions & 1 deletion Wrappers/Python/test/test_reconstructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def test_results_3D_astra(self):
def test_results_3D_split(self):

reconstructor = FBP(self.acq_data)
reconstructor.set_split_processing(1)
reconstructor.set_split_processing(8)

reco = reconstructor.run(verbose=0)
np.testing.assert_allclose(reco.as_array(), self.img_data.as_array(),atol=1e-3)
Expand All @@ -663,6 +663,27 @@ def test_results_3D_split(self):
np.testing.assert_allclose(reco.as_array(), reco2.as_array(), atol=1e-8)


@unittest.skipUnless(has_tigre and has_nvidia and has_ipp, "TIGRE or IPP not installed")
def test_results_3D_split_reverse(self):

acq_data = self.acq_data.copy()
acq_data.geometry.config.panel.origin = 'top-left'

reconstructor = FBP(acq_data)
reconstructor.set_split_processing(8)

expected_image = np.flip(self.img_data.as_array(),0)

reco = reconstructor.run(verbose=0)
np.testing.assert_allclose(reco.as_array(), expected_image,atol=1e-3)

reco2 = reco.copy()
reco2.fill(0)
reconstructor.run(out=reco2, verbose=0)
np.testing.assert_allclose(reco.as_array(), reco2.as_array(), atol=1e-8)



@unittest.skipUnless(has_tigre and has_nvidia and has_ipp, "TIGRE or IPP not installed")
def test_results_2D_tigre(self):

Expand Down

0 comments on commit 9cdf620

Please sign in to comment.