diff --git a/coolest/template/classes/grid.py b/coolest/template/classes/grid.py index bd6df41..4d571f0 100644 --- a/coolest/template/classes/grid.py +++ b/coolest/template/classes/grid.py @@ -2,6 +2,7 @@ from typing import Tuple import numpy as np +import numpy.testing as npt import warnings from coolest.template.classes.base import APIBaseObject @@ -94,7 +95,7 @@ def __init__(self, **kwargs_file) -> None: super().__init__(fits_path, **kwargs_file) self.set_grid(None, field_of_view_x, field_of_view_y, - num_pix_x, num_pix_y, + num_pix_x=num_pix_x, num_pix_y=num_pix_y, check_fits_file=kwargs_file.get('check_fits_file', True)) @property @@ -107,7 +108,8 @@ def pixel_size(self): return 0. pix_size_x = np.abs(self.field_of_view_x[0] - self.field_of_view_x[1]) / self.num_pix_x pix_size_y = np.abs(self.field_of_view_y[0] - self.field_of_view_y[1]) / self.num_pix_y - assert pix_size_x == pix_size_y, "Regular grid must have square pixels" + npt.assert_almost_equal(pix_size_x, pix_size_y, decimal=6, + err_msg="Regular grid must have square pixels") return pix_size_x def set_grid(self, fits_path, @@ -124,7 +126,8 @@ def set_grid(self, fits_path, self.field_of_view_x = field_of_view_x self.field_of_view_y = field_of_view_y if self.fits_file.exists() and check_fits_file: - self.num_pix_x, self.num_pix_y = self.read_fits() + shape = self.read_fits() + self.num_pix_x, self.num_pix_y = shape[-2], shape[-1] # if number of pixels is also given, check that it is consistent if num_pix_x != 0 and self.num_pix_x != num_pix_x: raise ValueError("Given number of pixels in x direction " @@ -145,7 +148,8 @@ def read_fits(self): """ array, header = self.fits_file.read() array_shape = array.shape - if array_shape != (header['NAXIS1'], header['NAXIS2']): + if (len(array_shape) == 2 and array_shape != (header['NAXIS1'], header['NAXIS2']) or + len(array_shape) == 3 and array_shape != (header['NAXIS1'], header['NAXIS2'], header['NAXIS3'])): warnings.warn("Image dimensions do not match the FITS header") return array_shape @@ -168,6 +172,38 @@ def get_pixels(self, directory=None): return array +class PixelatedRegularGridStack(PixelatedRegularGrid): + + def __init__(self, + fits_path: str = None, + field_of_view_x: Tuple[float] = (0, 0), + field_of_view_y: Tuple[float] = (0, 0), + num_pix_x: int = 0, + num_pix_y: int = 0, + num_stack: int = 0, + **kwargs_file) -> None: + Grid.__init__(self, fits_path, **kwargs_file) + self.set_grid(None, field_of_view_x, field_of_view_y, + num_pix_x=num_pix_x, num_pix_y=num_pix_y, num_stack=num_stack, + check_fits_file=kwargs_file.get('check_fits_file', True)) + + @property + def shape(self): + return (self.num_stack, self.num_pix_x, self.num_pix_y) + + def set_grid(self, *args, num_stack=0, **kwargs): + super().set_grid(*args, **kwargs) + if self.fits_file.exists() and kwargs.get('check_fits_file', True): + shape = self.read_fits() + self.num_stack = shape[0] + # if number of pixels is also given, check that it is consistent + if num_stack != 0 and self.num_stack != num_stack: + raise ValueError("Number of stacked pixelated grids " + "is inconsistent with the fits file") + else: + self.num_stack = num_stack + + class IrregularGrid(Grid): """Class that represents an irregular set of values and their coordinates. diff --git a/coolest/template/classes/parameter.py b/coolest/template/classes/parameter.py index e6c1709..b17c72f 100644 --- a/coolest/template/classes/parameter.py +++ b/coolest/template/classes/parameter.py @@ -4,7 +4,7 @@ from coolest.template.classes.base import APIBaseObject from coolest.template.classes.probabilities import Prior, PosteriorStatistics -from coolest.template.classes.grid import PixelatedRegularGrid, IrregularGrid +from coolest.template.classes.grid import PixelatedRegularGrid, PixelatedRegularGridStack, IrregularGrid import numpy as np @@ -214,7 +214,7 @@ def __init__(self, *args, **kwargs): class LinearParameter(Parameter): """Define a hyper-parameter of a lens model - Warning: this class may be removed in the future, as it has adds unncessary abstraction level. + Warning: this class may be removed in the future, as it adds an unnecessary abstraction level. """ def __init__(self, *args, **kwargs): @@ -243,7 +243,7 @@ def __init__(self, *args, **kwargs) -> None: class LinearParameterSet(ParameterSet): """Typically for analytical basis sets. - Warning: this class may be removed in the future, as it has adds unncessary abstraction level. + Warning: this class may be removed in the future, as it adds an unnecessary abstraction level. """ def __init__(self, *args, **kwargs) -> None: @@ -253,7 +253,7 @@ def __init__(self, *args, **kwargs) -> None: class NonLinearParameterSet(ParameterSet): """Typically for position of point sources. - Warning: this class may be removed in the future, as it has adds unncessary abstraction level.""" + Warning: this class may be removed in the future, as it adds an unnecessary abstraction level.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -268,6 +268,15 @@ def __init__(self, documentation, **kwargs_grid) -> None: super().__init__(**kwargs_grid) +class PixelatedRegularGridStackParameter(PixelatedRegularGridStack): + """Typically for pixelated profiles that can be cast to a stacking of multiple grids""" + # TODO: implement .fixed attribute following the analytical Parameter interface + + def __init__(self, documentation, **kwargs_grid) -> None: + self.documentation = documentation + super().__init__(**kwargs_grid) + + class IrregularGridParameter(IrregularGrid): """Typically for pixelated profiles""" # TODO: implement .fixed attribute following the analytical Parameter interface diff --git a/coolest/template/classes/profiles/mass.py b/coolest/template/classes/profiles/mass.py index d64acd9..9fc8ef0 100644 --- a/coolest/template/classes/profiles/mass.py +++ b/coolest/template/classes/profiles/mass.py @@ -3,7 +3,8 @@ from coolest.template.classes.profile import Profile, AnalyticalProfile from coolest.template.classes.parameter import NonLinearParameter from coolest.template.classes.parameter import (DefinitionRange, - PixelatedRegularGridParameter) + PixelatedRegularGridParameter, + PixelatedRegularGridStackParameter) from coolest.template.classes.grid import PixelatedRegularGrid @@ -17,6 +18,7 @@ 'ExternalShear', 'ConvergenceSheet', 'PixelatedRegularGridPotential', + 'PixelatedRegularGridFullyDefined', ] SUPPORTED_CHOICES = __all__ @@ -308,3 +310,27 @@ def __init__(self): 'pixels': PixelatedRegularGridParameter("Pixel values") } super().__init__(parameters) + + +class PixelatedRegularGridFullyDefined(Profile): + """Full mass model (potential, first and second spatial derivatives) + defined on a grid of regular pixels. + + This profile is described by the following parameters: + + - 'pixels': 2D array of pixel values + """ + + def __init__(self): + parameters = { + 'pixels': PixelatedRegularGridParameter( + "Pixel values for the lens potential" + ), + 'pixels_derivative': PixelatedRegularGridStackParameter( + "Pixel values for the first spatial derivative along x and y axes" + ), + 'pixels_hessian': PixelatedRegularGridStackParameter( + "Pixel values for the second spatial derivative along the 'xx', 'xy' and 'xy' axes" + ), + } + super().__init__(parameters) diff --git a/coolest/template/json.py b/coolest/template/json.py index 4772da7..baf8657 100644 --- a/coolest/template/json.py +++ b/coolest/template/json.py @@ -343,17 +343,17 @@ def _update_parameters_values(self, entity_in, entity_out, model_type): self._update_std_parameter(profile_out, name, values) def _update_grid_parameter(self, profile_out, name, values): - if name != 'pixels': - raise NotImplementedError("Support for grid parameters other than " - "'pixels' is not implemented.") - if 'Regular' in profile_out.type: + if profile_out.type in ('PixelatedRegularGrid', 'PixelatedRegularGridPotential'): pixels = self._setup_grid(values, PixelatedRegularGrid) - profile_out.parameters['pixels'] = pixels - elif 'Irregular' in profile_out.type: + profile_out.parameters[name] = pixels + elif profile_out.type == 'IrregularGrid': pixels = self._setup_grid(values, IrregularGrid) - profile_out.parameters['pixels'] = pixels + profile_out.parameters[name] = pixels + elif profile_out.type == 'PixelatedRegularGridFullyDefined': + pixels = self._setup_grid(values, PixelatedRegularGridStack) + profile_out.parameters[name] = pixels else: - raise ValueError(f"Unknown grid profile ({profile_out.type})") + raise ValueError(f"Unknown grid profile ('{profile_out.type}') and/or parameter name ('{name}').") def _update_std_parameter(self, profile_out, name, values): pt_estim = PointEstimate(**values['point_estimate']) diff --git a/coolest/template/lazy.py b/coolest/template/lazy.py index b41b42e..471e43f 100644 --- a/coolest/template/lazy.py +++ b/coolest/template/lazy.py @@ -1,6 +1,6 @@ # imports all classes that are need to be filled -from coolest.template.classes.grid import PixelatedRegularGrid,IrregularGrid +from coolest.template.classes.grid import PixelatedRegularGrid, PixelatedRegularGridStack, IrregularGrid from coolest.template.classes.fits_file import FitsFile from coolest.template.classes.observation import Observation from coolest.template.classes.psf import * @@ -16,4 +16,3 @@ from coolest.template.classes.lensing_entity_list import LensingEntityList from coolest.template.classes.mass_light_model import MassModel, LightModel from coolest.template.classes.probabilities import Prior, PosteriorStatistics -