Skip to content

Commit

Permalink
Add new template mass profile: fully defined pixelated potential and …
Browse files Browse the repository at this point in the history
…its derivatives
  • Loading branch information
aymgal committed Jun 18, 2024
1 parent cd40139 commit 08e5b26
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 19 deletions.
44 changes: 40 additions & 4 deletions coolest/template/classes/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Tuple
import numpy as np
import numpy.testing as npt
import warnings

from coolest.template.classes.base import APIBaseObject
Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(self,
**kwargs_file) -> None:
super().__init__(fits_path, **kwargs_file)
self.set_grid(None, field_of_view_x, field_of_view_y,
num_pix_x, num_pix_y,
num_pix_x=num_pix_x, num_pix_y=num_pix_y,
check_fits_file=kwargs_file.get('check_fits_file', True))

@property
Expand All @@ -107,7 +108,8 @@ def pixel_size(self):
return 0.
pix_size_x = np.abs(self.field_of_view_x[0] - self.field_of_view_x[1]) / self.num_pix_x
pix_size_y = np.abs(self.field_of_view_y[0] - self.field_of_view_y[1]) / self.num_pix_y
assert pix_size_x == pix_size_y, "Regular grid must have square pixels"
npt.assert_almost_equal(pix_size_x, pix_size_y, decimal=6,
err_msg="Regular grid must have square pixels")
return pix_size_x

def set_grid(self, fits_path,
Expand All @@ -124,7 +126,8 @@ def set_grid(self, fits_path,
self.field_of_view_x = field_of_view_x
self.field_of_view_y = field_of_view_y
if self.fits_file.exists() and check_fits_file:
self.num_pix_x, self.num_pix_y = self.read_fits()
shape = self.read_fits()
self.num_pix_x, self.num_pix_y = shape[-2], shape[-1]
# if number of pixels is also given, check that it is consistent
if num_pix_x != 0 and self.num_pix_x != num_pix_x:
raise ValueError("Given number of pixels in x direction "
Expand All @@ -145,7 +148,8 @@ def read_fits(self):
"""
array, header = self.fits_file.read()
array_shape = array.shape
if array_shape != (header['NAXIS1'], header['NAXIS2']):
if (len(array_shape) == 2 and array_shape != (header['NAXIS1'], header['NAXIS2']) or
len(array_shape) == 3 and array_shape != (header['NAXIS1'], header['NAXIS2'], header['NAXIS3'])):
warnings.warn("Image dimensions do not match the FITS header")
return array_shape

Expand All @@ -168,6 +172,38 @@ def get_pixels(self, directory=None):
return array


class PixelatedRegularGridStack(PixelatedRegularGrid):

def __init__(self,
fits_path: str = None,
field_of_view_x: Tuple[float] = (0, 0),
field_of_view_y: Tuple[float] = (0, 0),
num_pix_x: int = 0,
num_pix_y: int = 0,
num_stack: int = 0,
**kwargs_file) -> None:
Grid.__init__(self, fits_path, **kwargs_file)
self.set_grid(None, field_of_view_x, field_of_view_y,
num_pix_x=num_pix_x, num_pix_y=num_pix_y, num_stack=num_stack,
check_fits_file=kwargs_file.get('check_fits_file', True))

@property
def shape(self):
return (self.num_stack, self.num_pix_x, self.num_pix_y)

def set_grid(self, *args, num_stack=0, **kwargs):
super().set_grid(*args, **kwargs)
if self.fits_file.exists() and kwargs.get('check_fits_file', True):
shape = self.read_fits()
self.num_stack = shape[0]
# if number of pixels is also given, check that it is consistent
if num_stack != 0 and self.num_stack != num_stack:
raise ValueError("Number of stacked pixelated grids "
"is inconsistent with the fits file")
else:
self.num_stack = num_stack


class IrregularGrid(Grid):
"""Class that represents an irregular set of values and their coordinates.
Expand Down
17 changes: 13 additions & 4 deletions coolest/template/classes/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from coolest.template.classes.base import APIBaseObject
from coolest.template.classes.probabilities import Prior, PosteriorStatistics
from coolest.template.classes.grid import PixelatedRegularGrid, IrregularGrid
from coolest.template.classes.grid import PixelatedRegularGrid, PixelatedRegularGridStack, IrregularGrid

import numpy as np

Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(self, *args, **kwargs):
class LinearParameter(Parameter):
"""Define a hyper-parameter of a lens model
Warning: this class may be removed in the future, as it has adds unncessary abstraction level.
Warning: this class may be removed in the future, as it adds an unnecessary abstraction level.
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -243,7 +243,7 @@ def __init__(self, *args, **kwargs) -> None:
class LinearParameterSet(ParameterSet):
"""Typically for analytical basis sets.
Warning: this class may be removed in the future, as it has adds unncessary abstraction level.
Warning: this class may be removed in the future, as it adds an unnecessary abstraction level.
"""

def __init__(self, *args, **kwargs) -> None:
Expand All @@ -253,7 +253,7 @@ def __init__(self, *args, **kwargs) -> None:
class NonLinearParameterSet(ParameterSet):
"""Typically for position of point sources.
Warning: this class may be removed in the future, as it has adds unncessary abstraction level."""
Warning: this class may be removed in the future, as it adds an unnecessary abstraction level."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -268,6 +268,15 @@ def __init__(self, documentation, **kwargs_grid) -> None:
super().__init__(**kwargs_grid)


class PixelatedRegularGridStackParameter(PixelatedRegularGridStack):
"""Typically for pixelated profiles that can be cast to a stacking of multiple grids"""
# TODO: implement .fixed attribute following the analytical Parameter interface

def __init__(self, documentation, **kwargs_grid) -> None:
self.documentation = documentation
super().__init__(**kwargs_grid)


class IrregularGridParameter(IrregularGrid):
"""Typically for pixelated profiles"""
# TODO: implement .fixed attribute following the analytical Parameter interface
Expand Down
28 changes: 27 additions & 1 deletion coolest/template/classes/profiles/mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from coolest.template.classes.profile import Profile, AnalyticalProfile
from coolest.template.classes.parameter import NonLinearParameter
from coolest.template.classes.parameter import (DefinitionRange,
PixelatedRegularGridParameter)
PixelatedRegularGridParameter,
PixelatedRegularGridStackParameter)
from coolest.template.classes.grid import PixelatedRegularGrid


Expand All @@ -17,6 +18,7 @@
'ExternalShear',
'ConvergenceSheet',
'PixelatedRegularGridPotential',
'PixelatedRegularGridFullyDefined',
]
SUPPORTED_CHOICES = __all__

Expand Down Expand Up @@ -308,3 +310,27 @@ def __init__(self):
'pixels': PixelatedRegularGridParameter("Pixel values")
}
super().__init__(parameters)


class PixelatedRegularGridFullyDefined(Profile):
"""Full mass model (potential, first and second spatial derivatives)
defined on a grid of regular pixels.
This profile is described by the following parameters:
- 'pixels': 2D array of pixel values
"""

def __init__(self):
parameters = {
'pixels': PixelatedRegularGridParameter(
"Pixel values for the lens potential"
),
'pixels_derivative': PixelatedRegularGridStackParameter(
"Pixel values for the first spatial derivative along x and y axes"
),
'pixels_hessian': PixelatedRegularGridStackParameter(
"Pixel values for the second spatial derivative along the 'xx', 'xy' and 'xy' axes"
),
}
super().__init__(parameters)
16 changes: 8 additions & 8 deletions coolest/template/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,17 +343,17 @@ def _update_parameters_values(self, entity_in, entity_out, model_type):
self._update_std_parameter(profile_out, name, values)

def _update_grid_parameter(self, profile_out, name, values):
if name != 'pixels':
raise NotImplementedError("Support for grid parameters other than "
"'pixels' is not implemented.")
if 'Regular' in profile_out.type:
if profile_out.type in ('PixelatedRegularGrid', 'PixelatedRegularGridPotential'):
pixels = self._setup_grid(values, PixelatedRegularGrid)
profile_out.parameters['pixels'] = pixels
elif 'Irregular' in profile_out.type:
profile_out.parameters[name] = pixels
elif profile_out.type == 'IrregularGrid':
pixels = self._setup_grid(values, IrregularGrid)
profile_out.parameters['pixels'] = pixels
profile_out.parameters[name] = pixels
elif profile_out.type == 'PixelatedRegularGridFullyDefined':
pixels = self._setup_grid(values, PixelatedRegularGridStack)
profile_out.parameters[name] = pixels
else:
raise ValueError(f"Unknown grid profile ({profile_out.type})")
raise ValueError(f"Unknown grid profile ('{profile_out.type}') and/or parameter name ('{name}').")

def _update_std_parameter(self, profile_out, name, values):
pt_estim = PointEstimate(**values['point_estimate'])
Expand Down
3 changes: 1 addition & 2 deletions coolest/template/lazy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# imports all classes that are need to be filled

from coolest.template.classes.grid import PixelatedRegularGrid,IrregularGrid
from coolest.template.classes.grid import PixelatedRegularGrid, PixelatedRegularGridStack, IrregularGrid
from coolest.template.classes.fits_file import FitsFile
from coolest.template.classes.observation import Observation
from coolest.template.classes.psf import *
Expand All @@ -16,4 +16,3 @@
from coolest.template.classes.lensing_entity_list import LensingEntityList
from coolest.template.classes.mass_light_model import MassModel, LightModel
from coolest.template.classes.probabilities import Prior, PosteriorStatistics

0 comments on commit 08e5b26

Please sign in to comment.