diff --git a/CHANGELOG.md b/CHANGELOG.md index 74db4d5b0d..c0cc69188f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,11 @@ + * x.x.x - fix bug in IndicatorBox proximal_conjugate + - allow CCPi Regulariser functions for not CIL object - Add norm for CompositionOperator. - Refactor SIRT algorithm to make it more computationally and memory efficient - Optimisation in L2NormSquared + - Fix for show_geometry bug for 2D data * 23.0.1 - Fix bug with NikonReader requiring ROI to be set in constructor. diff --git a/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py b/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py index 789d95da15..b2a4d50d0e 100644 --- a/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py +++ b/Wrappers/Python/cil/plugins/ccpi_regularisation/functions/regularisers.py @@ -28,6 +28,7 @@ from cil.framework import DataOrder +from cil.framework import DataContainer from cil.optimisation.functions import Function import numpy as np import warnings @@ -269,7 +270,7 @@ def __rmul__(self, scalar): self.alpha *= scalar return self def check_input(self, input): - if input.geometry.length > 3: + if len(input.shape) > 3: raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) class TGV(RegulariserFunction): @@ -350,9 +351,9 @@ def __rmul__(self, scalar): # f = alpha * f def check_input(self, input): - if len(input.dimension_labels) == 2: + if len(input.shape) == 2: self.LipshitzConstant = 12 - elif len(input.dimension_labels) == 3: + elif len(input.shape) == 3: self.LipshitzConstant = 16 # Vaggelis to confirm else: raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) @@ -430,7 +431,7 @@ def __rmul__(self, scalar): return self def check_input(self, input): - if input.geometry.length > 3: + if len(input.shape) > 3: raise ValueError('{} cannot work on more than 3D. Got {}'.format(self.__class__.__name__, input.geometry.length)) class TNV(RegulariserFunction): @@ -455,9 +456,9 @@ def __call__(self,x): return np.nan def proximal_numpy(self, in_arr, tau): - if in_arr.ndim != 3: - # https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/413c6001003c6f1272aeb43152654baaf0c8a423/src/Python/src/cpu_regularisers.pyx#L584-L588 - raise ValueError('Only 3D data is supported. Passed data has {} dimensions'.format(in_arr.ndim)) + # remove any dimension of size 1 + in_arr = np.squeeze(in_arr) + res = regularisers.TNV(in_arr, self.alpha * tau, self.max_iteration, @@ -480,8 +481,14 @@ def __rmul__(self, scalar): def check_input(self, input): '''TNV requires 2D+channel data with the first dimension as the channel dimension''' - DataOrder.check_order_for_engine('cil', input.geometry) - if ( input.geometry.channels == 1 ) or ( not input.geometry.length == 3) : - raise ValueError('TNV requires 2D+channel data. Got {}'.format(input.geometry.dimension_labels)) + if isinstance(input, DataContainer): + DataOrder.check_order_for_engine('cil', input.geometry) + if ( input.geometry.channels == 1 ) or ( not input.geometry.ndim == 3) : + raise ValueError('TNV requires 2D+channel data. Got {}'.format(input.geometry.dimension_labels)) + else: + # if it is not a CIL DataContainer we assume that the data is passed in the correct order + # discard any dimension of size 1 + if sum(1 for i in input.shape if i!=1) != 3: + raise ValueError('TNV requires 3D data (with channel as first axis). Got {}'.format(input.shape)) diff --git a/Wrappers/Python/cil/recon/FBP.py b/Wrappers/Python/cil/recon/FBP.py index ea02d648ed..90fd51bd87 100644 --- a/Wrappers/Python/cil/recon/FBP.py +++ b/Wrappers/Python/cil/recon/FBP.py @@ -514,11 +514,12 @@ def _setup_PO_for_chunks(self, num_slices): self.data_slice = ag_slice.allocate() self.operator = self._PO_class(ig_slice,ag_slice) - def _process_chunk(self, i, step, out): + def _process_chunk(self, i, step): self.data_slice.fill(np.squeeze(self.input.array[:,i:i+step,:])) if not self.filter_inplace: self._pre_filtering(self.data_slice) - out.array[i:i+step,:,:] = self.operator.adjoint(self.data_slice).array[:] + + return self.operator.adjoint(self.data_slice).array def run(self, out=None, verbose=1): @@ -569,15 +570,32 @@ def run(self, out=None, verbose=1): #process dataset by requested chunk size self._setup_PO_for_chunks(self.slices_per_chunk) for i in range(0, tot_slices-remainder, self.slices_per_chunk): - self._process_chunk(i, self.slices_per_chunk, ret) + + if 'bottom' in self.acquisition_geometry.config.panel.origin: + start = i + end = i + self.slices_per_chunk + else: + start = tot_slices -i - self.slices_per_chunk + end = tot_slices - i + + ret.array[start:end,:,:] = self._process_chunk(i, self.slices_per_chunk) + if verbose: pbar.update(1) #process excess rows if remainder: - i = tot_slices-remainder self._setup_PO_for_chunks(remainder) - self._process_chunk(i, remainder, ret) + + if 'bottom' in self.acquisition_geometry.config.panel.origin: + start = tot_slices-remainder + end = tot_slices + else: + start = 0 + end = remainder + + ret.array[start:end,:,:] = self._process_chunk(i, remainder) + if verbose: pbar.update(1) diff --git a/Wrappers/Python/cil/utilities/display.py b/Wrappers/Python/cil/utilities/display.py index 42b21ed24a..ff6f9832f7 100644 --- a/Wrappers/Python/cil/utilities/display.py +++ b/Wrappers/Python/cil/utilities/display.py @@ -845,12 +845,12 @@ def display_detector(self): #mark data origin if 'right' in self.acquisition_geometry.config.panel.origin: - if 'bottom' in self.acquisition_geometry.config.panel.origin: + if self.ndim==2 or 'bottom' in self.acquisition_geometry.config.panel.origin: pix0 = det[0] else: pix0 = det[3] else: - if 'bottom' in self.acquisition_geometry.config.panel.origin: + if self.ndim==2 or 'bottom' in self.acquisition_geometry.config.panel.origin: pix0 = det[1] else: pix0 = det[2] diff --git a/Wrappers/Python/test/test_SIRF.py b/Wrappers/Python/test/test_SIRF.py index eaece897e4..9d8af9bb3e 100644 --- a/Wrappers/Python/test/test_SIRF.py +++ b/Wrappers/Python/test/test_SIRF.py @@ -28,18 +28,31 @@ from cil.optimisation.functions import TotalVariation, L2NormSquared, KullbackLeibler from cil.optimisation.algorithms import FISTA +import os +from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, TNV, FGP_dTV +from cil.utilities.display import show2D + from testclass import CCPiTestClass +from utils import has_nvidia, has_ccpi_regularisation, initialise_tests initialise_tests() try: import sirf.STIR as pet import sirf.Gadgetron as mr + import sirf.Reg as reg from sirf.Utilities import examples_data_path + has_sirf = True except ImportError as ie: has_sirf = False +if has_ccpi_regularisation: + from ccpi.filters import regularisers + from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, FGP_dTV, TNV + + + class KullbackLeiblerSIRF(object): def setUp(self): @@ -295,7 +308,9 @@ def test_BlockDataContainer_with_SIRF_DataContainer_divide(self): bdc = BlockDataContainer(image1, image2) bdc1 = bdc.divide(1.) - self.assertBlockDataContainerEqual(bdc , bdc1) + # self.assertBlockDataContainerEqual(bdc , bdc1) + np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array()) + np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array()) @unittest.skipUnless(has_sirf, "Has SIRF") @@ -316,7 +331,9 @@ def test_BlockDataContainer_with_SIRF_DataContainer_multiply(self): bdc = BlockDataContainer(image1, image2) bdc1 = bdc.multiply(1.) - self.assertBlockDataContainerEqual(bdc , bdc1) + # self.assertBlockDataContainerEqual(bdc , bdc1) + np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array()) + np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array()) @unittest.skipUnless(has_sirf, "Has SIRF") @@ -340,7 +357,9 @@ def test_BlockDataContainer_with_SIRF_DataContainer_add(self): bdc = BlockDataContainer(image1, image2) - self.assertBlockDataContainerEqual(bdc , bdc1) + np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array()) + np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array()) + # self.assertBlockDataContainerEqual(bdc , bdc1) @unittest.skipUnless(has_sirf, "Has SIRF") @@ -360,10 +379,111 @@ def test_BlockDataContainer_with_SIRF_DataContainer_subtract(self): bdc = BlockDataContainer(image1, image2) - self.assertBlockDataContainerEqual(bdc , bdc1) - - + # self.assertBlockDataContainerEqual(bdc , bdc1) + np.testing.assert_allclose(bdc.get_item(0).as_array(), bdc1.get_item(0).as_array()) + np.testing.assert_allclose(bdc.get_item(1).as_array(), bdc1.get_item(1).as_array()) +class CCPiRegularisationWithSIRFTests(): + + def setUpFGP_TV(self, max_iteration=100, alpha=1.): + return alpha*FGP_TV(max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_FGP_TV_call_works(self): + regulariser = self.setUpFGP_TV() + output_number = regulariser(self.image1) + self.assertTrue(True) + # TODO: test the actual value + # expected = 160600016.0 + # np.testing.assert_allclose(output_number, expected, rtol=1e-5) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_FGP_TV_proximal_works(self): + regulariser = self.setUpFGP_TV() + solution = regulariser.proximal(x=self.image1, tau=1) + self.assertTrue(True) + + # TGV + def setUpTGV(self, max_iteration=100, alpha=1.): + return alpha * TGV(max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_call_works(self): + regulariser = self.setUpTGV() + output_number = regulariser(self.image1) + self.assertTrue(True) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_proximal_works(self): + regulariser = self.setUpTGV() + solution = regulariser.proximal(x=self.image1, tau=1) + self.assertTrue(True) + + # dTV + def setUpdTV(self, max_iteration=100, alpha=1.): + return alpha * FGP_dTV(reference=self.image2, max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_call_works(self): + regulariser = self.setUpTGV() + output_number = regulariser(self.image1) + self.assertTrue(True) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_proximal_works(self): + regulariser = self.setUpTGV() + solution = regulariser.proximal(x=self.image1, tau=1) + self.assertTrue(True) + + # TNV + def setUpTNV(self, max_iteration=100, alpha=1.): + return alpha * TNV(max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TNV_call_works(self): + new_shape = [ i for i in self.image1.shape if i!=1] + if len(new_shape) == 3: + regulariser = self.setUpTNV() + output_number = regulariser(self.image1) + self.assertTrue(True) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TNV_proximal_works(self): + new_shape = [ i for i in self.image1.shape if i!=1] + if len(new_shape) == 3: + regulariser = self.setUpTNV() + solution = regulariser.proximal(x=self.image1, tau=1.) + self.assertTrue(True) + +class TestPETRegularisation(unittest.TestCase, CCPiRegularisationWithSIRFTests): + skip_TNV_on_2D = True + def setUp(self): + self.image1 = pet.ImageData(os.path.join( + examples_data_path('PET'),'thorax_single_slice','emission.hv' + )) + self.image2 = self.image1 * 0.5 + + @unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D") + def test_TNV_call_works(self): + super().test_TNV_call_works() + + @unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D") + def test_TNV_proximal_works(self): + super().test_TNV_proximal_works() + +class TestRegRegularisation(unittest.TestCase, CCPiRegularisationWithSIRFTests): + def setUp(self): + self.image1 = reg.ImageData(os.path.join(examples_data_path('Registration'),'test2.nii.gz')) + self.image2 = self.image1 * 0.5 +class TestMRRegularisation(unittest.TestCase, CCPiRegularisationWithSIRFTests): + def setUp(self): + acq_data = mr.AcquisitionData(os.path.join(examples_data_path('MR'),'simulated_MR_2D_cartesian.h5')) + preprocessed_data = mr.preprocess_acquisition_data(acq_data) + recon = mr.FullySampledReconstructor() + recon.set_input(preprocessed_data) + recon.process() + self.image1 = recon.get_output() + self.image2 = self.image1 * 0.5 \ No newline at end of file diff --git a/Wrappers/Python/test/test_reconstructors.py b/Wrappers/Python/test/test_reconstructors.py index d8b0320fab..9d38279007 100644 --- a/Wrappers/Python/test/test_reconstructors.py +++ b/Wrappers/Python/test/test_reconstructors.py @@ -652,7 +652,7 @@ def test_results_3D_astra(self): def test_results_3D_split(self): reconstructor = FBP(self.acq_data) - reconstructor.set_split_processing(1) + reconstructor.set_split_processing(8) reco = reconstructor.run(verbose=0) np.testing.assert_allclose(reco.as_array(), self.img_data.as_array(),atol=1e-3) @@ -663,6 +663,27 @@ def test_results_3D_split(self): np.testing.assert_allclose(reco.as_array(), reco2.as_array(), atol=1e-8) + @unittest.skipUnless(has_tigre and has_nvidia and has_ipp, "TIGRE or IPP not installed") + def test_results_3D_split_reverse(self): + + acq_data = self.acq_data.copy() + acq_data.geometry.config.panel.origin = 'top-left' + + reconstructor = FBP(acq_data) + reconstructor.set_split_processing(8) + + expected_image = np.flip(self.img_data.as_array(),0) + + reco = reconstructor.run(verbose=0) + np.testing.assert_allclose(reco.as_array(), expected_image,atol=1e-3) + + reco2 = reco.copy() + reco2.fill(0) + reconstructor.run(out=reco2, verbose=0) + np.testing.assert_allclose(reco.as_array(), reco2.as_array(), atol=1e-8) + + + @unittest.skipUnless(has_tigre and has_nvidia and has_ipp, "TIGRE or IPP not installed") def test_results_2D_tigre(self):