Skip to content

Commit

Permalink
Load probability weights from csv file in ComposableModel
Browse files Browse the repository at this point in the history
  • Loading branch information
aymgal committed Mar 12, 2024
1 parent 72d56f1 commit 9e2667b
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions coolest/api/composable_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,14 @@ def __init__(self, model_type,
else:
self._posterior_bool = True
self._csv_path = os.path.join(self.directory, metadata[self._chain_key])
self.profile_list, self.param_list, self.post_param_list, self.info_list \
= self.get_profiles_and_params(model_type, entities,
entity_selection, profile_selection)
self.setup_profiles_and_params(model_type, entities,
entity_selection, profile_selection)
self.num_profiles = len(self.profile_list)
if self.num_profiles == 0:
raise ValueError("No profile has been selected!")

def get_profiles_and_params(self, model_type, entities,
entity_selection, profile_selection):
def setup_profiles_and_params(self, model_type, entities,
entity_selection, profile_selection):
profile_list = []
param_list, post_param_list = [], []
info_list = []
Expand All @@ -117,8 +116,13 @@ def get_profiles_and_params(self, model_type, entities,
param_list.append(params)
post_param_list.append(post_params)
info_list.append((entity.name, entity.redshift))
post_param_list = self._reorganize_post_list(post_param_list)
return profile_list, param_list, post_param_list, info_list
self.profile_list = profile_list
self.param_list = param_list
self.info_list = info_list
if self._posterior_bool is True:
post_param_list, post_weights = self._finalize_post_samples(post_param_list, self._csv_path)
self.post_param_list = post_param_list
self.post_weights = np.array(post_weights)

def estimate_center(self):
# TODO: improve this (for now simply considers the first profile that has a center)
Expand Down Expand Up @@ -158,7 +162,7 @@ def _get_regular_params(profile_in, samples_file_path=None):
delimiter=',',
)
# TODO: take into account probability weights from nested sampling runs!
samples[name] = np.array(column[param.id])
samples[name] = list(column[param.id])
return parameters, samples

@staticmethod
Expand All @@ -183,15 +187,13 @@ def _get_grid_params(profile_in, fits_dir):
return parameters, fixed_parameters

@staticmethod
def _reorganize_post_list(param_list_of_samples):
def _finalize_post_samples(param_list_of_samples, samples_file_path):
"""
Takes as input the samples grouped at the leaves of the nested container structure,
and returns a list of items each organized as self.param_list
"""
num_profiles = len(param_list_of_samples)
profile_0 = param_list_of_samples[0]
if profile_0 is None: # happens when no samples have been loaded
return None
num_samples = len(profile_0[list(profile_0.keys())[0]])
samples_of_param_list = [
[{} for _ in range(num_profiles)] for _ in range(num_samples)
Expand All @@ -200,7 +202,16 @@ def _reorganize_post_list(param_list_of_samples):
for k in range(num_profiles):
for key in param_list_of_samples[k].keys():
samples_of_param_list[i][k][key] = param_list_of_samples[k][key][i]
return samples_of_param_list
# also load and return the probability weights
# read just the column corresponding to the parameter ID
weights_key = 'probability_weights'
column = pd.read_csv(
samples_file_path,
usecols=[weights_key],
delimiter=',',
)
weights_list = list(column[weights_key])
return samples_of_param_list, weights_list

@staticmethod
def _selected(index, selection):
Expand Down Expand Up @@ -314,18 +325,18 @@ def evaluate_potential(self, x, y, mode='point', num_samples=None):
self._check_eval_mode(mode)
if mode == 'point' or self._posterior_bool is False:
return self._eval_pot_point(x, y, self.param_list)
return self._eval_pot_posterior(x, y, self.post_param_list, num_samples)
elif mode == 'posterior':
return self._eval_pot_posterior(x, y, self.post_param_list)

def _eval_pot_point(self, x, y, param_list):
psi = np.zeros_like(x)
for k, profile in enumerate(self.profile_list):
psi += profile.potential(x, y, **param_list[k])
return psi

def _eval_pot_posterior(self, x, y, post_param_list, num_max):
map_list = post_param_list if num_max is None else post_param_list[-num_max:]
def _eval_pot_posterior(self, x, y, param_list):
# map the point function at each sample
mapped = map(partial(self._eval_pot_point, x, y), map_list)
mapped = map(partial(self._eval_pot_point, x, y), param_list)
return np.array(list(mapped))

def evaluate_deflection(self, x, y):
Expand Down

0 comments on commit 9e2667b

Please sign in to comment.