diff --git a/coolest/template/classes/grid.py b/coolest/template/classes/grid.py index 67d56d3..f22e78f 100644 --- a/coolest/template/classes/grid.py +++ b/coolest/template/classes/grid.py @@ -1,4 +1,4 @@ -__author__ = 'aymgal, Giorgos Vernardos' +__author__ = 'aymgal', 'gvernard' from typing import Tuple import numpy as np diff --git a/coolest/template/classes/lensing_entity_list.py b/coolest/template/classes/lensing_entity_list.py index 840b097..307c0a2 100644 --- a/coolest/template/classes/lensing_entity_list.py +++ b/coolest/template/classes/lensing_entity_list.py @@ -4,7 +4,7 @@ from coolest.template.classes.base import APIBaseObject from coolest.template.classes.lensing_entity import LensingEntity -from coolest.template.classes.profile import AnalyticalProfile +from coolest.template.classes.parameter import Parameter from coolest.template.classes import util @@ -34,28 +34,61 @@ def __init__(self, *entities: Tuple[LensingEntity]): APIBaseObject.__init__(self) self._create_all_ids() - def get_parameter_ids(self, with_name=None): - """Returns the list of either all parameter IDs in the model. - or a subset of them for parameters with a specific name. + def get_parameters(self, with_name=None, with_fixed=True): + """Returns the list of either all parameters in the model, + or only a subset of them for parameters with a specific name. Parameters ---------- with_name : str, optional - Parameter for which we want to get all corresponding IDs, by default None + Parameter for which we want to get all corresponding IDs (default: None). + with_fixed : bool, optional + If True, includes also fixed parameters (default: True). + + Returns + ------- + list + List of parameter instances """ - def _selected(param_name): - return ((with_name is None) or - (with_name is not None and param_name == with_name)) - id_list = [] + def _selected(param_name, param): + # below we check that is is a Parameter instance because Grid-like parameters + # do not have (yet) the possibility to be fixed (no fixed attribute). + is_fixed = False if not isinstance(param, Parameter) else param.fixed + ignored_if_fixed = not with_fixed and is_fixed + if with_name is None: + return False if ignored_if_fixed else True + elif param_name == with_name: + return False if ignored_if_fixed else True + else: + return False + param_list = [] for entity in self: for model_type in ('light', 'mass'): model = getattr(entity, f'{model_type}_model', None) if model is not None: for profile in model: for param_name, param in profile.parameters.items(): - if _selected(param_name): - id_list.append(param.id) - return id_list + if _selected(param_name, param): + param_list.append(param) + return param_list + + def get_parameter_ids(self, with_name=None, with_fixed=True): + """Returns the list of either all parameter IDs in the model, + or only a subset of them for parameters with a specific name. + + Parameters + ---------- + with_name : str, optional + Parameter for which we want to get all corresponding IDs (default: None). + with_fixed : bool, optional + If True, includes also fixed parameters (default: True). + + Returns + ------- + list + List of IDs (strings) + """ + return [p.id for p in self.get_parameters(with_name=with_name, with_fixed=with_fixed)] def get_parameter_from_id(self, param_id): """Returns the Parameter instance that has the given parameter ID. @@ -93,7 +126,6 @@ def _create_all_ids(self): elif entity.type == 'MassField': profile_id = util.mass_field_profile_to_id(profile.type, j, i) profile.id = profile_id - # if isinstance(profile, AnalyticalProfile): for param_name, parameter in profile.parameters.items(): param_id = util.parameter_to_id(param_name, profile.id) parameter.id = param_id diff --git a/coolest/template/classes/parameter.py b/coolest/template/classes/parameter.py index 89c4a2e..d4df00c 100644 --- a/coolest/template/classes/parameter.py +++ b/coolest/template/classes/parameter.py @@ -258,6 +258,7 @@ def __init__(self, *args, **kwargs) -> None: class PixelatedRegularGridParameter(PixelatedRegularGrid): """Typically for pixelated profiles""" + # TODO: implement .fixed attribute following the analytical Parameter interface def __init__(self, documentation, **kwargs_grid) -> None: self.documentation = documentation @@ -266,6 +267,7 @@ def __init__(self, documentation, **kwargs_grid) -> None: class IrregularGridParameter(IrregularGrid): """Typically for pixelated profiles""" + # TODO: implement .fixed attribute following the analytical Parameter interface def __init__(self, documentation, **kwargs_grid) -> None: self.documentation = documentation diff --git a/docs/notebooks/02-generate_template.ipynb b/docs/notebooks/02-generate_template.ipynb index 8e0af92..826f608 100644 --- a/docs/notebooks/02-generate_template.ipynb +++ b/docs/notebooks/02-generate_template.ipynb @@ -210,11 +210,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[,\n", - " ,\n", - " ,\n", - " ,\n", - " ]\n" + "[,\n", + " ,\n", + " ,\n", + " ,\n", + " ]\n" ] } ], @@ -302,6 +302,35 @@ "print(coolest_2.lensing_entities.get_parameter_ids(with_name='theta_E'))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also retrieve the list of all Parameter instances" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All parameters with name 'q' that are not fixed:\n", + "[, , , ]\n", + "\n" + ] + } + ], + "source": [ + "# or you can als get Parameters objects\n", + "print(\"All parameters with name 'q' that are not fixed:\")\n", + "print(coolest_2.lensing_entities.get_parameters(with_fixed=False, with_name='q'))\n", + "print(\"\")" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -314,7 +343,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -343,7 +372,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [