From 976dd269e66e1a5f992946d37c6c816eeadcdadb Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 13 Mar 2024 16:29:20 +0000 Subject: [PATCH] more OOP --- Wrappers/Python/cil/utilities/dataexample.py | 536 ++++++++----------- Wrappers/Python/test/test_dataexample.py | 2 +- 2 files changed, 213 insertions(+), 325 deletions(-) diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index 18f3374a97..30f0b7c85f 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -16,7 +16,7 @@ # Authors: # CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt -from cil.framework import ImageData, ImageGeometry, DataContainer +from cil.framework import ImageGeometry import numpy import numpy as np from PIL import Image @@ -26,301 +26,38 @@ from zipfile import ZipFile from urllib.request import urlopen from io import BytesIO -from scipy.io import loadmat from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader +from abc import ABC, abstractmethod -class DATA(object): - @classmethod - def dfile(cls): - return None - -class CILDATA(DATA): - data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil')) - @classmethod - def get(cls, size=None, scale=(0,1), **kwargs): - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = TestData(data_dir=ddir) - return loader.load(cls.dfile(), size, scale, **kwargs) - -class REMOTEDATA(DATA): - - FOLDER = '' - URL = '' - FILE_SIZE = '' - - @classmethod - def get(cls, data_dir): - return None - - @classmethod - def _download_and_extract_from_url(cls, data_dir): - with urlopen(cls.URL) as response: - with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: - zipfile.extractall(path = data_dir) - - @classmethod - def download_data(cls, data_dir): - ''' - Download a dataset from a remote repository - - Parameters - ---------- - data_dir: str, optional - The path to the data directory where the downloaded data should be stored - - ''' - if os.path.isdir(os.path.join(data_dir, cls.FOLDER)): - print("Dataset already exists in " + data_dir) - else: - if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y": - print('Downloading dataset from ' + cls.URL) - cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER)) - print('Download complete') - else: - print('Download cancelled') - -class BOAT(CILDATA): - @classmethod - def dfile(cls): - return TestData.BOAT -class CAMERA(CILDATA): - @classmethod - def dfile(cls): - return TestData.CAMERA -class PEPPERS(CILDATA): - @classmethod - def dfile(cls): - return TestData.PEPPERS -class RESOLUTION_CHART(CILDATA): - @classmethod - def dfile(cls): - return TestData.RESOLUTION_CHART -class SIMPLE_PHANTOM_2D(CILDATA): - @classmethod - def dfile(cls): - return TestData.SIMPLE_PHANTOM_2D -class SHAPES(CILDATA): - @classmethod - def dfile(cls): - return TestData.SHAPES -class RAINBOW(CILDATA): - @classmethod - def dfile(cls): - return TestData.RAINBOW -class SYNCHROTRON_PARALLEL_BEAM_DATA(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A DLS dataset - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - AcquisitionData - The DLS dataset - ''' - - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), '24737_fd_normalised.nxs')) - return loader.read() -class SIMULATED_PARALLEL_BEAM_DATA(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A simulated parallel-beam dataset generated from SIMULATED_SPHERE_VOLUME - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - AcquisitionData - The simulated spheres dataset - ''' - - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_parallel_beam.nxs')) - return loader.read() -class SIMULATED_CONE_BEAM_DATA(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A cone-beam dataset generated from SIMULATED_SPHERE_VOLUME - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - AcquisitionData - The simulated spheres dataset - ''' - - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_cone_beam.nxs')) - return loader.read() -class SIMULATED_SPHERE_VOLUME(CILDATA): - @classmethod - def get(cls, **kwargs): - ''' - A simulated volume of spheres - - Parameters - ---------- - data_dir: str, optional - The path to the data directory - - Returns - ------- - ImageData - The simulated spheres volume - ''' - ddir = kwargs.get('data_dir', CILDATA.data_dir) - loader = NEXUSDataReader() - loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs')) - return loader.read() - -class WALNUT(REMOTEDATA): - ''' - A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 - ''' - FOLDER = 'walnut' - URL = 'https://zenodo.org/record/4822516/files/walnut.zip' - FILE_SIZE = '6.4 GB' - - @classmethod - def get(cls, data_dir): - ''' - A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 - This function returns the raw projection data from the .txrm file - - Parameters - ---------- - data_dir: str - The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) - - Returns - ------- - ImageData - The walnut dataset - ''' - filepath = os.path.join(data_dir, cls.FOLDER, 'valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') - try: - loader = ZEISSDataReader(file_name=filepath) - return loader.read() - except(FileNotFoundError): - raise(FileNotFoundError("Dataset .txrm file not found in specifed data_dir: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) +DEFAULT_DATA_DIR = os.path.abspath(os.path.join(sys.prefix, 'share', 'cil')) -class USB(REMOTEDATA): - ''' - A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 - ''' - FOLDER = 'USB' - URL = 'https://zenodo.org/record/4822516/files/usb.zip' - FILE_SIZE = '3.2 GB' +class TestData: + '''Provides 6 datasets: - @classmethod - def get(cls, data_dir): - ''' - A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 - This function returns the raw projection data from the .txrm file - - Parameters - ---------- - data_dir: str - The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) - - Returns - ------- - ImageData - The usb dataset - ''' - filepath = os.path.join(data_dir, cls.FOLDER, 'gruppe 4','gruppe 4_2014-03-20_1404_12','tomo-A','gruppe 4_tomo-A.txrm') - try: - loader = ZEISSDataReader(file_name=filepath) - return loader.read() - except(FileNotFoundError): - raise(FileNotFoundError("Dataset .txrm file not found in: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) - -class KORN(REMOTEDATA): - ''' - A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + BOAT: 'boat.tiff' + CAMERA: 'camera.png' + PEPPERS: 'peppers.tiff' + RESOLUTION_CHART: 'resolution_chart.tiff' + SIMPLE_PHANTOM_2D: 'hotdog' + SHAPES: 'shapes.png' + RAINBOW: 'rainbow.png' ''' - FOLDER = 'korn' - URL = 'https://zenodo.org/record/6874123/files/korn.zip' - FILE_SIZE = '2.9 GB' - - @classmethod - def get(cls, data_dir): - ''' - A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 - This function returns the raw projection data from the .xtekct file - - Parameters - ---------- - data_dir: str - The path to the directory where the dataset is stored. Data can be downloaded with dataexample.KORN.download_data(data_dir) - - Returns - ------- - ImageData - The korn dataset - ''' - filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct') - try: - loader = NikonDataReader(file_name=filepath) - return loader.read() - except(FileNotFoundError): - raise(FileNotFoundError("Dataset .xtekct file not found in: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) - - -class SANDSTONE(REMOTEDATA): - ''' - A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 - A small subset of the data containing selected projections and 4 slices of the reconstruction - ''' - FOLDER = 'sandstone' - URL = 'https://zenodo.org/records/4912435/files/small.zip' - FILE_SIZE = '227 MB' - -class TestData(object): - '''Class to return test data - - provides 6 dataset: BOAT = 'boat.tiff' CAMERA = 'camera.png' PEPPERS = 'peppers.tiff' RESOLUTION_CHART = 'resolution_chart.tiff' SIMPLE_PHANTOM_2D = 'hotdog' - SHAPES = 'shapes.png' + SHAPES = 'shapes.png' RAINBOW = 'rainbow.png' - ''' - BOAT = 'boat.tiff' - CAMERA = 'camera.png' - PEPPERS = 'peppers.tiff' - RESOLUTION_CHART = 'resolution_chart.tiff' - SIMPLE_PHANTOM_2D = 'hotdog' - SHAPES = 'shapes.png' - RAINBOW = 'rainbow.png' + + @classmethod + def _datasets(cls): + return {cls.BOAT, cls.CAMERA, cls.PEPPERS, cls.RESOLUTION_CHART, cls.SIMPLE_PHANTOM_2D, cls.SHAPES, cls.RAINBOW} def __init__(self, data_dir): self.data_dir = data_dir - def load(self, which, size=None, scale=(0,1), **kwargs): + def load(self, which, size=None, scale=None): ''' Return a test data of the requested image @@ -338,52 +75,28 @@ def load(self, which, size=None, scale=(0,1), **kwargs): ImageData The simulated spheres volume ''' - if which not in [TestData.BOAT, TestData.CAMERA, - TestData.PEPPERS, TestData.RESOLUTION_CHART, - TestData.SIMPLE_PHANTOM_2D, TestData.SHAPES, - TestData.RAINBOW]: - raise ValueError('Unknown TestData {}.'.format(which)) + if scale is None: + scale = 0, 1 + if which not in self._datasets(): + raise KeyError(f"Unknown TestData: {which}") if which == TestData.SIMPLE_PHANTOM_2D: - if size is None: - N = 512 - M = 512 - else: - N = size[0] - M = size[1] - + N, M = 512, 512 if size is None else size[0], size[1] sdata = numpy.zeros((N, M)) sdata[int(round(N/4)):int(round(3*N/4)), int(round(M/4)):int(round(3*M/4))] = 0.5 sdata[int(round(N/8)):int(round(7*N/8)), int(round(3*M/8)):int(round(5*M/8))] = 1 ig = ImageGeometry(voxel_num_x = M, voxel_num_y = N, dimension_labels=[ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]) data = ig.allocate() data.fill(sdata) - elif which == TestData.SHAPES: - with Image.open(os.path.join(self.data_dir, which)) as f: - - if size is None: - N = 200 - M = 300 - else: - N = size[0] - M = size[1] - + N, M = 200, 300 if size is None else size[0], size[1] ig = ImageGeometry(voxel_num_x = M, voxel_num_y = N, dimension_labels=[ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]) data = ig.allocate() tmp = numpy.array(f.convert('L').resize((M,N))) data.fill(tmp/numpy.max(tmp)) - else: with Image.open(os.path.join(self.data_dir, which)) as tmp: - - if size is None: - N = tmp.size[1] - M = tmp.size[0] - else: - N = size[0] - M = size[1] - + N, M = tmp.size[1], tmp.size[0] if size is None else size[0], size[1] bands = tmp.getbands() if len(bands) > 1: if len(bands) == 4: @@ -414,26 +127,22 @@ def load(self, which, size=None, scale=(0,1), **kwargs): # print ("data.geometry", data.geometry) return data - @staticmethod - def random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): + @classmethod + def random_noise(cls, image, **kwargs): '''Function to add noise to input image :param image: input dataset, DataContainer of numpy.ndarray - :param mode: type of noise - :param seed: seed for random number generator - :param clip: should clip the data. + :param **kwargs: Passed to `scikit_random_noise` See https://github.com/scikit-image/scikit-image/blob/master/skimage/util/noise.py - ''' if hasattr(image, 'as_array'): - arr = TestData.scikit_random_noise(image.as_array(), mode=mode, seed=seed, clip=clip, - **kwargs) + arr = cls.scikit_random_noise(image.as_array(), **kwargs) out = image.copy() out.fill(arr) return out elif issubclass(type(image), numpy.ndarray): - return TestData.scikit_random_noise(image, mode=mode, seed=seed, clip=clip, - **kwargs) + return cls.scikit_random_noise(image, **kwargs) + raise TypeError(type(image)) @staticmethod def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): @@ -538,7 +247,6 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - """ mode = mode.lower() @@ -548,7 +256,7 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): else: low_clip = 0. - image = numpy.asarray(image, dtype=(np.float64)) + image = numpy.asarray(image, dtype=np.float64) if seed is not None: np.random.seed(seed=seed) @@ -646,3 +354,183 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): out = np.clip(out, low_clip, 1.0) return out + +class _CIL_DATA(ABC): + dfile: str + @classmethod + def get(cls, data_dir=DEFAULT_DATA_DIR, **loader_kwargs): + loader = TestData(data_dir) + return loader.load(cls.dfile, **loader_kwargs) + +class _REMOTE_DATA(ABC): + FOLDER: str + URL: str + FILE_SIZE: str + + @staticmethod + def _prompt(msg): + while (res := input(f"{msg} [y/n]").lower()) not in "yn": + pass + return res == "y" + + @classmethod + def _download_and_extract_from_url(cls, data_dir): + with urlopen(cls.URL) as response: + with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: + zipfile.extractall(path=data_dir) + + @classmethod + def download_data(cls, data_dir): + ''' + Download a dataset from a remote repository + + Parameters + ---------- + data_dir: str, optional + The path to the data directory where the downloaded data should be stored + ''' + if os.path.isdir(os.path.join(data_dir, cls.FOLDER)): + print(f"Dataset already exists in {data_dir}") + else: + if cls._prompt(f"Are you sure you want to download {cls.FILE_SIZE} dataset from {cls.URL}?"): + print(f"Downloading dataset from {cls.URL}") + cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER)) + print('Download complete') + else: + print('Download cancelled') + +class BOAT(_CIL_DATA): + dfile = TestData.BOAT +class CAMERA(_CIL_DATA): + dfile = TestData.CAMERA +class PEPPERS(_CIL_DATA): + dfile = TestData.PEPPERS +class RESOLUTION_CHART(_CIL_DATA): + dfile = TestData.RESOLUTION_CHART +class SIMPLE_PHANTOM_2D(_CIL_DATA): + dfile = TestData.SIMPLE_PHANTOM_2D +class SHAPES(_CIL_DATA): + dfile = TestData.SHAPES +class RAINBOW(_CIL_DATA): + dfile = TestData.RAINBOW +class _NEXUS_CIL_DATA(_CIL_DATA): + @classmethod + def get(cls, data_dir=DEFAULT_DATA_DIR): + ''' + Parameters + ---------- + data_dir: str, optional + The path to the data directory + + Returns + ------- + AcquisitionData + ''' + loader = NEXUSDataReader() + loader.set_up(file_name=os.path.join(data_dir, cls.dfile)) + return loader.read() +class SYNCHROTRON_PARALLEL_BEAM_DATA(_NEXUS_CIL_DATA): + '''A DLS dataset''' + dfile = '24737_fd_normalised.nxs' +class SIMULATED_PARALLEL_BEAM_DATA(_NEXUS_CIL_DATA): + '''A simulated parallel-beam dataset generated from SIMULATED_SPHERE_VOLUME''' + dfile = 'sim_parallel_beam.nxs' +class SIMULATED_CONE_BEAM_DATA(_NEXUS_CIL_DATA): + '''A cone-beam dataset generated from SIMULATED_SPHERE_VOLUME''' + dfile = 'sim_cone_beam.nxs' +class SIMULATED_SPHERE_VOLUME(_NEXUS_CIL_DATA): + '''A simulated volume of spheres''' + dfile = 'sim_volume.nxs' + +class WALNUT(_REMOTE_DATA): + '''A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516''' + FOLDER = 'walnut' + URL = 'https://zenodo.org/record/4822516/files/walnut.zip' + FILE_SIZE = '6.4 GB' + + @classmethod + def get(cls, data_dir): + ''' + This function returns the raw projection data from the .txrm file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) + + Returns + ------- + ImageData + The walnut dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') + try: + loader = ZEISSDataReader(file_name=filepath) + return loader.read() + except FileNotFoundError as exc: + raise ValueError(f"Specify a different data_dir or download data with `{cls.__name__}.download_data({data_dir})`") from exc + +class USB(_REMOTE_DATA): + '''A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516''' + FOLDER = 'USB' + URL = 'https://zenodo.org/record/4822516/files/usb.zip' + FILE_SIZE = '3.2 GB' + + @classmethod + def get(cls, data_dir): + ''' + This function returns the raw projection data from the .txrm file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.USB.download_data(data_dir) + + Returns + ------- + ImageData + The usb dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'gruppe 4','gruppe 4_2014-03-20_1404_12','tomo-A','gruppe 4_tomo-A.txrm') + try: + loader = ZEISSDataReader(file_name=filepath) + return loader.read() + except FileNotFoundError as exc: + raise ValueError(f"Specify a different data_dir or download data with `{cls.__name__}.download_data({data_dir})`") from exc + +class KORN(_REMOTE_DATA): + '''A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123''' + FOLDER = 'korn' + URL = 'https://zenodo.org/record/6874123/files/korn.zip' + FILE_SIZE = '2.9 GB' + + @classmethod + def get(cls, data_dir): + ''' + This function returns the raw projection data from the .xtekct file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.KORN.download_data(data_dir) + + Returns + ------- + ImageData + The korn dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct') + try: + loader = NikonDataReader(file_name=filepath) + return loader.read() + except FileNotFoundError as exc: + raise ValueError(f"Specify a different data_dir or download data with `{cls.__name__}.download_data({data_dir})`") from exc + +class SANDSTONE(_REMOTE_DATA): + ''' + A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 + A small subset of the data containing selected projections and 4 slices of the reconstruction + ''' + FOLDER = 'sandstone' + URL = 'https://zenodo.org/records/4912435/files/small.zip' + FILE_SIZE = '227 MB' diff --git a/Wrappers/Python/test/test_dataexample.py b/Wrappers/Python/test/test_dataexample.py index 3eaa003133..47f7e72da0 100644 --- a/Wrappers/Python/test/test_dataexample.py +++ b/Wrappers/Python/test/test_dataexample.py @@ -186,7 +186,7 @@ def mock_urlopen(self, mock_urlopen): @patch('cil.utilities.dataexample.urlopen') def test_unzip_remote_data(self, mock_urlopen): self.mock_urlopen(mock_urlopen) - dataexample.REMOTEDATA._download_and_extract_from_url('.') + dataexample._REMOTE_DATA._download_and_extract_from_url('.') self.assertTrue(os.path.isfile(self.tmp_file)) @patch('cil.utilities.dataexample.input', return_value='n')