diff --git a/examples/case_g/models.yml b/examples/case_g/models.yml index 90e1247..47e8a3e 100644 --- a/examples/case_g/models.yml +++ b/examples/case_g/models.yml @@ -1,7 +1,7 @@ - pymock: path: pymock - func: python run.py + func: pymock func_kwargs: n_sims: 100 mag_min: 3.5 - build: pip + build: venv diff --git a/floatcsep/environments.py b/floatcsep/environments.py new file mode 100644 index 0000000..c532dd6 --- /dev/null +++ b/floatcsep/environments.py @@ -0,0 +1,478 @@ +import configparser +import hashlib +import logging +import os +import shutil +import subprocess +import sys +import venv +from abc import ABC, abstractmethod + +from packaging.specifiers import SpecifierSet + +log = logging.getLogger("floatLogger") + + +class EnvironmentManager(ABC): + """ + Abstract base class for managing different types of environments. + This class defines the interface for creating, checking existence, + running commands, and installing dependencies in various environment types. + """ + + @abstractmethod + def __init__(self, base_name: str, model_directory: str): + """ + Initializes the environment manager with a base name and model directory. + + Args: + base_name (str): The base name for the environment. + model_directory (str): The directory containing the model files. + """ + self.base_name = base_name + self.model_directory = model_directory + + @abstractmethod + def create_environment(self, force=False): + """ + Creates the environment. If 'force' is True, it will remove any existing + environment with the same name before creating a new one. + + Args: + force (bool): Whether to forcefully remove an existing environment. + """ + pass + + @abstractmethod + def env_exists(self): + """ + Checks if the environment already exists. + + Returns: + bool: True if the environment exists, False otherwise. + """ + pass + + @abstractmethod + def run_command(self, command): + """ + Executes a command within the context of the environment. + + Args: + command (str): The command to be executed. + """ + pass + + @abstractmethod + def install_dependencies(self): + """ + Installs the necessary dependencies for the environment based on the + specified configuration or requirements. + """ + pass + + def generate_env_name(self) -> str: + """ + Generates a unique environment name by hashing the model directory + and appending it to the base name. + + Returns: + str: A unique name for the environment. + """ + dir_hash = hashlib.md5(self.model_directory.encode()).hexdigest()[:8] + return f"{self.base_name}_{dir_hash}" + + +class CondaEnvironmentManager(EnvironmentManager): + """ + Manages a conda (or mamba) environment, providing methods to create, check, + and manipulate conda environments specifically. + """ + + def __init__(self, base_name: str, model_directory: str): + """ + Initializes the Conda environment manager with the specified base name + and model directory. It also generates the environment name and detects + the package manager (conda or mamba) to install dependencies.. + + Args: + base_name (str): The base name, i.e., model name, for the conda environment. + model_directory (str): The directory containing the model files. + """ + self.base_name = base_name + self.model_directory = model_directory + self.env_name = self.generate_env_name() + self.package_manager = self.detect_package_manager() + + @staticmethod + def detect_package_manager(): + """ + Detects whether 'mamba' or 'conda' is available as the package manager. + + Returns: + str: The name of the detected package manager ('mamba' or 'conda'). + """ + if shutil.which("mamba"): + log.info("Mamba detected, using mamba as package manager.") + return "mamba" + log.info("Mamba not detected, using conda as package manager.") + return "conda" + + def create_environment(self, force=False): + """ + Creates a conda environment using either an environment.yml file or + the specified Python version in setup.py/setup.cfg or project/toml. + If 'force' is True, any existing environment with the same name will + be removed first. + + Args: + force (bool): Whether to forcefully remove an existing environment. + """ + if force and self.env_exists(): + log.info(f"Removing existing conda environment: {self.env_name}") + subprocess.run( + [ + self.package_manager, + "env", + "remove", + "--name", + self.env_name, + "--yes", + ] + ) + + if not self.env_exists(): + env_file = os.path.join(self.model_directory, "environment.yml") + if os.path.exists(env_file): + log.info(f"Creating sub-conda environment {self.env_name} from environment.yml") + subprocess.run( + [ + self.package_manager, + "env", + "create", + "--name", + self.env_name, + "--file", + env_file, + ] + ) + else: + python_version = self.detect_python_version() + log.info( + f"Creating sub-conda environment {self.env_name} with Python {python_version}" + ) + subprocess.run( + [ + self.package_manager, + "create", + "--name", + self.env_name, + "--yes", + f"python={python_version}", + ] + ) + log.info(f"\tSub-conda environment created: {self.env_name}") + + self.install_dependencies() + + def env_exists(self) -> bool: + """ + Checks if the conda environment exists by querying the list of + existing conda environments. + + Returns: + bool: True if the conda environment exists, False otherwise. + """ + result = subprocess.run(["conda", "env", "list"], stdout=subprocess.PIPE) + return self.env_name in result.stdout.decode() + + def detect_python_version(self) -> str: + """ + Determines the required Python version from setup files in the model directory. + It checks 'setup.py', 'pyproject.toml', and 'setup.cfg' (in that order), for + version specifications. + + Returns: + str: The detected or default Python version. + """ + setup_py = os.path.join(self.model_directory, "setup.py") + pyproject_toml = os.path.join(self.model_directory, "pyproject.toml") + setup_cfg = os.path.join(self.model_directory, "setup.cfg") + current_python_version = ( + f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) + + def parse_version(version_str): + # Extract the first valid version number + import re + + match = re.search(r"\d+(.\d+)*", version_str) + return match.group(0) if match else current_python_version + + def is_version_compatible(requirement, current_version): + try: + specifier = SpecifierSet(requirement) + return current_version in specifier + except Exception as e: + log.error(f"Invalid specifier: {requirement}. Error: {e}") + return False + + if os.path.exists(setup_py): + with open(setup_py) as f: + for line in f: + if "python_requires" in line: + required_version = line.split("=")[1].strip() + if is_version_compatible(required_version, current_python_version): + log.info(f"Using current Python version: {current_python_version}") + return current_python_version + return parse_version(required_version) + + if os.path.exists(pyproject_toml): + with open(pyproject_toml) as f: + for line in f: + if "python" in line and "=" in line: + required_version = line.split("=")[1].strip() + if is_version_compatible(required_version, current_python_version): + log.info(f"Using current Python version: {current_python_version}") + return current_python_version + return parse_version(required_version) + + if os.path.exists(setup_cfg): + config = configparser.ConfigParser() + config.read(setup_cfg) + if "options" in config and "python_requires" in config["options"]: + required_version = config["options"]["python_requires"].strip() + if is_version_compatible(required_version, current_python_version): + log.info(f"Using current Python version: {current_python_version}") + return current_python_version + return parse_version(required_version) + + return current_python_version + + def install_dependencies(self): + """ + Installs dependencies in the conda environment using pip, based on the + model setup file + """ + log.info(f"Installing dependencies in conda environment: {self.env_name}") + cmd = [ + self.package_manager, + "run", + "-n", + self.env_name, + "pip", + "install", + "-e", + self.model_directory, + ] + subprocess.run(cmd, check=True) + + def run_command(self, command): + """ + Runs a specified command within the conda environment + Args: + command (str): The command to be executed in the conda environment. + """ + cmd = [ + "bash", + "-c", + f"{self.package_manager} run -n {self.env_name} {command}", + ] + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + for line in process.stdout: + log.info(f"\t{line[:-1]}") + process.wait() + + +class VenvEnvironmentManager(EnvironmentManager): + """ + Manages a virtual environment created using Python's venv module. + Provides methods to create, check, and manipulate virtual environments. + """ + + def __init__(self, base_name: str, model_directory: str): + """ + Initializes the virtual environment manager with the specified base name + and model directory. + + Args: + base_name (str): The base name (i.e., model name) for the virtual environment. + model_directory (str): The directory containing the model files. + """ + + self.base_name = base_name + self.model_directory = model_directory + self.env_name = self.generate_env_name() + self.env_path = os.path.join(model_directory, self.env_name) + + def create_environment(self, force=False): + """ + Creates a virtual environment in the specified model directory. If 'force' + is True, any existing virtual environment will be removed before creation. + + Args: + force (bool): Whether to forcefully remove an existing virtual environment. + """ + if force and self.env_exists(): + log.info(f"Removing existing virtual environment: {self.env_name}") + shutil.rmtree(self.env_path) + + if not self.env_exists(): + log.info(f"Creating virtual environment: {self.env_name}") + venv.create(self.env_path, with_pip=True) + log.info(f"\tVirtual environment created: {self.env_name}") + self.install_dependencies() + + def env_exists(self) -> bool: + """ + Checks if the virtual environment exists by verifying the presence of its directory. + + Returns: + bool: True if the virtual environment exists, False otherwise. + """ + return os.path.isdir(self.env_path) + + def install_dependencies(self): + """ + Installs dependencies in the virtual environment using pip, based on the + model directory's configuration. + """ + log.info(f"Installing dependencies in virtual environment: {self.env_name}") + pip_executable = os.path.join(self.env_path, "bin", "pip") + cmd = f"{pip_executable} install -e {os.path.abspath(self.model_directory)}" + self.run_command(cmd) + + def run_command(self, command): + """ + Executes a specified command in the virtual environment and logs the output. + + Args: + command (str): The command to be executed in the virtual environment. + """ + activate_script = os.path.join(self.env_path, "bin", "activate") + + virtualenv = os.environ.copy() + virtualenv.pop("PYTHONPATH", None) + virtualenv["VIRTUAL_ENV"] = self.env_path + virtualenv["PATH"] = ( + os.path.join(self.env_path, "bin") + os.pathsep + virtualenv.get("PATH", "") + ) + + full_command = f"bash -c 'source {activate_script}' && {command}" + + process = subprocess.Popen( + full_command, + shell=True, + env=virtualenv, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + for line in process.stdout: + stripped_line = line.strip() + log.info(stripped_line) + process.wait() + + +class DockerEnvironmentManager(EnvironmentManager): + """ + Manages a Docker environment, providing methods to create, check, + and manipulate Docker containers for the environment. + """ + + def __init__(self, base_name: str, model_directory: str): + self.base_name = base_name + self.model_directory = model_directory + + def create_environment(self, force=False): + pass + + def env_exists(self): + pass + + def run_command(self, command): + pass + + def install_dependencies(self): + pass + + +class EnvironmentFactory: + """ + Factory class for creating instances of environment managers based on the specified type. + """ + + @staticmethod + def get_env( + build: str = None, model_name: str = "model", model_path: str = None + ) -> EnvironmentManager: + """ + Returns an instance of an environment manager based on the specified build type. + It checks the current environment type and can return a conda, venv, or Docker + environment manager. + + Args: + build (str): The desired type of environment ('conda', 'venv', or 'docker'). + model_name (str): The name of the model for which the environment is being created. + model_path (str): The path to the model directory. + + Returns: + EnvironmentManager: An instance of the appropriate environment manager. + + Raises: + Exception: If an invalid environment type is specified. + """ + run_env = EnvironmentFactory.check_environment_type() + if run_env != build and build and build != "docker": + log.warning( + f"Selected build environment ({build}) for this model is different than that of" + f" the experiment run. Consider selecting the same environment." + ) + if build == "conda" or (not build and run_env == "conda"): + return CondaEnvironmentManager( + base_name=f"{model_name}", + model_directory=os.path.abspath(model_path), + ) + elif build == "venv" or (not build and run_env == "venv"): + return VenvEnvironmentManager( + base_name=f"{model_name}", + model_directory=os.path.abspath(model_path), + ) + elif build == "docker": + return DockerEnvironmentManager( + base_name=f"{model_name}", + model_directory=os.path.abspath(model_path), + ) + else: + raise Exception( + "Wrong environment selection. Please choose between " + '"conda", "venv" or "docker".' + ) + + @staticmethod + def check_environment_type(): + if "VIRTUAL_ENV" in os.environ: + return "venv" + try: + subprocess.run( + ["conda", "info"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return "conda" + except FileNotFoundError: + pass + return None + + +if __name__ == "__main__": + + env = EnvironmentFactory.get_env( + "conda", model_path="../examples/case_h/models/pymock_poisson" + ) + env.create_environment(force=True) diff --git a/floatcsep/model.py b/floatcsep/model.py index 3d81bb8..93bb310 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -1,7 +1,6 @@ import json import logging import os -import subprocess from abc import ABC, abstractmethod from datetime import datetime from typing import List, Callable, Union, Mapping, Sequence @@ -13,6 +12,7 @@ from csep.utils.time_utils import decimal_year from floatcsep.accessors import from_zenodo, from_git +from floatcsep.environments import EnvironmentFactory from floatcsep.readers import ForecastParsers, HDF5Serializer from floatcsep.registry import ModelTree from floatcsep.utils import timewindow2str, str2timewindow @@ -124,9 +124,7 @@ def get_source( elif giturl: log.info(f"Retrieving model {self.name} from git url: " f"{giturl}") try: - from_git( - giturl, self.dir if self.path.fmt else self.path("path"), **kwargs - ) + from_git(giturl, self.dir if self.path.fmt else self.path("path"), **kwargs) except (git.NoSuchPathError, git.CommandError) as msg: raise git.NoSuchPathError(f"git url was not found {msg}") else: @@ -177,9 +175,7 @@ def iter_attr(val): return _get_value(val) list_walk = [ - (i, j) - for i, j in sorted(self.__dict__.items()) - if not i.startswith("_") and j + (i, j) for i, j in sorted(self.__dict__.items()) if not i.startswith("_") and j ] dict_walk = {i: j for i, j in list_walk} @@ -223,9 +219,7 @@ class TimeIndependentModel(Model): store_db (bool): flag to indicate whether to store the model in a database. """ - def __init__( - self, name: str, model_path: str, forecast_unit=1, store_db=False, **kwargs - ): + def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, **kwargs): super().__init__(name, model_path, **kwargs) self.forecast_unit = forecast_unit self.store_db = store_db @@ -289,9 +283,7 @@ def stage(self, timewindows: Union[str, List[datetime]] = None) -> None: def get_forecast( self, tstring: Union[str, list] = None, region=None - ) -> Union[ - GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast] - ]: + ) -> Union[GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]]: """ Wrapper that just returns a forecast when requested. """ @@ -333,9 +325,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None: start_date, end_date = str2timewindow(tstring) self.forecast_from_file(start_date, end_date, **kwargs) - def forecast_from_file( - self, start_date: datetime, end_date: datetime, **kwargs - ) -> None: + def forecast_from_file(self, start_date: datetime, end_date: datetime, **kwargs) -> None: """ Generates a forecast from a file, by parsing and scaling it to. @@ -403,42 +393,10 @@ def __init__( self.build = kwargs.get("build", "docker") self.run_prefix = "" - def build_model(self): - - if self.build == "pip" or self.build == "venv": - venv = os.path.join(self.path("path"), self.__dict__.get("venv", "venv")) - venvact = os.path.join(venv, "bin", "activate") - - if not os.path.exists(venv): - log.info(f"Building model {self.name} using pip") - subprocess.run(["python", "-m", "venv", venv]) - log.info(f"\tVirtual environment created in {venv}") - build_cmd = ( - f"source {venvact} && " - f"pip install --upgrade pip && " - f'pip install {self.path("path")}' - ) - - cmd = ["bash", "-c", build_cmd] - - log.info("\tInstalling dependencies") - - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True, - ) - for line in process.stdout: - log.info(f"\t{line[:-1]}") - process.wait() - log.info("\tEnvironment ready") - log.warning( - "\tNested environments is not fully supported. " - "Consider using docker instead" - ) - - self.run_prefix = f'cd {self.path("path")} && source {venvact} && ' + if self.func: + self.environment = EnvironmentFactory.get_env( + self.build, self.name, self.path.abs(self.model_path) + ) def stage(self, timewindows=None) -> None: """ @@ -450,7 +408,10 @@ def stage(self, timewindows=None) -> None: - Run model quality assurance (unit tests, runnable from floatcsep) """ self.get_source(self.zenodo_id, self.giturl, branch=self.repo_hash) - self.build_model() + + if hasattr(self, "environment"): + self.environment.create_environment() + self.path.build_tree( timewindows=timewindows, model_class="td", @@ -461,9 +422,7 @@ def stage(self, timewindows=None) -> None: def get_forecast( self, tstring: Union[str, list] = None, region=None - ) -> Union[ - GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast] - ]: + ) -> Union[GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]]: """Wrapper that just returns a forecast, hiding the access method under the hood""" if isinstance(tstring, str): @@ -511,9 +470,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None: else: log.info(f"Forecast of {tstring} of model {self.name} already " f"exists") - def forecast_from_func( - self, start_date: datetime, end_date: datetime, **kwargs - ) -> None: + def forecast_from_func(self, start_date: datetime, end_date: datetime, **kwargs) -> None: self.prepare_args(start_date, end_date, **kwargs) log.info( @@ -562,18 +519,7 @@ def replace_arg(arg, val, fp): def run_model(self): - if self.build == "pip" or self.build == "venv": - run_func = f'{self.func} {self.path("args_file")}' - cmd = ["bash", "-c", f"{self.run_prefix} {run_func}"] - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True, - ) - for line in process.stdout: - log.info(f"\t{line[:-1]}") - process.wait() + self.environment.run_command(f'{self.func} {self.path("args_file")}') class ModelFactory: diff --git a/floatcsep/registry.py b/floatcsep/registry.py index 8ece9b2..b36c119 100644 --- a/floatcsep/registry.py +++ b/floatcsep/registry.py @@ -118,8 +118,7 @@ def build_tree( # set forecast names fc_files = { - win: join(dirtree["forecasts"], f"{prefix}_{win}.csv") - for win in windows + win: join(dirtree["forecasts"], f"{prefix}_{win}.csv") for win in windows } fc_exists = { diff --git a/pyproject.toml b/pyproject.toml index dc80b62..5eb4733 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,4 +6,9 @@ build-backend = "setuptools.build_meta" addopts = "--cov=floatcsep" testpaths = [ "tests", -] \ No newline at end of file +] + +[tool.black] +line-length = 96 +skip-string-normalization = false +target-version = ["py39", "py310", "py311"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d1cbbb8..81d20f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,8 @@ flake8 gitpython h5py matplotlib +packaging +pandas pycsep pyshp pyyaml diff --git a/requirements_dev.txt b/requirements_dev.txt index 9cc3cc8..258b89e 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,5 +1,6 @@ numpy cartopy +black dateparser docker flake8 @@ -10,9 +11,9 @@ matplotlib mercantile mypy obspy +packaging pandas pillow -pyblack pycsep pydocstringformatter pyproj diff --git a/setup.cfg b/setup.cfg index 427f38d..82af787 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,8 @@ install_requires = gitpython h5py matplotlib + packaging + pandas pycsep pyshp pyyaml @@ -55,6 +57,7 @@ dev = mercantile mypy obspy + packaging pandas pillow pycsep @@ -63,7 +66,6 @@ dev = pyshp pytest pytest-cov - pytest-asyncio pyyaml requests scipy diff --git a/tests/artifacts/models/td_model/setup.cfg b/tests/artifacts/models/td_model/setup.cfg new file mode 100644 index 0000000..e69de29 diff --git a/tests/qa/test_data.py b/tests/qa/test_data.py index b043987..96dd1d2 100644 --- a/tests/qa/test_data.py +++ b/tests/qa/test_data.py @@ -1,5 +1,7 @@ from floatcsep.cmd import main +from floatcsep.experiment import Experiment import unittest +from unittest.mock import patch import os @@ -33,77 +35,58 @@ def get_eval_dist(self): pass +@patch.object(Experiment, "generate_report") +@patch.object(Experiment, "plot_forecasts") +@patch.object(Experiment, "plot_catalog") class RunExamples(DataTest): - def test_case_a(self): + def test_case_a(self, *args): cfg = self.get_runpath('a') self.run_evaluation(cfg) self.assertEqual(1, 1) - def test_case_b(self): + def test_case_b(self, *args): cfg = self.get_runpath('b') self.run_evaluation(cfg) self.assertEqual(1, 1) - def test_case_c(self): + def test_case_c(self, *args): cfg = self.get_runpath('c') self.run_evaluation(cfg) self.assertEqual(1, 1) - def test_case_d(self): + def test_case_d(self, *args): cfg = self.get_runpath('d') self.run_evaluation(cfg) self.assertEqual(1, 1) - def test_case_e(self): + def test_case_e(self, *args): cfg = self.get_runpath('e') self.run_evaluation(cfg) self.assertEqual(1, 1) - def test_case_f(self): + def test_case_f(self, *args): cfg = self.get_runpath('f') self.run_evaluation(cfg) self.assertEqual(1, 1) - def test_case_g(self): + def test_case_g(self, *args): cfg = self.get_runpath('g') self.run_evaluation(cfg) self.assertEqual(1, 1) +@patch.object(Experiment, "generate_report") +@patch.object(Experiment, "plot_forecasts") +@patch.object(Experiment, "plot_catalog") class ReproduceExamples(DataTest): - def test_case_a(self): - cfg = self.get_rerunpath('a') - self.repr_evaluation(cfg) - self.assertEqual(1, 1) - - def test_case_b(self): - cfg = self.get_rerunpath('b') - self.repr_evaluation(cfg) - self.assertEqual(1, 1) - - def test_case_c(self): + def test_case_c(self, *args): cfg = self.get_rerunpath('c') self.repr_evaluation(cfg) self.assertEqual(1, 1) - def test_case_d(self): - cfg = self.get_rerunpath('d') - self.repr_evaluation(cfg) - self.assertEqual(1, 1) - - def test_case_e(self): - cfg = self.get_rerunpath('e') - self.repr_evaluation(cfg) - self.assertEqual(1, 1) - - def test_case_f(self): + def test_case_f(self, *args): cfg = self.get_rerunpath('f') self.repr_evaluation(cfg) self.assertEqual(1, 1) - - def test_case_g(self): - cfg = self.get_rerunpath('g') - self.repr_evaluation(cfg) - self.assertEqual(1, 1) \ No newline at end of file diff --git a/tests/unit/test_accessors.py b/tests/unit/test_accessors.py index a2ab143..260009b 100644 --- a/tests/unit/test_accessors.py +++ b/tests/unit/test_accessors.py @@ -1,20 +1,20 @@ import os.path import vcr from datetime import datetime -from floatcsep.accessors import query_gcmt, _query_gcmt, from_zenodo, \ - from_git, _check_hash +from floatcsep.accessors import query_gcmt, _query_gcmt, from_zenodo, from_git, _check_hash import unittest from unittest import mock root_dir = os.path.dirname(os.path.abspath(__file__)) + def gcmt_dir(): - data_dir = os.path.join(root_dir, '../artifacts', 'gcmt') + data_dir = os.path.join(root_dir, "../artifacts", "gcmt") return data_dir def zenodo_dir(): - data_dir = os.path.join(root_dir, '../artifacts', 'zenodo') + data_dir = os.path.join(root_dir, "../artifacts", "zenodo") return data_dir @@ -23,32 +23,31 @@ class TestCatalogGetter(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.makedirs(gcmt_dir(), exist_ok=True) - cls._fname = os.path.join(gcmt_dir(), 'test_cat') + cls._fname = os.path.join(gcmt_dir(), "test_cat") def test_gcmt_search(self): - tape_file = os.path.join(gcmt_dir(), 'vcr_search.yaml') + tape_file = os.path.join(gcmt_dir(), "vcr_search.yaml") with vcr.use_cassette(tape_file): # Maule, Chile - eventlist = \ - _query_gcmt(start_time=datetime(2010, 2, 26), - end_time=datetime(2010, 3, 2), - min_magnitude=6) + eventlist = _query_gcmt( + start_time=datetime(2010, 2, 26), end_time=datetime(2010, 3, 2), min_magnitude=6 + ) event = eventlist[0] - assert event[0] == '2844986' + assert event[0] == "2844986" def test_gcmt_summary(self): - tape_file = os.path.join(gcmt_dir(), 'vcr_summary.yaml') + tape_file = os.path.join(gcmt_dir(), "vcr_summary.yaml") with vcr.use_cassette(tape_file): - eventlist = \ - _query_gcmt(start_time=datetime(2010, 2, 26), - end_time=datetime(2010, 3, 2), - min_magnitude=7) + eventlist = _query_gcmt( + start_time=datetime(2010, 2, 26), end_time=datetime(2010, 3, 2), min_magnitude=7 + ) event = eventlist[0] cmp = "('2844986', 1267252514000, -35.98, -73.15, 23.2, 8.8)" assert str(event) == cmp - assert event[0] == '2844986' - assert datetime.fromtimestamp( - event[1] / 1000.) == datetime.fromtimestamp(1267252514) + assert event[0] == "2844986" + assert datetime.fromtimestamp(event[1] / 1000.0) == datetime.fromtimestamp( + 1267252514 + ) assert event[2] == -35.98 assert event[3] == -73.15 assert event[4] == 23.2 @@ -57,54 +56,52 @@ def test_gcmt_summary(self): def test_catalog_query_plot(self): start_datetime = datetime(2020, 1, 1) end_datetime = datetime(2020, 3, 1) - catalog = query_gcmt(start_time=start_datetime, - end_time=end_datetime, - min_magnitude=5.95) - catalog.plot(set_global=True, plot_args={'filename': self._fname, - 'basemap': 'stock_img'}) - assert os.path.isfile(self._fname + '.png') - assert os.path.isfile(self._fname + '.pdf') + catalog = query_gcmt( + start_time=start_datetime, end_time=end_datetime, min_magnitude=5.95 + ) + catalog.plot( + set_global=True, plot_args={"filename": self._fname, "basemap": "stock_img"} + ) + assert os.path.isfile(self._fname + ".png") + assert os.path.isfile(self._fname + ".pdf") @classmethod def tearDownClass(cls) -> None: try: - os.remove(os.path.join(gcmt_dir(), cls._fname + '.pdf')) - os.remove(os.path.join(gcmt_dir(), cls._fname + '.png')) + os.remove(os.path.join(gcmt_dir(), cls._fname + ".pdf")) + os.remove(os.path.join(gcmt_dir(), cls._fname + ".png")) except OSError: pass + class TestZenodoGetter(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.makedirs(zenodo_dir(), exist_ok=True) - cls._txt = os.path.join(zenodo_dir(), 'dummy.txt') - cls._tar = os.path.join(zenodo_dir(), 'dummy.tar') + cls._txt = os.path.join(zenodo_dir(), "dummy.txt") + cls._tar = os.path.join(zenodo_dir(), "dummy.tar") def test_zenodo_query(self): from_zenodo(4739912, zenodo_dir()) assert os.path.isfile(self._txt) assert os.path.isfile(self._tar) - with open(self._txt, 'r') as dummy: - assert dummy.readline() == 'test' - _check_hash(self._tar, 'md5:17f80d606ff085751998ac4050cc614c') + with open(self._txt, "r") as dummy: + assert dummy.readline() == "test" + _check_hash(self._tar, "md5:17f80d606ff085751998ac4050cc614c") @classmethod def tearDownClass(cls) -> None: - os.remove(os.path.join(zenodo_dir(), 'dummy.txt')) - os.remove(os.path.join(zenodo_dir(), 'dummy.tar')) + os.remove(os.path.join(zenodo_dir(), "dummy.txt")) + os.remove(os.path.join(zenodo_dir(), "dummy.tar")) os.rmdir(zenodo_dir()) class TestGitter(unittest.TestCase): - @mock.patch('floatcsep.accessors.git.Repo') - @mock.patch('git.Git') + @mock.patch("floatcsep.accessors.git.Repo") + @mock.patch("git.Git") def runTest(self, mock_git, mock_repo): p = mock.PropertyMock(return_value=False) type(mock_repo.clone_from.return_value).bare = p - from_git( - '/tmp/testrepo', - 'git@github.com:github/testrepo.git', - 'master' - ) - mock_git.checkout.called_once_with('master') + from_git("/tmp/testrepo", "git@github.com:github/testrepo.git", "master") + mock_git.checkout.called_once_with("master") diff --git a/tests/unit/test_environments.py b/tests/unit/test_environments.py new file mode 100644 index 0000000..cef9511 --- /dev/null +++ b/tests/unit/test_environments.py @@ -0,0 +1,352 @@ +import venv +import unittest +import subprocess +import os +from unittest.mock import patch, MagicMock, call, mock_open +import shutil +import hashlib +import logging +from floatcsep.environments import ( + CondaEnvironmentManager, + EnvironmentFactory, + VenvEnvironmentManager, + DockerEnvironmentManager, +) + + +class TestCondaEnvironmentManager(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if not shutil.which("conda"): + raise unittest.SkipTest("Conda is not available in the environment.") + + def setUp(self): + self.manager = CondaEnvironmentManager( + base_name="test_env", model_directory="/tmp/test_model" + ) + os.makedirs("/tmp/test_model", exist_ok=True) + with open("/tmp/test_model/environment.yml", "w") as f: + f.write("name: test_env\ndependencies:\n - python=3.8\n - numpy") + with open("/tmp/test_model/setup.py", "w") as f: + f.write("from setuptools import setup\nsetup(name='test_model', version='0.1')") + + def tearDown(self): + if self.manager.env_exists(): + subprocess.run( + ["conda", "env", "remove", "--name", self.manager.env_name, "--yes"], + check=True, + ) + if os.path.exists("/tmp/test_model"): + shutil.rmtree("/tmp/test_model") + + @patch("subprocess.run") + @patch("shutil.which", return_value="conda") + def test_generate_env_name(self, mock_which, mock_run): + manager = CondaEnvironmentManager("test_base", "/path/to/model") + expected_name = "test_base_" + hashlib.md5("/path/to/model".encode()).hexdigest()[:8] + print(expected_name) + self.assertEqual(manager.generate_env_name(), expected_name) + + @patch("subprocess.run") + def test_env_exists(self, mock_run): + hashed = hashlib.md5("/path/to/model".encode()).hexdigest()[:8] + mock_run.return_value.stdout = f"test_base_{hashed}\n".encode() + + manager = CondaEnvironmentManager("test_base", "/path/to/model") + self.assertTrue(manager.env_exists()) + + @patch("subprocess.run") + @patch("os.path.exists", return_value=True) + def test_create_environment(self, mock_exists, mock_run): + manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager.create_environment(force=False) + package_manager = manager.detect_package_manager() + expected_calls = [ + call(["conda", "env", "list"], stdout=-1), + call().stdout.decode(), + call().stdout.decode().__contains__(manager.env_name), + call( + [ + package_manager, + "env", + "create", + "--name", + manager.env_name, + "--file", + "/path/to/model/environment.yml", + ] + ), + call( + [ + package_manager, + "run", + "-n", + manager.env_name, + "pip", + "install", + "-e", + "/path/to/model", + ], + check=True, + ), + ] + + self.assertEqual(mock_run.call_count, 3) + mock_run.assert_has_calls(expected_calls, any_order=False) + + @patch("subprocess.run") + def test_create_environment_force(self, mock_run): + manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager.env_exists = MagicMock(return_value=True) + manager.create_environment(force=True) + self.assertEqual(mock_run.call_count, 2) # One for remove, one for create + + @patch("subprocess.run") + @patch.object(CondaEnvironmentManager, "detect_package_manager", return_value="conda") + def test_install_dependencies(self, mock_detect_package_manager, mock_run): + manager = CondaEnvironmentManager("test_base", "/path/to/model") + manager.install_dependencies() + mock_run.assert_called_once_with( + [ + "conda", + "run", + "-n", + manager.env_name, + "pip", + "install", + "-e", + "/path/to/model", + ], + check=True, + ) + + @patch("shutil.which", return_value="conda") + @patch("os.path.exists", side_effect=[False, False, True]) + @patch( + "builtins.open", + new_callable=mock_open, + read_data="[metadata]\nname = test\n\n[options]\ninstall_requires =\n numpy\npython_requires = >=3.9,<3.12\n", + ) + def test_detect_python_version_setup_cfg(self, mock_open, mock_exists, mock_which): + manager = CondaEnvironmentManager("test_base", "../artifacts/models/td_model") + python_version = manager.detect_python_version() + + # Extract major and minor version parts + major_minor_version = ".".join(python_version.split(".")[:2]) + + self.assertIn( + major_minor_version, ["3.9", "3.10", "3.11"] + ) # Check if it falls within the specified range + + def test_create_and_delete_environment(self): + # Create the environment + self.manager.create_environment(force=True) + + # Check if the environment was created + result = subprocess.run(["conda", "env", "list"], stdout=subprocess.PIPE, check=True) + self.assertIn(self.manager.env_name, result.stdout.decode()) + + # Check if numpy is installed + result = subprocess.run( + [ + "conda", + "run", + "-n", + self.manager.env_name, + "python", + "-c", + "import numpy", + ], + check=True, + ) + self.assertEqual(result.returncode, 0) + + # Delete the environment + self.manager.create_environment( + force=True + ) # This should remove and recreate the environment + + # Check if the environment was recreated + result = subprocess.run(["conda", "env", "list"], stdout=subprocess.PIPE, check=True) + self.assertIn(self.manager.env_name, result.stdout.decode()) + + +class TestEnvironmentFactory(unittest.TestCase): + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value="conda") + def test_get_env_conda(self, mock_check_env, mock_abspath): + env_manager = EnvironmentFactory.get_env( + build="conda", model_name="test_model", model_path="/path/to/model" + ) + self.assertIsInstance(env_manager, CondaEnvironmentManager) + self.assertEqual(env_manager.base_name, "test_model") + self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value="venv") + def test_get_env_venv(self, mock_check_env, mock_abspath): + env_manager = EnvironmentFactory.get_env( + build="venv", model_name="test_model", model_path="/path/to/model" + ) + self.assertIsInstance(env_manager, VenvEnvironmentManager) + self.assertEqual(env_manager.base_name, "test_model") + self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value=None) + def test_get_env_docker(self, mock_check_env, mock_abspath): + env_manager = EnvironmentFactory.get_env( + build="docker", model_name="test_model", model_path="/path/to/model" + ) + self.assertIsInstance(env_manager, DockerEnvironmentManager) + self.assertEqual(env_manager.base_name, "test_model") + self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value="conda") + def test_get_env_default_conda(self, mock_check_env, mock_abspath): + env_manager = EnvironmentFactory.get_env( + build=None, model_name="test_model", model_path="/path/to/model" + ) + self.assertIsInstance(env_manager, CondaEnvironmentManager) + self.assertEqual(env_manager.base_name, "test_model") + self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value="venv") + def test_get_env_default_venv(self, mock_check_env, mock_abspath): + env_manager = EnvironmentFactory.get_env( + build=None, model_name="test_model", model_path="/path/to/model" + ) + self.assertIsInstance(env_manager, VenvEnvironmentManager) + self.assertEqual(env_manager.base_name, "test_model") + self.assertEqual(env_manager.model_directory, "/absolute/path/to/model") + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value=None) + def test_get_env_invalid(self, mock_check_env, mock_abspath): + with self.assertRaises(Exception) as context: + EnvironmentFactory.get_env( + build="invalid", model_name="test_model", model_path="/path/to/model" + ) + self.assertTrue("Wrong environment selection" in str(context.exception)) + + @patch("os.path.abspath", return_value="/absolute/path/to/model") + @patch.object(EnvironmentFactory, "check_environment_type", return_value="venv") + @patch("logging.Logger.warning") + def test_get_env_warning(self, mock_log_warning, mock_check_env, mock_abspath): + EnvironmentFactory.get_env( + build="conda", model_name="test_model", model_path="/path/to/model" + ) + mock_log_warning.assert_called_once_with( + f"Selected build environment (conda) for this model is different than that of" + f" the experiment run. Consider selecting the same environment." + ) + + +class TestVenvEnvironmentManager(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # Check if venv is available (Python standard library) + if not hasattr(venv, "create"): + raise unittest.SkipTest("Venv is not available in the environment.") + + def setUp(self): + self.model_directory = "/tmp/test_model" + self.manager = VenvEnvironmentManager( + base_name="test_env", model_directory=self.model_directory + ) + os.makedirs(self.model_directory, exist_ok=True) + with open(os.path.join(self.model_directory, "setup.py"), "w") as f: + f.write("from setuptools import setup\nsetup(name='test_model', version='0.1')") + logging.disable(logging.CRITICAL) + + def tearDown(self): + if self.manager.env_exists(): + shutil.rmtree(self.manager.env_path) + if os.path.exists(self.model_directory): + shutil.rmtree(self.model_directory) + + def test_create_and_delete_environment(self): + # Create the environment + self.manager.create_environment(force=True) + + # Check if the environment was created + self.assertTrue(self.manager.env_exists()) + + # Check if pip is available in the environment + pip_executable = os.path.join(self.manager.env_path, "bin", "pip") + result = subprocess.run( + [pip_executable, "list"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + self.assertEqual(result.returncode, 0) # pip should run without errors + + # Delete the environment + self.manager.create_environment( + force=True + ) # This should remove and recreate the environment + + # Check if the environment was recreated + self.assertTrue(self.manager.env_exists()) + + def test_init(self): + self.assertEqual(self.manager.base_name, "test_env") + self.assertEqual(self.manager.model_directory, self.model_directory) + self.assertTrue(self.manager.env_name.startswith("test_env_")) + + def test_env_exists(self): + self.assertFalse(self.manager.env_exists()) + self.manager.create_environment(force=True) + self.assertTrue(self.manager.env_exists()) + + def test_create_environment(self): + self.manager.create_environment(force=True) + self.assertTrue(self.manager.env_exists()) + + def test_create_environment_force(self): + self.manager.create_environment(force=True) + env_path_before = self.manager.env_path + self.manager.create_environment(force=True) + self.assertTrue(self.manager.env_exists()) + self.assertEqual(env_path_before, self.manager.env_path) # Ensure it's a new path + + def test_install_dependencies(self): + self.manager.create_environment(force=True) + pip_executable = os.path.join(self.manager.env_path, "bin", "pip") + result = subprocess.run( + [pip_executable, "install", "-e", self.model_directory], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + self.assertEqual(result.returncode, 0) # pip should run without errors + + @patch("subprocess.Popen") + def test_run_command(self, mock_popen): + # Arrange + mock_process = MagicMock() + mock_process.stdout = iter(["Output line 1\n", "Output line 2\n"]) + mock_process.wait.return_value = None + mock_popen.return_value = mock_process + + command = "echo test_command" + + # Act + self.manager.run_command(command) + + output_cmd = f"bash -c 'source {os.path.join(self.manager.env_path, 'bin', 'activate')}' && {command}" + # Assert + mock_popen.assert_called_once_with( + output_cmd, + shell=True, + env=unittest.mock.ANY, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_evaluation.py b/tests/unit/test_evaluation.py index 11aee79..4b26298 100644 --- a/tests/unit/test_evaluation.py +++ b/tests/unit/test_evaluation.py @@ -12,37 +12,36 @@ def setUpClass(cls) -> None: def mock_eval(): return - setattr(cls, 'mock_eval', mock_eval) + setattr(cls, "mock_eval", mock_eval) @staticmethod def init_noreg(name, func, **kwargs): - """ Instantiates a model without using the @register deco, - but mocks Model.Registry() attrs"""\ - + """Instantiates a model without using the @register deco, + but mocks Model.Registry() attrs""" # deprecated # evaluation = Evaluation.__new__(Evaluation) # Evaluation.__init__.__wrapped__(self=evaluation, # name=name, # func=func, # **kwargs) - evaluation = Evaluation(name=name, func=func, - **kwargs) + evaluation = Evaluation(name=name, func=func, **kwargs) return evaluation def test_init(self): - name = 'N_test' - eval_ = self.init_noreg(name=name, - func=self.mock_eval) + name = "N_test" + eval_ = self.init_noreg(name=name, func=self.mock_eval) self.assertIs(None, eval_.type) - dict_ = {'name': 'N_test', - 'func': self.mock_eval, - 'func_kwargs': {}, - 'ref_model': None, - 'plot_func': None, - 'plot_args': None, - 'plot_kwargs': None, - 'markdown': '', - '_type': None} + dict_ = { + "name": "N_test", + "func": self.mock_eval, + "func_kwargs": {}, + "ref_model": None, + "plot_func": None, + "plot_args": None, + "plot_kwargs": None, + "markdown": "", + "_type": None, + } self.assertEqual(dict_, eval_.__dict__) def test_discrete_args(self): @@ -56,32 +55,30 @@ def test_prepare_catalog(self): def read_cat(_): cat = Mock() - cat.name = 'csep' + cat.name = "csep" return cat - with patch('csep.core.catalogs.CSEPCatalog.load_json', read_cat): - region = 'CSEPRegion' - forecast = MagicMock(name='forecast', region=region) + with patch("csep.core.catalogs.CSEPCatalog.load_json", read_cat): + region = "CSEPRegion" + forecast = MagicMock(name="forecast", region=region) - catt = Evaluation.get_catalog('path_to_cat', forecast) - self.assertEqual('csep', catt.name) + catt = Evaluation.get_catalog("path_to_cat", forecast) + self.assertEqual("csep", catt.name) self.assertEqual(region, catt.region) - region2 = 'definitelyNotCSEPregion' - forecast2 = Mock(name='forecast', region=region2) - cats = Evaluation.get_catalog(['path1', 'path2'], - [forecast, forecast2]) + region2 = "definitelyNotCSEPregion" + forecast2 = Mock(name="forecast", region=region2) + cats = Evaluation.get_catalog(["path1", "path2"], [forecast, forecast2]) self.assertIsInstance(cats, list) - self.assertEqual(cats[0].name, 'csep') - self.assertEqual(cats[0].region, 'CSEPRegion') - self.assertEqual(cats[1].region, 'definitelyNotCSEPregion') + self.assertEqual(cats[0].name, "csep") + self.assertEqual(cats[0].region, "CSEPRegion") + self.assertEqual(cats[1].region, "definitelyNotCSEPregion") with self.assertRaises(AttributeError): - Evaluation.get_catalog('path1', [forecast, forecast2]) + Evaluation.get_catalog("path1", [forecast, forecast2]) with self.assertRaises(IndexError): - Evaluation.get_catalog(['path1', 'path2'], - forecast) + Evaluation.get_catalog(["path1", "path2"], forecast) assert True def test_write_result(self): diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py index 2f417a5..2a1ebc0 100644 --- a/tests/unit/test_experiment.py +++ b/tests/unit/test_experiment.py @@ -9,19 +9,18 @@ from csep.core.catalogs import CSEPCatalog _dir = os.path.dirname(__file__) -_model_cfg = os.path.normpath(os.path.join(_dir, '../artifacts', 'models', - 'model_cfg.yml')) -_region = os.path.normpath(os.path.join(_dir, '../artifacts', 'regions', - 'mock_region')) -_time_config = {'start_date': datetime(2021, 1, 1), - 'end_date': datetime(2022, 1, 1)} -_region_config = {'region': _region, - 'mag_max': 10.0, - 'mag_min': 1.0, - 'mag_bin': 0.1, - 'depth_min': 0, - 'depth_max': 1} -_cat = os.path.normpath(os.path.join(_dir, '../artifacts', 'catalog.json')) +_model_cfg = os.path.normpath(os.path.join(_dir, "../artifacts", "models", "model_cfg.yml")) +_region = os.path.normpath(os.path.join(_dir, "../artifacts", "regions", "mock_region")) +_time_config = {"start_date": datetime(2021, 1, 1), "end_date": datetime(2022, 1, 1)} +_region_config = { + "region": _region, + "mag_max": 10.0, + "mag_min": 1.0, + "mag_bin": 0.1, + "depth_min": 0, + "depth_max": 1, +} +_cat = os.path.normpath(os.path.join(_dir, "../artifacts", "catalog.json")) class TestExperiment(TestCase): @@ -46,68 +45,72 @@ def assertEqualExperiment(self, exp_a, exp_b): def init_no_wrap(name, path, **kwargs): model = Experiment.__new__(Experiment) - Experiment.__init__.__wrapped__(self=model, name=name, - path=path, **kwargs) + Experiment.__init__.__wrapped__(self=model, name=name, path=path, **kwargs) return Experiment def test_init(self): - exp_a = Experiment(**_time_config, **_region_config, - catalog=_cat) - exp_b = Experiment(time_config=_time_config, - region_config=_region_config, - catalog=_cat) + exp_a = Experiment(**_time_config, **_region_config, catalog=_cat) + exp_b = Experiment(time_config=_time_config, region_config=_region_config, catalog=_cat) self.assertEqualExperiment(exp_a, exp_b) def test_to_dict(self): - time_config = {'start_date': datetime(2020, 1, 1), - 'end_date': datetime(2021, 1, 1), - 'horizon': '6 month', - 'growth': 'cumulative'} - - region_config = {'region': 'california_relm_region', - 'mag_max': 9.0, - 'mag_min': 3.0, - 'mag_bin': 0.1, - 'depth_min': -2, - 'depth_max': 70} - - exp_a = Experiment(name='test', **time_config, **region_config, - catalog=_cat) - dict_ = {'name': 'test', - 'path': os.getcwd(), - 'rundir': 'results', - 'time_config': - {'exp_class': 'ti', - 'start_date': datetime(2020, 1, 1), - 'end_date': datetime(2021, 1, 1), - 'horizon': '6-months', - 'growth': 'cumulative'}, - 'region_config': { - 'region': 'california_relm_region', - 'mag_max': 9.0, - 'mag_min': 3.0, - 'mag_bin': 0.1, - 'depth_min': -2, - 'depth_max': 70 - }, - 'catalog': os.path.relpath(_cat, os.getcwd()) - } + time_config = { + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2021, 1, 1), + "horizon": "6 month", + "growth": "cumulative", + } + + region_config = { + "region": "california_relm_region", + "mag_max": 9.0, + "mag_min": 3.0, + "mag_bin": 0.1, + "depth_min": -2, + "depth_max": 70, + } + + exp_a = Experiment(name="test", **time_config, **region_config, catalog=_cat) + dict_ = { + "name": "test", + "path": os.getcwd(), + "rundir": "results", + "time_config": { + "exp_class": "ti", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2021, 1, 1), + "horizon": "6-months", + "growth": "cumulative", + }, + "region_config": { + "region": "california_relm_region", + "mag_max": 9.0, + "mag_min": 3.0, + "mag_bin": 0.1, + "depth_min": -2, + "depth_max": 70, + }, + "catalog": os.path.relpath(_cat, os.getcwd()), + } self.assertEqual(dict_, exp_a.as_dict()) def test_to_yml(self): - time_config = {'start_date': datetime(2021, 1, 1), - 'end_date': datetime(2022, 1, 1), - 'intervals': 12} - - region_config = {'region': 'california_relm_region', - 'mag_max': 9.0, - 'mag_min': 3.0, - 'mag_bin': 0.1, - 'depth_min': -2, - 'depth_max': 70} - - exp_a = Experiment(**time_config, **region_config, - catalog=_cat) + time_config = { + "start_date": datetime(2021, 1, 1), + "end_date": datetime(2022, 1, 1), + "intervals": 12, + } + + region_config = { + "region": "california_relm_region", + "mag_max": 9.0, + "mag_min": 3.0, + "mag_bin": 0.1, + "depth_min": -2, + "depth_max": 70, + } + + exp_a = Experiment(**time_config, **region_config, catalog=_cat) file_ = tempfile.mkstemp()[1] exp_a.to_yml(file_) exp_b = Experiment.from_yml(file_) @@ -120,49 +123,49 @@ def test_to_yml(self): self.assertEqualExperiment(exp_a, exp_c) def test_set_models(self): - exp = Experiment(**_time_config, **_region_config, - model_config=_model_cfg, - catalog=_cat) + exp = Experiment( + **_time_config, **_region_config, model_config=_model_cfg, catalog=_cat + ) names = [i.name for i in exp.models] - self.assertEqual(['mock', 'qtree@team10', 'qtree@team25'], names) + self.assertEqual(["mock", "qtree@team10", "qtree@team25"], names) m1_path = os.path.normpath( - os.path.join(_dir, '../artifacts', 'models', 'qtree', - 'TEAM=N10L11.csv')) + os.path.join(_dir, "../artifacts", "models", "qtree", "TEAM=N10L11.csv") + ) def test_stage_models(self): - exp = Experiment(**_time_config, **_region_config, - model_config=_model_cfg, - catalog=_cat) + exp = Experiment( + **_time_config, **_region_config, model_config=_model_cfg, catalog=_cat + ) exp.stage_models() - dbpath = os.path.relpath( - os.path.join(_dir, '../artifacts', 'models', 'model.hdf5')) + dbpath = os.path.relpath(os.path.join(_dir, "../artifacts", "models", "model.hdf5")) self.assertEqual(exp.models[0].path.database, dbpath) def test_set_tests(self): test_cfg = os.path.normpath( - os.path.join(_dir, '../artifacts', 'evaluations', - 'tests_cfg.yml')) - exp = Experiment(**_time_config, **_region_config, - test_config=test_cfg, - catalog=_cat) + os.path.join(_dir, "../artifacts", "evaluations", "tests_cfg.yml") + ) + exp = Experiment(**_time_config, **_region_config, test_config=test_cfg, catalog=_cat) funcs = [i.func for i in exp.tests] - funcs_expected = [poisson_evaluations.number_test, - poisson_evaluations.spatial_test, - poisson_evaluations.paired_t_test] + funcs_expected = [ + poisson_evaluations.number_test, + poisson_evaluations.spatial_test, + poisson_evaluations.paired_t_test, + ] for i, j in zip(funcs, funcs_expected): self.assertIs(i, j) def test_prepare_subcatalog(self): time_config = {**_time_config} - exp = Experiment(**time_config, **_region_config, - catalog=_cat) - tstring = '2020-08-01_2021-01-02' + exp = Experiment(**time_config, **_region_config, catalog=_cat) + tstring = "2020-08-01_2021-01-02" with tempfile.NamedTemporaryFile() as file_: + def filetree(*args): return file_.name + exp.path = filetree # with patch.object(exp, 'filetree', filetree): # print(file_.name) @@ -172,6 +175,6 @@ def filetree(*args): @classmethod def tearDownClass(cls) -> None: - path_ = os.path.join(_dir, '../artifacts', 'models', 'model.hdf5') + path_ = os.path.join(_dir, "../artifacts", "models", "model.hdf5") if os.path.isfile(path_): os.remove(path_) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 965804f..c57dadf 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -9,6 +9,7 @@ import csep.core.regions import numpy.testing +from floatcsep.environments import EnvironmentManager from floatcsep.model import TimeIndependentModel, TimeDependentModel from floatcsep.utils import str2timewindow @@ -229,14 +230,14 @@ def init_model(name, path, **kwargs): return model - def test_from_git(self): + @patch.object(EnvironmentManager, "create_environment") + def test_from_git(self, mock_create_environment): """clones model from git, checks with test artifacts""" name = "mock_git" _dir = "git_template" path_ = os.path.join(tempfile.tempdir, _dir) giturl = ( - "https://git.gfz-potsdam.de/csep-group/" - "rise_italy_experiment/models/template.git" + "https://git.gfz-potsdam.de/csep-group/" "rise_italy_experiment/models/template.git" ) model_a = self.init_model(name=name, path=path_, giturl=giturl) model_a.stage() @@ -335,42 +336,6 @@ def test_get_forecast_multiple(self, mock_load_forecast): self.assertEqual(result[0], mock_forecast1) self.assertEqual(result[1], mock_forecast2) - @patch("subprocess.run") # Mock subprocess.run - @patch("subprocess.Popen") # Mock subprocess.Popen - @patch("os.path.exists") # Mock os.path.exists - def test_build_model_creates_venv(self, mock_exists, mock_popen, mock_run): - # Arrange - model_path = "../artifacts/models/td_model" - model = self.init_model(name="TestModel", path=model_path, build="venv") - mock_exists.return_value = False # Simulate that the venv does not exist - - # Act - model.build_model() - - # Assert - mock_run.assert_called_once_with( - ["python", "-m", "venv", model.path("path") + "/venv"] - ) - mock_popen.assert_called_once() # Ensure Popen was called to install dependencies - self.assertIn(f"cd {os.path.abspath(model_path)} && source", model.run_prefix) - - @patch("subprocess.run") - @patch("subprocess.Popen") - @patch("os.path.exists") - def test_build_model_when_venv_exists(self, mock_exists, mock_popen, mock_run): - # Arrange - model_path = "../artifacts/models/td_model" - model = self.init_model(name="TestModel", path=model_path, build="venv") - mock_exists.return_value = True # Simulate that the venv already exists - - # Act - model.build_model() - - # Assert - mock_run.assert_not_called() - mock_popen.assert_not_called() - self.assertIn(f"cd {os.path.abspath(model_path)} && source", model.run_prefix) - def test_argprep(self): model_path = os.path.join(self._dir, "td_model") with open(os.path.join(model_path, "input", "args.txt"), "w") as args: @@ -396,12 +361,8 @@ def test_argprep(self): @patch("floatcsep.model.open", new_callable=mock_open, read_data='{"key": "value"}') @patch("json.dump") def test_argprep_json(self, mock_json_dump, mock_file): - model = self.init_model( - name="TestModel", path=os.path.join(self._dir, "td_model") - ) - model.path = MagicMock( - return_value="path/to/model/args.json" - ) # Mock path method + model = self.init_model(name="TestModel", path=os.path.join(self._dir, "td_model")) + model.path = MagicMock(return_value="path/to/model/args.json") start = MagicMock() end = MagicMock() start.isoformat.return_value = "2023-01-01" diff --git a/tests/unit/test_readers.py b/tests/unit/test_readers.py index c4c6a41..1ed0648 100644 --- a/tests/unit/test_readers.py +++ b/tests/unit/test_readers.py @@ -14,30 +14,27 @@ class TestForecastParsers(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls._path = os.path.dirname(__file__) - cls._dir = os.path.join(cls._path, '../artifacts', 'models') + cls._dir = os.path.join(cls._path, "../artifacts", "models") + + @classmethod + def tearDownClass(cls) -> None: + fname = os.path.join(cls._dir, 'model.hdf5') + if os.path.isfile(fname): + os.remove(fname) def test_parse_csv(self): - fname = os.path.join(self._dir, 'model.csv') + fname = os.path.join(self._dir, "model.csv") numpy.seterr(all="ignore") rates, region, mags = readers.ForecastParsers.csv(fname) - rts = numpy.array([[1., 0.1], - [1., 0.1], - [1., 0.1], - [1., 0.1]]) - orgs = numpy.array([[0., 0.], - [0.1, 0], - [0., 0.1], - [0.1, 0.1]]) - poly_2 = numpy.array([[0., 0.1], - [0., 0.2], - [0.1, 0.2], - [0.1, 0.1]]) + rts = numpy.array([[1.0, 0.1], [1.0, 0.1], [1.0, 0.1], [1.0, 0.1]]) + orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]]) + poly_2 = numpy.array([[0.0, 0.1], [0.0, 0.2], [0.1, 0.2], [0.1, 0.1]]) numpy.testing.assert_allclose(rts, rates) numpy.testing.assert_allclose(orgs, region.origins()) numpy.testing.assert_almost_equal(0.1, region.dh) - numpy.testing.assert_allclose([5., 5.1], mags) + numpy.testing.assert_allclose([5.0, 5.1], mags) numpy.testing.assert_allclose(poly_2, region.polygons[2].points) def test_parse_dat(self): @@ -50,41 +47,41 @@ def test_parse_dat(self): numpy.testing.assert_allclose(forecast.data, rates) def test_parse_csv_qtree(self): - fname = os.path.join(self._dir, 'qtree', 'TEAM=N10L11.csv') + fname = os.path.join(self._dir, "qtree", "TEAM=N10L11.csv") numpy.seterr(all="ignore") rates, region, mags = readers.ForecastParsers.csv(fname) - poly = numpy.array([[-180., 66.51326], - [-180., 79.171335], - [-135., 79.171335], - [-135., 66.51326]]) + poly = numpy.array( + [[-180.0, 66.51326], [-180.0, 79.171335], [-135.0, 79.171335], [-135.0, 66.51326]] + ) numpy.testing.assert_allclose(115.96694121688556, rates.sum()) - numpy.testing.assert_allclose([-177.1875, 51.179343], - region.origins()[123]) + numpy.testing.assert_allclose([-177.1875, 51.179343], region.origins()[123]) self.assertEqual(8089, rates.shape[0]) numpy.testing.assert_allclose(poly, region.polygons[2].points) rates2, region2, mags2 = readers.ForecastParsers.quadtree(fname) numpy.testing.assert_allclose(rates, rates2) - numpy.testing.assert_allclose([i.points for i in region.polygons], - [i.points for i in region2.polygons]) + numpy.testing.assert_allclose( + [i.points for i in region.polygons], [i.points for i in region2.polygons] + ) numpy.testing.assert_allclose(mags, mags2) def test_parse_xml(self): - fname = os.path.join(self._path, '../../examples', 'case_e', - 'models', - 'gulia-wiemer.ALM.italy.10yr.2010-01-01.xml') + fname = os.path.join( + self._path, + "../../examples", + "case_e", + "models", + "gulia-wiemer.ALM.italy.10yr.2010-01-01.xml", + ) numpy.seterr(all="ignore") rates, region, mags = readers.ForecastParsers.xml(fname) orgs = numpy.array([12.6, 38.3]) - poly = numpy.array([[12.6, 38.3], - [12.6, 38.4], - [12.7, 38.4], - [12.7, 38.3]]) + poly = numpy.array([[12.6, 38.3], [12.6, 38.4], [12.7, 38.4], [12.7, 38.3]]) mags_ = numpy.arange(5, 9.05, 0.1) numpy.testing.assert_allclose(16.185424321406536, rates.sum()) @@ -96,40 +93,34 @@ def test_parse_xml(self): def test_serialize_hdf5(self): numpy.seterr(all="ignore") - fname = os.path.join(self._dir, 'model.csv') + fname = os.path.join(self._dir, "model.csv") rates, region, mags = readers.ForecastParsers.csv(fname) - fname_db = os.path.join(self._dir, 'model.hdf5') - readers.HDF5Serializer.grid2hdf5(rates, region, mags, - hdf5_filename=fname_db) + fname_db = os.path.join(self._dir, "model.hdf5") + readers.HDF5Serializer.grid2hdf5(rates, region, mags, hdf5_filename=fname_db) self.assertTrue(os.path.isfile(fname_db)) size = os.path.getsize(fname_db) - self.assertEqual(4640, size) + self.assertLessEqual(4500, size) + self.assertGreaterEqual(5000, size) def test_parse_hdf5(self): - fname = os.path.join(self._dir, 'model_h5.hdf5') + fname = os.path.join(self._dir, "model_h5.hdf5") rates, region, mags = readers.ForecastParsers.hdf5(fname) - orgs = numpy.array([[0., 0.], - [0.1, 0], - [0., 0.1], - [0.1, 0.1]]) - poly_3 = numpy.array([[0.1, 0.1], - [0.1, 0.2], - [0.2, 0.2], - [0.2, 0.1]]) + orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]]) + poly_3 = numpy.array([[0.1, 0.1], [0.1, 0.2], [0.2, 0.2], [0.2, 0.1]]) numpy.testing.assert_allclose(4.4, rates.sum()) numpy.testing.assert_allclose(orgs, region.origins()) numpy.testing.assert_almost_equal(0.1, region.dh) - numpy.testing.assert_allclose([5., 5.1], mags) + numpy.testing.assert_allclose([5.0, 5.1], mags) numpy.testing.assert_allclose(poly_3, region.polygons[3].points) def test_checkformat_xml(self): def save(xml_list): - name_ = os.path.join(tempfile.tempdir, 'tmpxml.xml') - with open(name_, 'w') as file_: + name_ = os.path.join(tempfile.tempdir, "tmpxml.xml") + with open(name_, "w") as file_: for i in xml_list: - file_.write(i + '\n') + file_.write(i + "\n") return name_ forecast_xml = [ @@ -148,34 +139,34 @@ def save(xml_list): filename = save(forecast_xml) try: - readers.check_format(filename, fmt='xml') + readers.check_format(filename, fmt="xml") except (IndentationError, IndexError, KeyError): - self.fail('Format check failed') + self.fail("Format check failed") xml_fail = copy.deepcopy(forecast_xml) xml_fail[3] = "" xml_fail[-3] = "" filename = save(xml_fail) with pytest.raises(LookupError): - readers.check_format(filename, fmt='xml') + readers.check_format(filename, fmt="xml") xml_fail = copy.deepcopy(forecast_xml) xml_fail[4] = "" filename = save(xml_fail) with pytest.raises(KeyError): - readers.check_format(filename, fmt='xml') + readers.check_format(filename, fmt="xml") xml_fail = copy.deepcopy(forecast_xml) xml_fail[5] = "1.6773966e-008" filename = save(xml_fail) with pytest.raises(LookupError): - readers.check_format(filename, fmt='xml') + readers.check_format(filename, fmt="xml") xml_fail = copy.deepcopy(forecast_xml) xml_fail[5] = "1.6773966e-008" filename = save(xml_fail) with pytest.raises(KeyError): - readers.check_format(filename, fmt='xml') + readers.check_format(filename, fmt="xml") xml_fail = copy.deepcopy(forecast_xml) xml_fail[2] = "" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index f70a696..2e08832 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -10,8 +10,13 @@ import floatcsep import floatcsep.accessors import floatcsep.extras -from floatcsep.utils import parse_timedelta_string, timewindows_ti, \ - read_time_cfg, read_region_cfg, parse_csep_func +from floatcsep.utils import ( + parse_timedelta_string, + timewindows_ti, + read_time_cfg, + read_region_cfg, + parse_csep_func, +) root_dir = os.path.dirname(os.path.abspath(__file__)) @@ -19,73 +24,82 @@ class CsepFunctionTest(unittest.TestCase): def test_parse_csep_func(self): - self.assertIsInstance(parse_csep_func('load_gridded_forecast'), - csep.load_gridded_forecast.__class__) - self.assertIsInstance(parse_csep_func('join_struct_arrays'), - csep.utils.join_struct_arrays.__class__) - self.assertIsInstance(parse_csep_func('plot_poisson_consistency_test'), - csep.utils.plots.plot_poisson_consistency_test.__class__) - self.assertIsInstance(parse_csep_func('italy_csep_region'), - csep.core.regions.italy_csep_region.__class__) - self.assertIsInstance(parse_csep_func('plot_forecast_lowres'), - floatcsep.utils.plot_forecast_lowres.__class__) - self.assertIsInstance(parse_csep_func('from_zenodo'), - floatcsep.accessors.from_zenodo.__class__) - self.assertIsInstance(parse_csep_func('from_zenodo'), - floatcsep.extras.vector_poisson_t_w_test.__class__) - self.assertRaises(AttributeError, parse_csep_func, 'panic_button') + self.assertIsInstance( + parse_csep_func("load_gridded_forecast"), csep.load_gridded_forecast.__class__ + ) + self.assertIsInstance( + parse_csep_func("join_struct_arrays"), csep.utils.join_struct_arrays.__class__ + ) + self.assertIsInstance( + parse_csep_func("plot_poisson_consistency_test"), + csep.utils.plots.plot_poisson_consistency_test.__class__, + ) + self.assertIsInstance( + parse_csep_func("italy_csep_region"), csep.core.regions.italy_csep_region.__class__ + ) + self.assertIsInstance( + parse_csep_func("plot_forecast_lowres"), + floatcsep.utils.plot_forecast_lowres.__class__, + ) + self.assertIsInstance( + parse_csep_func("from_zenodo"), floatcsep.accessors.from_zenodo.__class__ + ) + self.assertIsInstance( + parse_csep_func("from_zenodo"), floatcsep.extras.vector_poisson_t_w_test.__class__ + ) + self.assertRaises(AttributeError, parse_csep_func, "panic_button") class TimeUtilsTest(unittest.TestCase): def test_parse_time_window(self): - dt = '1Year' - self.assertEqual(parse_timedelta_string(dt), '1-years') - dt = '7-Days' - self.assertEqual(parse_timedelta_string(dt), '7-days') - dt = '1- mOnThS' - self.assertEqual(parse_timedelta_string(dt), '1-months') - dt = '20 days' - self.assertEqual(parse_timedelta_string(dt), '20-days') - dt = '1decade' + dt = "1Year" + self.assertEqual(parse_timedelta_string(dt), "1-years") + dt = "7-Days" + self.assertEqual(parse_timedelta_string(dt), "7-days") + dt = "1- mOnThS" + self.assertEqual(parse_timedelta_string(dt), "1-months") + dt = "20 days" + self.assertEqual(parse_timedelta_string(dt), "20-days") + dt = "1decade" self.assertRaises(ValueError, parse_timedelta_string, dt) def test_timewindows_ti(self): start = datetime(2014, 1, 1) end = datetime(2022, 1, 1) - self.assertEqual(timewindows_ti(start_date=start, - end_date=end), [(start, end)]) - - t1 = [(datetime(2014, 1, 1), datetime(2018, 1, 1)), - (datetime(2018, 1, 1), datetime(2022, 1, 1))] - self.assertEqual(timewindows_ti(start_date=start, - end_date=end, - intervals=2), t1) - self.assertEqual(timewindows_ti(start_date=start, - end_date=end, - horizon='4-years'), t1) - self.assertEqual(timewindows_ti(start_date=start, - intervals=2, - horizon='4-years'), t1) - - t2 = [(datetime(2014, 1, 1, 0, 0), - datetime(2015, 2, 22, 10, 17, 8, 571428)), - (datetime(2015, 2, 22, 10, 17, 8, 571428), - datetime(2016, 4, 14, 20, 34, 17, 142857)), - (datetime(2016, 4, 14, 20, 34, 17, 142857), - datetime(2017, 6, 6, 6, 51, 25, 714285)), - (datetime(2017, 6, 6, 6, 51, 25, 714285), - datetime(2018, 7, 28, 17, 8, 34, 285714)), - (datetime(2018, 7, 28, 17, 8, 34, 285714), - datetime(2019, 9, 19, 3, 25, 42, 857142)), - (datetime(2019, 9, 19, 3, 25, 42, 857142), - datetime(2020, 11, 9, 13, 42, 51, 428571)), - (datetime(2020, 11, 9, 13, 42, 51, 428571), - datetime(2022, 1, 1, 0, 0))] - self.assertEqual(timewindows_ti(start_date=start, - end_date=end, - intervals=7), t2) + self.assertEqual(timewindows_ti(start_date=start, end_date=end), [(start, end)]) + + t1 = [ + (datetime(2014, 1, 1), datetime(2018, 1, 1)), + (datetime(2018, 1, 1), datetime(2022, 1, 1)), + ] + self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=2), t1) + self.assertEqual(timewindows_ti(start_date=start, end_date=end, horizon="4-years"), t1) + self.assertEqual(timewindows_ti(start_date=start, intervals=2, horizon="4-years"), t1) + + t2 = [ + (datetime(2014, 1, 1, 0, 0), datetime(2015, 2, 22, 10, 17, 8, 571428)), + ( + datetime(2015, 2, 22, 10, 17, 8, 571428), + datetime(2016, 4, 14, 20, 34, 17, 142857), + ), + ( + datetime(2016, 4, 14, 20, 34, 17, 142857), + datetime(2017, 6, 6, 6, 51, 25, 714285), + ), + (datetime(2017, 6, 6, 6, 51, 25, 714285), datetime(2018, 7, 28, 17, 8, 34, 285714)), + ( + datetime(2018, 7, 28, 17, 8, 34, 285714), + datetime(2019, 9, 19, 3, 25, 42, 857142), + ), + ( + datetime(2019, 9, 19, 3, 25, 42, 857142), + datetime(2020, 11, 9, 13, 42, 51, 428571), + ), + (datetime(2020, 11, 9, 13, 42, 51, 428571), datetime(2022, 1, 1, 0, 0)), + ] + self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=7), t2) def test_timewindows_ti_td(self): pass @@ -94,20 +108,15 @@ def test_read_time_config(self): start = datetime(2014, 1, 1) end = datetime(2022, 1, 1) intervals = 2 - config = {'start_date': start, - 'end_date': end, - 'intervals': intervals} - self.assertEqual(read_time_cfg(None, **config), - read_time_cfg(config)) - - short_config = {'start_date': start, - 'end_date': end} - time_config = {'intervals': 2} - full_config = {'start_date': start, - 'end_date': end, - 'intervals': 2} - self.assertEqual(read_time_cfg(time_config, **short_config), - read_time_cfg(None, **full_config)) + config = {"start_date": start, "end_date": end, "intervals": intervals} + self.assertEqual(read_time_cfg(None, **config), read_time_cfg(config)) + + short_config = {"start_date": start, "end_date": end} + time_config = {"intervals": 2} + full_config = {"start_date": start, "end_date": end, "intervals": 2} + self.assertEqual( + read_time_cfg(time_config, **short_config), read_time_cfg(None, **full_config) + ) class RegionUtilsTest(unittest.TestCase): @@ -117,50 +126,51 @@ def test_magnitudes_depth(self): mag_min = 1 mag_max = 1.2 mag_bin = 0.1 - depth_max = 1. - depth_min = 0. + depth_max = 1.0 + depth_min = 0.0 - config = {'mag_min': mag_min, - 'mag_max': mag_max, - 'mag_bin': mag_bin, - 'depth_min': depth_min, - 'depth_max': depth_max} + config = { + "mag_min": mag_min, + "mag_max": mag_max, + "mag_bin": mag_bin, + "depth_min": depth_min, + "depth_max": depth_max, + } region_config = read_region_cfg(config) self.assertEqual(8, len(region_config)) - numpy.testing.assert_equal(magnitudes, - region_config['magnitudes']) - numpy.testing.assert_equal(numpy.array([depth_min, depth_max]), - region_config['depths']) + numpy.testing.assert_equal(magnitudes, region_config["magnitudes"]) + numpy.testing.assert_equal(numpy.array([depth_min, depth_max]), region_config["depths"]) def test_region(self): - region_origins = numpy.array([[0, 0], - [0.1, 0], - [0.1, 0.1], - [0, 0.1]]) - region_path = os.path.join(os.path.dirname(__file__), '../artifacts', - 'regions', 'mock_region') - config = {'region': region_path, - 'mag_min': 1, - 'mag_max': 1.2, - 'mag_bin': 0.1, - 'depth_min': 0, - 'depth_max': 1} + region_origins = numpy.array([[0, 0], [0.1, 0], [0.1, 0.1], [0, 0.1]]) + region_path = os.path.join( + os.path.dirname(__file__), "../artifacts", "regions", "mock_region" + ) + config = { + "region": region_path, + "mag_min": 1, + "mag_max": 1.2, + "mag_bin": 0.1, + "depth_min": 0, + "depth_max": 1, + } region_config = read_region_cfg(config) self.assertEqual(9, len(region_config)) - numpy.testing.assert_equal(region_origins, - region_config['region'].origins()) - - config = {'region': 'italy_csep_region', - 'mag_min': 1, - 'mag_max': 1.2, - 'mag_bin': 0.1, - 'depth_min': 0, - 'depth_max': 1} - region_path = os.path.join(os.path.dirname(__file__), '../artifacts', - 'regions', 'italy_midpoints') + numpy.testing.assert_equal(region_origins, region_config["region"].origins()) + + config = { + "region": "italy_csep_region", + "mag_min": 1, + "mag_max": 1.2, + "mag_bin": 0.1, + "depth_min": 0, + "depth_max": 1, + } + region_path = os.path.join( + os.path.dirname(__file__), "../artifacts", "regions", "italy_midpoints" + ) midpoints = numpy.genfromtxt(region_path) region_config = read_region_cfg(config) - numpy.testing.assert_almost_equal(midpoints, - region_config['region'].midpoints()) + numpy.testing.assert_almost_equal(midpoints, region_config["region"].midpoints())