Skip to content

Commit

Permalink
Add preliminary support for evaluating composable model functions ove…
Browse files Browse the repository at this point in the history
…r posterior samples
  • Loading branch information
aymgal committed Mar 11, 2024
1 parent f4e620b commit 26001e8
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 37 deletions.
148 changes: 111 additions & 37 deletions coolest/api/composable_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
__author__ = 'aymgal'


import os
import numpy as np
import math
import logging
from scipy import signal
import pandas as pd
from functools import partial

from coolest.api import util


# logging settings
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().setLevel(logging.WARNING)


class BaseComposableModel(object):
Expand Down Expand Up @@ -39,8 +41,13 @@ class BaseComposableModel(object):
ValueError
No valid entity found or no profiles found.
"""

_chain_key = "chain_file_name"
_supported_eval_modes = ('point', 'posterior')

def __init__(self, model_type, coolest_object, coolest_directory=None,
def __init__(self, model_type,
coolest_object, coolest_directory=None,
load_posterior_samples=False,
entity_selection=None, profile_selection=None):
if entity_selection is None:
# finds the first entity that has a 'model_type' profile
Expand All @@ -58,24 +65,32 @@ def __init__(self, model_type, coolest_object, coolest_directory=None,
if entity_selection is None:
raise ValueError("No lensing entity with light profiles have been found")
else:
logging.info(f"Found valid profile for lensing entity (index {i}) for model type '{model_type}'")
logging.warning(f"Found valid profile for lensing entity (index {i}) for model type '{model_type}'")
if profile_selection is None:
profile_selection = 'all'
entities = coolest_object.lensing_entities
self.directory = coolest_directory
self.profile_list, self.param_list, self.info_list \
= self.select_profiles(model_type, entities,
entity_selection, profile_selection,
coolest_directory)
self._load_samples, self._csv_path = False, None
if load_posterior_samples:
metadata = coolest_object.meta
if self._chain_key not in metadata:
logging.warning(f"Metadata key '{self._chain_key}' is missing "
f"from COOLEST template, hence no posterior samples "
f"will be loaded.")
else:
self._load_samples = True
self._csv_path = os.path.join(self.directory, metadata[self._chain_key])
self.profile_list, self.param_list, self.post_param_list, self.info_list \
= self.get_profiles_and_params(model_type, entities,
entity_selection, profile_selection)
self.num_profiles = len(self.profile_list)
if self.num_profiles == 0:
raise ValueError("No profile has been selected!")

def select_profiles(self, model_type, entities,
entity_selection, profile_selection,
coolest_directory):
def get_profiles_and_params(self, model_type, entities,
entity_selection, profile_selection):
profile_list = []
param_list = []
param_list, post_param_list = [], []
info_list = []
for i, entity in enumerate(entities):
if self._selected(i, entity_selection):
Expand All @@ -84,17 +99,22 @@ def select_profiles(self, model_type, entities,
for j, profile in enumerate(getattr(entity, model_type)):
if self._selected(j, profile_selection):
if 'Grid' in profile.type:
if coolest_directory is None:
if self.directory is None:
raise ValueError("The directory in which the COOLEST file is located "
"must be provided for loading FITS files")
params, fixed = self._get_grid_params(profile, coolest_directory)
profile_list.append(self._get_api_profile(model_type, profile, *fixed))
param_list.append(params)
"must be provided for loading FITS files.")
params, fixed_params = self._get_grid_params(profile, self.directory)
profile_list.append(self._get_api_profile(model_type, profile, *fixed_params))
post_params = None # TODO: support samples for grid parameters
else:
params, post_params = self._get_regular_params(
profile, samples_file_path=self._csv_path
)
profile_list.append(self._get_api_profile(model_type, profile))
param_list.append(self._get_point_estimates(profile))
param_list.append(params)
post_param_list.append(post_params)
info_list.append((entity.name, entity.redshift))
return profile_list, param_list, info_list
post_param_list = self._reorganize_post_list(post_param_list)
return profile_list, param_list, post_param_list, info_list

def estimate_center(self):
# TODO: improve this (for now simply considers the first profile that has a center)
Expand All @@ -107,7 +127,7 @@ def estimate_center(self):
raise ValueError("Could not estimate a center from the composed model")

@staticmethod
def _get_api_profile(model_type, profile_in, *args):
def _get_api_profile(model_type, profile_in, *extra_profile_args):
"""
Takes as input a light profile from the template submodule
and instantites the corresponding profile from the API submodule
Expand All @@ -118,34 +138,64 @@ def _get_api_profile(model_type, profile_in, *args):
elif model_type == 'mass_model':
from coolest.api.profiles import mass
ProfileClass = getattr(mass, profile_in.type)
return ProfileClass(*args)
return ProfileClass(*extra_profile_args)

@staticmethod
def _get_point_estimates(profile_in):
parameters = {}
def _get_regular_params(profile_in, samples_file_path=None):
parameters = {} # best-fit values
samples = {} if samples_file_path else None # posterior samples
for name, param in profile_in.parameters.items():
parameters[name] = param.point_estimate.value
return parameters
if samples is not None:
# read just the column corresponding to the parameter ID
column = pd.read_csv(
samples_file_path,
usecols=[param.id],
delimiter=',',
)
samples[name] = np.array(column[param.id])
return parameters, samples

@staticmethod
def _get_grid_params(profile_in, fits_dir):
param_in = profile_in.parameters['pixels']
if profile_in.type == 'PixelatedRegularGrid':
data = profile_in.parameters['pixels'].get_pixels(directory=fits_dir)
data = param_in.get_pixels(directory=fits_dir)
parameters = {'pixels': data}
fov_x = profile_in.parameters['pixels'].field_of_view_x
fov_y = profile_in.parameters['pixels'].field_of_view_y
npix_x = profile_in.parameters['pixels'].num_pix_x
npix_y = profile_in.parameters['pixels'].num_pix_y
fov_x = param_in.field_of_view_x
fov_y = param_in.field_of_view_y
npix_x = param_in.num_pix_x
npix_y = param_in.num_pix_y
fixed_parameters = (fov_x, fov_y, npix_x, npix_y)

elif profile_in.type == 'IrregularGrid':
x, y, z = profile_in.parameters['pixels'].get_xyz(directory=fits_dir)
x, y, z = param_in.get_xyz(directory=fits_dir)
parameters = {'x': x, 'y': y, 'z': z}
fov_x = profile_in.parameters['pixels'].field_of_view_x
fov_y = profile_in.parameters['pixels'].field_of_view_y
npix = profile_in.parameters['pixels'].num_pix
fov_x = param_in.field_of_view_x
fov_y = param_in.field_of_view_y
npix = param_in.num_pix
fixed_parameters = (fov_x, fov_y, npix)
return parameters, fixed_parameters

@staticmethod
def _reorganize_post_list(param_list_of_samples):
"""
Takes as input the samples grouped at the leaves of the nested container structure,
and returns a list of items each organized as self.param_list
"""
num_profiles = len(param_list_of_samples)
profile_0 = param_list_of_samples[0]
if profile_0 is None: # happens when no samples have been loaded
return None
num_samples = len(profile_0[list(profile_0.keys())[0]])
samples_of_param_list = [
[{} for _ in range(num_profiles)] for _ in range(num_samples)
]
for i in range(num_samples):
for k in range(num_profiles):
for key in param_list_of_samples[k].keys():
samples_of_param_list[i][k][key] = param_list_of_samples[k][key][i]
return samples_of_param_list

@staticmethod
def _selected(index, selection):
Expand All @@ -157,6 +207,17 @@ def _selected(index, selection):
return True
return False

def _check_eval_mode(self, mode):
if mode not in self._supported_eval_modes:
raise NotImplementedError(
f"Only evaluation modes "
f"{self._supported_eval_modes} are supported "
f"(received '{mode}')."
)
if mode == 'posterior' and not self._load_samples:
raise ValueError(f"Selected evaluation mode '{mode}' "
f"but samples have not been loaded.")


class ComposableLightModel(BaseComposableModel):
"""Given a COOLEST object, evaluates a selection of entity and their light profiles.
Expand Down Expand Up @@ -238,18 +299,31 @@ class ComposableMassModel(BaseComposableModel):
No valid entity found or no profiles found.
"""

def __init__(self, coolest_object, coolest_directory=None, **kwargs_selection):
def __init__(self, coolest_object, coolest_directory=None,
load_posterior_samples=False,
**kwargs_selection):
super().__init__('mass_model', coolest_object,
coolest_directory=coolest_directory,
load_posterior_samples=load_posterior_samples,
**kwargs_selection)

def evaluate_potential(self, x, y):
def evaluate_potential(self, x, y, mode='point_estimate'):
"""Evaluates the lensing potential field at given coordinates"""
self._check_eval_mode(mode)
if mode == 'point':
return self._eval_pot_point(x, y, self.param_list)
return self._eval_pot_posterior(x, y, self.post_param_list)

def _eval_pot_point(self, x, y, param_list):
psi = np.zeros_like(x)
for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
psi += profile.potential(x, y, **params)
for k, profile in enumerate(self.profile_list):
psi += profile.potential(x, y, **param_list[k])
return psi

def _eval_pot_posterior(self, x, y, post_param_list):
mapped = map(partial(self._eval_pot_point, x, y), post_param_list)
return np.array(list(mapped))

def evaluate_deflection(self, x, y):
"""Evaluates the lensing deflection field at given coordinates"""
alpha_x, alpha_y = np.zeros_like(x), np.zeros_like(x)
Expand Down
2 changes: 2 additions & 0 deletions coolest/api/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# logging settings
logging.getLogger().setLevel(logging.INFO)

# TODO: separate ParametersPlotter from ModelPlotter to avoid dependencies on getdist


class ModelPlotter(object):
"""Create pyplot panels from a lens model stored in the COOLEST format.
Expand Down

0 comments on commit 26001e8

Please sign in to comment.