Skip to content

Commit

Permalink
Add function to get list of parameter instances
Browse files Browse the repository at this point in the history
  • Loading branch information
aymgal committed Jul 20, 2023
1 parent 38632ea commit f6d9ff0
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 21 deletions.
2 changes: 1 addition & 1 deletion coolest/template/classes/grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__author__ = 'aymgal, Giorgos Vernardos'
__author__ = 'aymgal', 'gvernard'

from typing import Tuple
import numpy as np
Expand Down
58 changes: 45 additions & 13 deletions coolest/template/classes/lensing_entity_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from coolest.template.classes.base import APIBaseObject
from coolest.template.classes.lensing_entity import LensingEntity
from coolest.template.classes.profile import AnalyticalProfile
from coolest.template.classes.parameter import Parameter
from coolest.template.classes import util


Expand Down Expand Up @@ -34,28 +34,61 @@ def __init__(self, *entities: Tuple[LensingEntity]):
APIBaseObject.__init__(self)
self._create_all_ids()

def get_parameter_ids(self, with_name=None):
"""Returns the list of either all parameter IDs in the model.
or a subset of them for parameters with a specific name.
def get_parameters(self, with_name=None, with_fixed=True):
"""Returns the list of either all parameters in the model,
or only a subset of them for parameters with a specific name.
Parameters
----------
with_name : str, optional
Parameter for which we want to get all corresponding IDs, by default None
Parameter for which we want to get all corresponding IDs (default: None).
with_fixed : bool, optional
If True, includes also fixed parameters (default: True).
Returns
-------
list
List of parameter instances
"""
def _selected(param_name):
return ((with_name is None) or
(with_name is not None and param_name == with_name))
id_list = []
def _selected(param_name, param):
# below we check that is is a Parameter instance because Grid-like parameters
# do not have (yet) the possibility to be fixed (no fixed attribute).
is_fixed = False if not isinstance(param, Parameter) else param.fixed
ignored_if_fixed = not with_fixed and is_fixed
if with_name is None:
return False if ignored_if_fixed else True
elif param_name == with_name:
return False if ignored_if_fixed else True
else:
return False
param_list = []
for entity in self:
for model_type in ('light', 'mass'):
model = getattr(entity, f'{model_type}_model', None)
if model is not None:
for profile in model:
for param_name, param in profile.parameters.items():
if _selected(param_name):
id_list.append(param.id)
return id_list
if _selected(param_name, param):
param_list.append(param)
return param_list

def get_parameter_ids(self, with_name=None, with_fixed=True):
"""Returns the list of either all parameter IDs in the model,
or only a subset of them for parameters with a specific name.
Parameters
----------
with_name : str, optional
Parameter for which we want to get all corresponding IDs (default: None).
with_fixed : bool, optional
If True, includes also fixed parameters (default: True).
Returns
-------
list
List of IDs (strings)
"""
return [p.id for p in self.get_parameters(with_name=with_name, with_fixed=with_fixed)]

def get_parameter_from_id(self, param_id):
"""Returns the Parameter instance that has the given parameter ID.
Expand Down Expand Up @@ -93,7 +126,6 @@ def _create_all_ids(self):
elif entity.type == 'MassField':
profile_id = util.mass_field_profile_to_id(profile.type, j, i)
profile.id = profile_id
# if isinstance(profile, AnalyticalProfile):
for param_name, parameter in profile.parameters.items():
param_id = util.parameter_to_id(param_name, profile.id)
parameter.id = param_id
2 changes: 2 additions & 0 deletions coolest/template/classes/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(self, *args, **kwargs) -> None:

class PixelatedRegularGridParameter(PixelatedRegularGrid):
"""Typically for pixelated profiles"""
# TODO: implement .fixed attribute following the analytical Parameter interface

def __init__(self, documentation, **kwargs_grid) -> None:
self.documentation = documentation
Expand All @@ -266,6 +267,7 @@ def __init__(self, documentation, **kwargs_grid) -> None:

class IrregularGridParameter(IrregularGrid):
"""Typically for pixelated profiles"""
# TODO: implement .fixed attribute following the analytical Parameter interface

def __init__(self, documentation, **kwargs_grid) -> None:
self.documentation = documentation
Expand Down
43 changes: 36 additions & 7 deletions docs/notebooks/02-generate_template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[<coolest.template.classes.mass_field.MassField object at 0x114708430>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x114708550>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x1146f02b0>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x1146f0f40>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x114729850>]\n"
"[<coolest.template.classes.mass_field.MassField object at 0x1304da520>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x1304daa60>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x13048fd30>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x1304f1220>,\n",
" <coolest.template.classes.galaxy.Galaxy object at 0x13051a8b0>]\n"
]
}
],
Expand Down Expand Up @@ -302,6 +302,35 @@
"print(coolest_2.lensing_entities.get_parameter_ids(with_name='theta_E'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also retrieve the list of all Parameter instances"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All parameters with name 'q' that are not fixed:\n",
"[<coolest.template.classes.parameter.NonLinearParameter object at 0x1304f4940>, <coolest.template.classes.parameter.NonLinearParameter object at 0x13051a280>, <coolest.template.classes.parameter.NonLinearParameter object at 0x13048fee0>, <coolest.template.classes.parameter.NonLinearParameter object at 0x13051b5e0>]\n",
"\n"
]
}
],
"source": [
"# or you can als get Parameters objects\n",
"print(\"All parameters with name 'q' that are not fixed:\")\n",
"print(coolest_2.lensing_entities.get_parameters(with_fixed=False, with_name='q'))\n",
"print(\"\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -314,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -343,7 +372,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit f6d9ff0

Please sign in to comment.