diff --git a/examples/case_g/models.yml b/examples/case_g/models.yml
index 90e1247..47e8a3e 100644
--- a/examples/case_g/models.yml
+++ b/examples/case_g/models.yml
@@ -1,7 +1,7 @@
- pymock:
path: pymock
- func: python run.py
+ func: pymock
func_kwargs:
n_sims: 100
mag_min: 3.5
- build: pip
+ build: venv
diff --git a/floatcsep/environments.py b/floatcsep/environments.py
new file mode 100644
index 0000000..c532dd6
--- /dev/null
+++ b/floatcsep/environments.py
@@ -0,0 +1,478 @@
+import configparser
+import hashlib
+import logging
+import os
+import shutil
+import subprocess
+import sys
+import venv
+from abc import ABC, abstractmethod
+
+from packaging.specifiers import SpecifierSet
+
+log = logging.getLogger("floatLogger")
+
+
+class EnvironmentManager(ABC):
+ """
+ Abstract base class for managing different types of environments.
+ This class defines the interface for creating, checking existence,
+ running commands, and installing dependencies in various environment types.
+ """
+
+ @abstractmethod
+ def __init__(self, base_name: str, model_directory: str):
+ """
+ Initializes the environment manager with a base name and model directory.
+
+ Args:
+ base_name (str): The base name for the environment.
+ model_directory (str): The directory containing the model files.
+ """
+ self.base_name = base_name
+ self.model_directory = model_directory
+
+ @abstractmethod
+ def create_environment(self, force=False):
+ """
+ Creates the environment. If 'force' is True, it will remove any existing
+ environment with the same name before creating a new one.
+
+ Args:
+ force (bool): Whether to forcefully remove an existing environment.
+ """
+ pass
+
+ @abstractmethod
+ def env_exists(self):
+ """
+ Checks if the environment already exists.
+
+ Returns:
+ bool: True if the environment exists, False otherwise.
+ """
+ pass
+
+ @abstractmethod
+ def run_command(self, command):
+ """
+ Executes a command within the context of the environment.
+
+ Args:
+ command (str): The command to be executed.
+ """
+ pass
+
+ @abstractmethod
+ def install_dependencies(self):
+ """
+ Installs the necessary dependencies for the environment based on the
+ specified configuration or requirements.
+ """
+ pass
+
+ def generate_env_name(self) -> str:
+ """
+ Generates a unique environment name by hashing the model directory
+ and appending it to the base name.
+
+ Returns:
+ str: A unique name for the environment.
+ """
+ dir_hash = hashlib.md5(self.model_directory.encode()).hexdigest()[:8]
+ return f"{self.base_name}_{dir_hash}"
+
+
+class CondaEnvironmentManager(EnvironmentManager):
+ """
+ Manages a conda (or mamba) environment, providing methods to create, check,
+ and manipulate conda environments specifically.
+ """
+
+ def __init__(self, base_name: str, model_directory: str):
+ """
+ Initializes the Conda environment manager with the specified base name
+ and model directory. It also generates the environment name and detects
+ the package manager (conda or mamba) to install dependencies..
+
+ Args:
+ base_name (str): The base name, i.e., model name, for the conda environment.
+ model_directory (str): The directory containing the model files.
+ """
+ self.base_name = base_name
+ self.model_directory = model_directory
+ self.env_name = self.generate_env_name()
+ self.package_manager = self.detect_package_manager()
+
+ @staticmethod
+ def detect_package_manager():
+ """
+ Detects whether 'mamba' or 'conda' is available as the package manager.
+
+ Returns:
+ str: The name of the detected package manager ('mamba' or 'conda').
+ """
+ if shutil.which("mamba"):
+ log.info("Mamba detected, using mamba as package manager.")
+ return "mamba"
+ log.info("Mamba not detected, using conda as package manager.")
+ return "conda"
+
+ def create_environment(self, force=False):
+ """
+ Creates a conda environment using either an environment.yml file or
+ the specified Python version in setup.py/setup.cfg or project/toml.
+ If 'force' is True, any existing environment with the same name will
+ be removed first.
+
+ Args:
+ force (bool): Whether to forcefully remove an existing environment.
+ """
+ if force and self.env_exists():
+ log.info(f"Removing existing conda environment: {self.env_name}")
+ subprocess.run(
+ [
+ self.package_manager,
+ "env",
+ "remove",
+ "--name",
+ self.env_name,
+ "--yes",
+ ]
+ )
+
+ if not self.env_exists():
+ env_file = os.path.join(self.model_directory, "environment.yml")
+ if os.path.exists(env_file):
+ log.info(f"Creating sub-conda environment {self.env_name} from environment.yml")
+ subprocess.run(
+ [
+ self.package_manager,
+ "env",
+ "create",
+ "--name",
+ self.env_name,
+ "--file",
+ env_file,
+ ]
+ )
+ else:
+ python_version = self.detect_python_version()
+ log.info(
+ f"Creating sub-conda environment {self.env_name} with Python {python_version}"
+ )
+ subprocess.run(
+ [
+ self.package_manager,
+ "create",
+ "--name",
+ self.env_name,
+ "--yes",
+ f"python={python_version}",
+ ]
+ )
+ log.info(f"\tSub-conda environment created: {self.env_name}")
+
+ self.install_dependencies()
+
+ def env_exists(self) -> bool:
+ """
+ Checks if the conda environment exists by querying the list of
+ existing conda environments.
+
+ Returns:
+ bool: True if the conda environment exists, False otherwise.
+ """
+ result = subprocess.run(["conda", "env", "list"], stdout=subprocess.PIPE)
+ return self.env_name in result.stdout.decode()
+
+ def detect_python_version(self) -> str:
+ """
+ Determines the required Python version from setup files in the model directory.
+ It checks 'setup.py', 'pyproject.toml', and 'setup.cfg' (in that order), for
+ version specifications.
+
+ Returns:
+ str: The detected or default Python version.
+ """
+ setup_py = os.path.join(self.model_directory, "setup.py")
+ pyproject_toml = os.path.join(self.model_directory, "pyproject.toml")
+ setup_cfg = os.path.join(self.model_directory, "setup.cfg")
+ current_python_version = (
+ f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
+ )
+
+ def parse_version(version_str):
+ # Extract the first valid version number
+ import re
+
+ match = re.search(r"\d+(.\d+)*", version_str)
+ return match.group(0) if match else current_python_version
+
+ def is_version_compatible(requirement, current_version):
+ try:
+ specifier = SpecifierSet(requirement)
+ return current_version in specifier
+ except Exception as e:
+ log.error(f"Invalid specifier: {requirement}. Error: {e}")
+ return False
+
+ if os.path.exists(setup_py):
+ with open(setup_py) as f:
+ for line in f:
+ if "python_requires" in line:
+ required_version = line.split("=")[1].strip()
+ if is_version_compatible(required_version, current_python_version):
+ log.info(f"Using current Python version: {current_python_version}")
+ return current_python_version
+ return parse_version(required_version)
+
+ if os.path.exists(pyproject_toml):
+ with open(pyproject_toml) as f:
+ for line in f:
+ if "python" in line and "=" in line:
+ required_version = line.split("=")[1].strip()
+ if is_version_compatible(required_version, current_python_version):
+ log.info(f"Using current Python version: {current_python_version}")
+ return current_python_version
+ return parse_version(required_version)
+
+ if os.path.exists(setup_cfg):
+ config = configparser.ConfigParser()
+ config.read(setup_cfg)
+ if "options" in config and "python_requires" in config["options"]:
+ required_version = config["options"]["python_requires"].strip()
+ if is_version_compatible(required_version, current_python_version):
+ log.info(f"Using current Python version: {current_python_version}")
+ return current_python_version
+ return parse_version(required_version)
+
+ return current_python_version
+
+ def install_dependencies(self):
+ """
+ Installs dependencies in the conda environment using pip, based on the
+ model setup file
+ """
+ log.info(f"Installing dependencies in conda environment: {self.env_name}")
+ cmd = [
+ self.package_manager,
+ "run",
+ "-n",
+ self.env_name,
+ "pip",
+ "install",
+ "-e",
+ self.model_directory,
+ ]
+ subprocess.run(cmd, check=True)
+
+ def run_command(self, command):
+ """
+ Runs a specified command within the conda environment
+ Args:
+ command (str): The command to be executed in the conda environment.
+ """
+ cmd = [
+ "bash",
+ "-c",
+ f"{self.package_manager} run -n {self.env_name} {command}",
+ ]
+ process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True,
+ )
+ for line in process.stdout:
+ log.info(f"\t{line[:-1]}")
+ process.wait()
+
+
+class VenvEnvironmentManager(EnvironmentManager):
+ """
+ Manages a virtual environment created using Python's venv module.
+ Provides methods to create, check, and manipulate virtual environments.
+ """
+
+ def __init__(self, base_name: str, model_directory: str):
+ """
+ Initializes the virtual environment manager with the specified base name
+ and model directory.
+
+ Args:
+ base_name (str): The base name (i.e., model name) for the virtual environment.
+ model_directory (str): The directory containing the model files.
+ """
+
+ self.base_name = base_name
+ self.model_directory = model_directory
+ self.env_name = self.generate_env_name()
+ self.env_path = os.path.join(model_directory, self.env_name)
+
+ def create_environment(self, force=False):
+ """
+ Creates a virtual environment in the specified model directory. If 'force'
+ is True, any existing virtual environment will be removed before creation.
+
+ Args:
+ force (bool): Whether to forcefully remove an existing virtual environment.
+ """
+ if force and self.env_exists():
+ log.info(f"Removing existing virtual environment: {self.env_name}")
+ shutil.rmtree(self.env_path)
+
+ if not self.env_exists():
+ log.info(f"Creating virtual environment: {self.env_name}")
+ venv.create(self.env_path, with_pip=True)
+ log.info(f"\tVirtual environment created: {self.env_name}")
+ self.install_dependencies()
+
+ def env_exists(self) -> bool:
+ """
+ Checks if the virtual environment exists by verifying the presence of its directory.
+
+ Returns:
+ bool: True if the virtual environment exists, False otherwise.
+ """
+ return os.path.isdir(self.env_path)
+
+ def install_dependencies(self):
+ """
+ Installs dependencies in the virtual environment using pip, based on the
+ model directory's configuration.
+ """
+ log.info(f"Installing dependencies in virtual environment: {self.env_name}")
+ pip_executable = os.path.join(self.env_path, "bin", "pip")
+ cmd = f"{pip_executable} install -e {os.path.abspath(self.model_directory)}"
+ self.run_command(cmd)
+
+ def run_command(self, command):
+ """
+ Executes a specified command in the virtual environment and logs the output.
+
+ Args:
+ command (str): The command to be executed in the virtual environment.
+ """
+ activate_script = os.path.join(self.env_path, "bin", "activate")
+
+ virtualenv = os.environ.copy()
+ virtualenv.pop("PYTHONPATH", None)
+ virtualenv["VIRTUAL_ENV"] = self.env_path
+ virtualenv["PATH"] = (
+ os.path.join(self.env_path, "bin") + os.pathsep + virtualenv.get("PATH", "")
+ )
+
+ full_command = f"bash -c 'source {activate_script}' && {command}"
+
+ process = subprocess.Popen(
+ full_command,
+ shell=True,
+ env=virtualenv,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True,
+ )
+ for line in process.stdout:
+ stripped_line = line.strip()
+ log.info(stripped_line)
+ process.wait()
+
+
+class DockerEnvironmentManager(EnvironmentManager):
+ """
+ Manages a Docker environment, providing methods to create, check,
+ and manipulate Docker containers for the environment.
+ """
+
+ def __init__(self, base_name: str, model_directory: str):
+ self.base_name = base_name
+ self.model_directory = model_directory
+
+ def create_environment(self, force=False):
+ pass
+
+ def env_exists(self):
+ pass
+
+ def run_command(self, command):
+ pass
+
+ def install_dependencies(self):
+ pass
+
+
+class EnvironmentFactory:
+ """
+ Factory class for creating instances of environment managers based on the specified type.
+ """
+
+ @staticmethod
+ def get_env(
+ build: str = None, model_name: str = "model", model_path: str = None
+ ) -> EnvironmentManager:
+ """
+ Returns an instance of an environment manager based on the specified build type.
+ It checks the current environment type and can return a conda, venv, or Docker
+ environment manager.
+
+ Args:
+ build (str): The desired type of environment ('conda', 'venv', or 'docker').
+ model_name (str): The name of the model for which the environment is being created.
+ model_path (str): The path to the model directory.
+
+ Returns:
+ EnvironmentManager: An instance of the appropriate environment manager.
+
+ Raises:
+ Exception: If an invalid environment type is specified.
+ """
+ run_env = EnvironmentFactory.check_environment_type()
+ if run_env != build and build and build != "docker":
+ log.warning(
+ f"Selected build environment ({build}) for this model is different than that of"
+ f" the experiment run. Consider selecting the same environment."
+ )
+ if build == "conda" or (not build and run_env == "conda"):
+ return CondaEnvironmentManager(
+ base_name=f"{model_name}",
+ model_directory=os.path.abspath(model_path),
+ )
+ elif build == "venv" or (not build and run_env == "venv"):
+ return VenvEnvironmentManager(
+ base_name=f"{model_name}",
+ model_directory=os.path.abspath(model_path),
+ )
+ elif build == "docker":
+ return DockerEnvironmentManager(
+ base_name=f"{model_name}",
+ model_directory=os.path.abspath(model_path),
+ )
+ else:
+ raise Exception(
+ "Wrong environment selection. Please choose between "
+ '"conda", "venv" or "docker".'
+ )
+
+ @staticmethod
+ def check_environment_type():
+ if "VIRTUAL_ENV" in os.environ:
+ return "venv"
+ try:
+ subprocess.run(
+ ["conda", "info"],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ return "conda"
+ except FileNotFoundError:
+ pass
+ return None
+
+
+if __name__ == "__main__":
+
+ env = EnvironmentFactory.get_env(
+ "conda", model_path="../examples/case_h/models/pymock_poisson"
+ )
+ env.create_environment(force=True)
diff --git a/floatcsep/model.py b/floatcsep/model.py
index 3d81bb8..93bb310 100644
--- a/floatcsep/model.py
+++ b/floatcsep/model.py
@@ -1,7 +1,6 @@
import json
import logging
import os
-import subprocess
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Callable, Union, Mapping, Sequence
@@ -13,6 +12,7 @@
from csep.utils.time_utils import decimal_year
from floatcsep.accessors import from_zenodo, from_git
+from floatcsep.environments import EnvironmentFactory
from floatcsep.readers import ForecastParsers, HDF5Serializer
from floatcsep.registry import ModelTree
from floatcsep.utils import timewindow2str, str2timewindow
@@ -124,9 +124,7 @@ def get_source(
elif giturl:
log.info(f"Retrieving model {self.name} from git url: " f"{giturl}")
try:
- from_git(
- giturl, self.dir if self.path.fmt else self.path("path"), **kwargs
- )
+ from_git(giturl, self.dir if self.path.fmt else self.path("path"), **kwargs)
except (git.NoSuchPathError, git.CommandError) as msg:
raise git.NoSuchPathError(f"git url was not found {msg}")
else:
@@ -177,9 +175,7 @@ def iter_attr(val):
return _get_value(val)
list_walk = [
- (i, j)
- for i, j in sorted(self.__dict__.items())
- if not i.startswith("_") and j
+ (i, j) for i, j in sorted(self.__dict__.items()) if not i.startswith("_") and j
]
dict_walk = {i: j for i, j in list_walk}
@@ -223,9 +219,7 @@ class TimeIndependentModel(Model):
store_db (bool): flag to indicate whether to store the model in a database.
"""
- def __init__(
- self, name: str, model_path: str, forecast_unit=1, store_db=False, **kwargs
- ):
+ def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, **kwargs):
super().__init__(name, model_path, **kwargs)
self.forecast_unit = forecast_unit
self.store_db = store_db
@@ -289,9 +283,7 @@ def stage(self, timewindows: Union[str, List[datetime]] = None) -> None:
def get_forecast(
self, tstring: Union[str, list] = None, region=None
- ) -> Union[
- GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]
- ]:
+ ) -> Union[GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]]:
"""
Wrapper that just returns a forecast when requested.
"""
@@ -333,9 +325,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None:
start_date, end_date = str2timewindow(tstring)
self.forecast_from_file(start_date, end_date, **kwargs)
- def forecast_from_file(
- self, start_date: datetime, end_date: datetime, **kwargs
- ) -> None:
+ def forecast_from_file(self, start_date: datetime, end_date: datetime, **kwargs) -> None:
"""
Generates a forecast from a file, by parsing and scaling it to.
@@ -403,42 +393,10 @@ def __init__(
self.build = kwargs.get("build", "docker")
self.run_prefix = ""
- def build_model(self):
-
- if self.build == "pip" or self.build == "venv":
- venv = os.path.join(self.path("path"), self.__dict__.get("venv", "venv"))
- venvact = os.path.join(venv, "bin", "activate")
-
- if not os.path.exists(venv):
- log.info(f"Building model {self.name} using pip")
- subprocess.run(["python", "-m", "venv", venv])
- log.info(f"\tVirtual environment created in {venv}")
- build_cmd = (
- f"source {venvact} && "
- f"pip install --upgrade pip && "
- f'pip install {self.path("path")}'
- )
-
- cmd = ["bash", "-c", build_cmd]
-
- log.info("\tInstalling dependencies")
-
- process = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- universal_newlines=True,
- )
- for line in process.stdout:
- log.info(f"\t{line[:-1]}")
- process.wait()
- log.info("\tEnvironment ready")
- log.warning(
- "\tNested environments is not fully supported. "
- "Consider using docker instead"
- )
-
- self.run_prefix = f'cd {self.path("path")} && source {venvact} && '
+ if self.func:
+ self.environment = EnvironmentFactory.get_env(
+ self.build, self.name, self.path.abs(self.model_path)
+ )
def stage(self, timewindows=None) -> None:
"""
@@ -450,7 +408,10 @@ def stage(self, timewindows=None) -> None:
- Run model quality assurance (unit tests, runnable from floatcsep)
"""
self.get_source(self.zenodo_id, self.giturl, branch=self.repo_hash)
- self.build_model()
+
+ if hasattr(self, "environment"):
+ self.environment.create_environment()
+
self.path.build_tree(
timewindows=timewindows,
model_class="td",
@@ -461,9 +422,7 @@ def stage(self, timewindows=None) -> None:
def get_forecast(
self, tstring: Union[str, list] = None, region=None
- ) -> Union[
- GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]
- ]:
+ ) -> Union[GriddedForecast, CatalogForecast, List[GriddedForecast], List[CatalogForecast]]:
"""Wrapper that just returns a forecast, hiding the access method under the hood"""
if isinstance(tstring, str):
@@ -511,9 +470,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None:
else:
log.info(f"Forecast of {tstring} of model {self.name} already " f"exists")
- def forecast_from_func(
- self, start_date: datetime, end_date: datetime, **kwargs
- ) -> None:
+ def forecast_from_func(self, start_date: datetime, end_date: datetime, **kwargs) -> None:
self.prepare_args(start_date, end_date, **kwargs)
log.info(
@@ -562,18 +519,7 @@ def replace_arg(arg, val, fp):
def run_model(self):
- if self.build == "pip" or self.build == "venv":
- run_func = f'{self.func} {self.path("args_file")}'
- cmd = ["bash", "-c", f"{self.run_prefix} {run_func}"]
- process = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- universal_newlines=True,
- )
- for line in process.stdout:
- log.info(f"\t{line[:-1]}")
- process.wait()
+ self.environment.run_command(f'{self.func} {self.path("args_file")}')
class ModelFactory:
diff --git a/floatcsep/registry.py b/floatcsep/registry.py
index 8ece9b2..b36c119 100644
--- a/floatcsep/registry.py
+++ b/floatcsep/registry.py
@@ -118,8 +118,7 @@ def build_tree(
# set forecast names
fc_files = {
- win: join(dirtree["forecasts"], f"{prefix}_{win}.csv")
- for win in windows
+ win: join(dirtree["forecasts"], f"{prefix}_{win}.csv") for win in windows
}
fc_exists = {
diff --git a/pyproject.toml b/pyproject.toml
index dc80b62..5eb4733 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,4 +6,9 @@ build-backend = "setuptools.build_meta"
addopts = "--cov=floatcsep"
testpaths = [
"tests",
-]
\ No newline at end of file
+]
+
+[tool.black]
+line-length = 96
+skip-string-normalization = false
+target-version = ["py39", "py310", "py311"]
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index d1cbbb8..81d20f8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,6 +5,8 @@ flake8
gitpython
h5py
matplotlib
+packaging
+pandas
pycsep
pyshp
pyyaml
diff --git a/requirements_dev.txt b/requirements_dev.txt
index 9cc3cc8..258b89e 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -1,5 +1,6 @@
numpy
cartopy
+black
dateparser
docker
flake8
@@ -10,9 +11,9 @@ matplotlib
mercantile
mypy
obspy
+packaging
pandas
pillow
-pyblack
pycsep
pydocstringformatter
pyproj
diff --git a/setup.cfg b/setup.cfg
index 427f38d..82af787 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -27,6 +27,8 @@ install_requires =
gitpython
h5py
matplotlib
+ packaging
+ pandas
pycsep
pyshp
pyyaml
@@ -55,6 +57,7 @@ dev =
mercantile
mypy
obspy
+ packaging
pandas
pillow
pycsep
@@ -63,7 +66,6 @@ dev =
pyshp
pytest
pytest-cov
- pytest-asyncio
pyyaml
requests
scipy
diff --git a/tests/artifacts/models/td_model/setup.cfg b/tests/artifacts/models/td_model/setup.cfg
new file mode 100644
index 0000000..e69de29
diff --git a/tests/qa/test_data.py b/tests/qa/test_data.py
index b043987..96dd1d2 100644
--- a/tests/qa/test_data.py
+++ b/tests/qa/test_data.py
@@ -1,5 +1,7 @@
from floatcsep.cmd import main
+from floatcsep.experiment import Experiment
import unittest
+from unittest.mock import patch
import os
@@ -33,77 +35,58 @@ def get_eval_dist(self):
pass
+@patch.object(Experiment, "generate_report")
+@patch.object(Experiment, "plot_forecasts")
+@patch.object(Experiment, "plot_catalog")
class RunExamples(DataTest):
- def test_case_a(self):
+ def test_case_a(self, *args):
cfg = self.get_runpath('a')
self.run_evaluation(cfg)
self.assertEqual(1, 1)
- def test_case_b(self):
+ def test_case_b(self, *args):
cfg = self.get_runpath('b')
self.run_evaluation(cfg)
self.assertEqual(1, 1)
- def test_case_c(self):
+ def test_case_c(self, *args):
cfg = self.get_runpath('c')
self.run_evaluation(cfg)
self.assertEqual(1, 1)
- def test_case_d(self):
+ def test_case_d(self, *args):
cfg = self.get_runpath('d')
self.run_evaluation(cfg)
self.assertEqual(1, 1)
- def test_case_e(self):
+ def test_case_e(self, *args):
cfg = self.get_runpath('e')
self.run_evaluation(cfg)
self.assertEqual(1, 1)
- def test_case_f(self):
+ def test_case_f(self, *args):
cfg = self.get_runpath('f')
self.run_evaluation(cfg)
self.assertEqual(1, 1)
- def test_case_g(self):
+ def test_case_g(self, *args):
cfg = self.get_runpath('g')
self.run_evaluation(cfg)
self.assertEqual(1, 1)
+@patch.object(Experiment, "generate_report")
+@patch.object(Experiment, "plot_forecasts")
+@patch.object(Experiment, "plot_catalog")
class ReproduceExamples(DataTest):
- def test_case_a(self):
- cfg = self.get_rerunpath('a')
- self.repr_evaluation(cfg)
- self.assertEqual(1, 1)
-
- def test_case_b(self):
- cfg = self.get_rerunpath('b')
- self.repr_evaluation(cfg)
- self.assertEqual(1, 1)
-
- def test_case_c(self):
+ def test_case_c(self, *args):
cfg = self.get_rerunpath('c')
self.repr_evaluation(cfg)
self.assertEqual(1, 1)
- def test_case_d(self):
- cfg = self.get_rerunpath('d')
- self.repr_evaluation(cfg)
- self.assertEqual(1, 1)
-
- def test_case_e(self):
- cfg = self.get_rerunpath('e')
- self.repr_evaluation(cfg)
- self.assertEqual(1, 1)
-
- def test_case_f(self):
+ def test_case_f(self, *args):
cfg = self.get_rerunpath('f')
self.repr_evaluation(cfg)
self.assertEqual(1, 1)
-
- def test_case_g(self):
- cfg = self.get_rerunpath('g')
- self.repr_evaluation(cfg)
- self.assertEqual(1, 1)
\ No newline at end of file
diff --git a/tests/unit/test_accessors.py b/tests/unit/test_accessors.py
index a2ab143..260009b 100644
--- a/tests/unit/test_accessors.py
+++ b/tests/unit/test_accessors.py
@@ -1,20 +1,20 @@
import os.path
import vcr
from datetime import datetime
-from floatcsep.accessors import query_gcmt, _query_gcmt, from_zenodo, \
- from_git, _check_hash
+from floatcsep.accessors import query_gcmt, _query_gcmt, from_zenodo, from_git, _check_hash
import unittest
from unittest import mock
root_dir = os.path.dirname(os.path.abspath(__file__))
+
def gcmt_dir():
- data_dir = os.path.join(root_dir, '../artifacts', 'gcmt')
+ data_dir = os.path.join(root_dir, "../artifacts", "gcmt")
return data_dir
def zenodo_dir():
- data_dir = os.path.join(root_dir, '../artifacts', 'zenodo')
+ data_dir = os.path.join(root_dir, "../artifacts", "zenodo")
return data_dir
@@ -23,32 +23,31 @@ class TestCatalogGetter(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
os.makedirs(gcmt_dir(), exist_ok=True)
- cls._fname = os.path.join(gcmt_dir(), 'test_cat')
+ cls._fname = os.path.join(gcmt_dir(), "test_cat")
def test_gcmt_search(self):
- tape_file = os.path.join(gcmt_dir(), 'vcr_search.yaml')
+ tape_file = os.path.join(gcmt_dir(), "vcr_search.yaml")
with vcr.use_cassette(tape_file):
# Maule, Chile
- eventlist = \
- _query_gcmt(start_time=datetime(2010, 2, 26),
- end_time=datetime(2010, 3, 2),
- min_magnitude=6)
+ eventlist = _query_gcmt(
+ start_time=datetime(2010, 2, 26), end_time=datetime(2010, 3, 2), min_magnitude=6
+ )
event = eventlist[0]
- assert event[0] == '2844986'
+ assert event[0] == "2844986"
def test_gcmt_summary(self):
- tape_file = os.path.join(gcmt_dir(), 'vcr_summary.yaml')
+ tape_file = os.path.join(gcmt_dir(), "vcr_summary.yaml")
with vcr.use_cassette(tape_file):
- eventlist = \
- _query_gcmt(start_time=datetime(2010, 2, 26),
- end_time=datetime(2010, 3, 2),
- min_magnitude=7)
+ eventlist = _query_gcmt(
+ start_time=datetime(2010, 2, 26), end_time=datetime(2010, 3, 2), min_magnitude=7
+ )
event = eventlist[0]
cmp = "('2844986', 1267252514000, -35.98, -73.15, 23.2, 8.8)"
assert str(event) == cmp
- assert event[0] == '2844986'
- assert datetime.fromtimestamp(
- event[1] / 1000.) == datetime.fromtimestamp(1267252514)
+ assert event[0] == "2844986"
+ assert datetime.fromtimestamp(event[1] / 1000.0) == datetime.fromtimestamp(
+ 1267252514
+ )
assert event[2] == -35.98
assert event[3] == -73.15
assert event[4] == 23.2
@@ -57,54 +56,52 @@ def test_gcmt_summary(self):
def test_catalog_query_plot(self):
start_datetime = datetime(2020, 1, 1)
end_datetime = datetime(2020, 3, 1)
- catalog = query_gcmt(start_time=start_datetime,
- end_time=end_datetime,
- min_magnitude=5.95)
- catalog.plot(set_global=True, plot_args={'filename': self._fname,
- 'basemap': 'stock_img'})
- assert os.path.isfile(self._fname + '.png')
- assert os.path.isfile(self._fname + '.pdf')
+ catalog = query_gcmt(
+ start_time=start_datetime, end_time=end_datetime, min_magnitude=5.95
+ )
+ catalog.plot(
+ set_global=True, plot_args={"filename": self._fname, "basemap": "stock_img"}
+ )
+ assert os.path.isfile(self._fname + ".png")
+ assert os.path.isfile(self._fname + ".pdf")
@classmethod
def tearDownClass(cls) -> None:
try:
- os.remove(os.path.join(gcmt_dir(), cls._fname + '.pdf'))
- os.remove(os.path.join(gcmt_dir(), cls._fname + '.png'))
+ os.remove(os.path.join(gcmt_dir(), cls._fname + ".pdf"))
+ os.remove(os.path.join(gcmt_dir(), cls._fname + ".png"))
except OSError:
pass
+
class TestZenodoGetter(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
os.makedirs(zenodo_dir(), exist_ok=True)
- cls._txt = os.path.join(zenodo_dir(), 'dummy.txt')
- cls._tar = os.path.join(zenodo_dir(), 'dummy.tar')
+ cls._txt = os.path.join(zenodo_dir(), "dummy.txt")
+ cls._tar = os.path.join(zenodo_dir(), "dummy.tar")
def test_zenodo_query(self):
from_zenodo(4739912, zenodo_dir())
assert os.path.isfile(self._txt)
assert os.path.isfile(self._tar)
- with open(self._txt, 'r') as dummy:
- assert dummy.readline() == 'test'
- _check_hash(self._tar, 'md5:17f80d606ff085751998ac4050cc614c')
+ with open(self._txt, "r") as dummy:
+ assert dummy.readline() == "test"
+ _check_hash(self._tar, "md5:17f80d606ff085751998ac4050cc614c")
@classmethod
def tearDownClass(cls) -> None:
- os.remove(os.path.join(zenodo_dir(), 'dummy.txt'))
- os.remove(os.path.join(zenodo_dir(), 'dummy.tar'))
+ os.remove(os.path.join(zenodo_dir(), "dummy.txt"))
+ os.remove(os.path.join(zenodo_dir(), "dummy.tar"))
os.rmdir(zenodo_dir())
class TestGitter(unittest.TestCase):
- @mock.patch('floatcsep.accessors.git.Repo')
- @mock.patch('git.Git')
+ @mock.patch("floatcsep.accessors.git.Repo")
+ @mock.patch("git.Git")
def runTest(self, mock_git, mock_repo):
p = mock.PropertyMock(return_value=False)
type(mock_repo.clone_from.return_value).bare = p
- from_git(
- '/tmp/testrepo',
- 'git@github.com:github/testrepo.git',
- 'master'
- )
- mock_git.checkout.called_once_with('master')
+ from_git("/tmp/testrepo", "git@github.com:github/testrepo.git", "master")
+ mock_git.checkout.called_once_with("master")
diff --git a/tests/unit/test_environments.py b/tests/unit/test_environments.py
new file mode 100644
index 0000000..cef9511
--- /dev/null
+++ b/tests/unit/test_environments.py
@@ -0,0 +1,352 @@
+import venv
+import unittest
+import subprocess
+import os
+from unittest.mock import patch, MagicMock, call, mock_open
+import shutil
+import hashlib
+import logging
+from floatcsep.environments import (
+ CondaEnvironmentManager,
+ EnvironmentFactory,
+ VenvEnvironmentManager,
+ DockerEnvironmentManager,
+)
+
+
+class TestCondaEnvironmentManager(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ if not shutil.which("conda"):
+ raise unittest.SkipTest("Conda is not available in the environment.")
+
+ def setUp(self):
+ self.manager = CondaEnvironmentManager(
+ base_name="test_env", model_directory="/tmp/test_model"
+ )
+ os.makedirs("/tmp/test_model", exist_ok=True)
+ with open("/tmp/test_model/environment.yml", "w") as f:
+ f.write("name: test_env\ndependencies:\n - python=3.8\n - numpy")
+ with open("/tmp/test_model/setup.py", "w") as f:
+ f.write("from setuptools import setup\nsetup(name='test_model', version='0.1')")
+
+ def tearDown(self):
+ if self.manager.env_exists():
+ subprocess.run(
+ ["conda", "env", "remove", "--name", self.manager.env_name, "--yes"],
+ check=True,
+ )
+ if os.path.exists("/tmp/test_model"):
+ shutil.rmtree("/tmp/test_model")
+
+ @patch("subprocess.run")
+ @patch("shutil.which", return_value="conda")
+ def test_generate_env_name(self, mock_which, mock_run):
+ manager = CondaEnvironmentManager("test_base", "/path/to/model")
+ expected_name = "test_base_" + hashlib.md5("/path/to/model".encode()).hexdigest()[:8]
+ print(expected_name)
+ self.assertEqual(manager.generate_env_name(), expected_name)
+
+ @patch("subprocess.run")
+ def test_env_exists(self, mock_run):
+ hashed = hashlib.md5("/path/to/model".encode()).hexdigest()[:8]
+ mock_run.return_value.stdout = f"test_base_{hashed}\n".encode()
+
+ manager = CondaEnvironmentManager("test_base", "/path/to/model")
+ self.assertTrue(manager.env_exists())
+
+ @patch("subprocess.run")
+ @patch("os.path.exists", return_value=True)
+ def test_create_environment(self, mock_exists, mock_run):
+ manager = CondaEnvironmentManager("test_base", "/path/to/model")
+ manager.create_environment(force=False)
+ package_manager = manager.detect_package_manager()
+ expected_calls = [
+ call(["conda", "env", "list"], stdout=-1),
+ call().stdout.decode(),
+ call().stdout.decode().__contains__(manager.env_name),
+ call(
+ [
+ package_manager,
+ "env",
+ "create",
+ "--name",
+ manager.env_name,
+ "--file",
+ "/path/to/model/environment.yml",
+ ]
+ ),
+ call(
+ [
+ package_manager,
+ "run",
+ "-n",
+ manager.env_name,
+ "pip",
+ "install",
+ "-e",
+ "/path/to/model",
+ ],
+ check=True,
+ ),
+ ]
+
+ self.assertEqual(mock_run.call_count, 3)
+ mock_run.assert_has_calls(expected_calls, any_order=False)
+
+ @patch("subprocess.run")
+ def test_create_environment_force(self, mock_run):
+ manager = CondaEnvironmentManager("test_base", "/path/to/model")
+ manager.env_exists = MagicMock(return_value=True)
+ manager.create_environment(force=True)
+ self.assertEqual(mock_run.call_count, 2) # One for remove, one for create
+
+ @patch("subprocess.run")
+ @patch.object(CondaEnvironmentManager, "detect_package_manager", return_value="conda")
+ def test_install_dependencies(self, mock_detect_package_manager, mock_run):
+ manager = CondaEnvironmentManager("test_base", "/path/to/model")
+ manager.install_dependencies()
+ mock_run.assert_called_once_with(
+ [
+ "conda",
+ "run",
+ "-n",
+ manager.env_name,
+ "pip",
+ "install",
+ "-e",
+ "/path/to/model",
+ ],
+ check=True,
+ )
+
+ @patch("shutil.which", return_value="conda")
+ @patch("os.path.exists", side_effect=[False, False, True])
+ @patch(
+ "builtins.open",
+ new_callable=mock_open,
+ read_data="[metadata]\nname = test\n\n[options]\ninstall_requires =\n numpy\npython_requires = >=3.9,<3.12\n",
+ )
+ def test_detect_python_version_setup_cfg(self, mock_open, mock_exists, mock_which):
+ manager = CondaEnvironmentManager("test_base", "../artifacts/models/td_model")
+ python_version = manager.detect_python_version()
+
+ # Extract major and minor version parts
+ major_minor_version = ".".join(python_version.split(".")[:2])
+
+ self.assertIn(
+ major_minor_version, ["3.9", "3.10", "3.11"]
+ ) # Check if it falls within the specified range
+
+ def test_create_and_delete_environment(self):
+ # Create the environment
+ self.manager.create_environment(force=True)
+
+ # Check if the environment was created
+ result = subprocess.run(["conda", "env", "list"], stdout=subprocess.PIPE, check=True)
+ self.assertIn(self.manager.env_name, result.stdout.decode())
+
+ # Check if numpy is installed
+ result = subprocess.run(
+ [
+ "conda",
+ "run",
+ "-n",
+ self.manager.env_name,
+ "python",
+ "-c",
+ "import numpy",
+ ],
+ check=True,
+ )
+ self.assertEqual(result.returncode, 0)
+
+ # Delete the environment
+ self.manager.create_environment(
+ force=True
+ ) # This should remove and recreate the environment
+
+ # Check if the environment was recreated
+ result = subprocess.run(["conda", "env", "list"], stdout=subprocess.PIPE, check=True)
+ self.assertIn(self.manager.env_name, result.stdout.decode())
+
+
+class TestEnvironmentFactory(unittest.TestCase):
+
+ @patch("os.path.abspath", return_value="/absolute/path/to/model")
+ @patch.object(EnvironmentFactory, "check_environment_type", return_value="conda")
+ def test_get_env_conda(self, mock_check_env, mock_abspath):
+ env_manager = EnvironmentFactory.get_env(
+ build="conda", model_name="test_model", model_path="/path/to/model"
+ )
+ self.assertIsInstance(env_manager, CondaEnvironmentManager)
+ self.assertEqual(env_manager.base_name, "test_model")
+ self.assertEqual(env_manager.model_directory, "/absolute/path/to/model")
+
+ @patch("os.path.abspath", return_value="/absolute/path/to/model")
+ @patch.object(EnvironmentFactory, "check_environment_type", return_value="venv")
+ def test_get_env_venv(self, mock_check_env, mock_abspath):
+ env_manager = EnvironmentFactory.get_env(
+ build="venv", model_name="test_model", model_path="/path/to/model"
+ )
+ self.assertIsInstance(env_manager, VenvEnvironmentManager)
+ self.assertEqual(env_manager.base_name, "test_model")
+ self.assertEqual(env_manager.model_directory, "/absolute/path/to/model")
+
+ @patch("os.path.abspath", return_value="/absolute/path/to/model")
+ @patch.object(EnvironmentFactory, "check_environment_type", return_value=None)
+ def test_get_env_docker(self, mock_check_env, mock_abspath):
+ env_manager = EnvironmentFactory.get_env(
+ build="docker", model_name="test_model", model_path="/path/to/model"
+ )
+ self.assertIsInstance(env_manager, DockerEnvironmentManager)
+ self.assertEqual(env_manager.base_name, "test_model")
+ self.assertEqual(env_manager.model_directory, "/absolute/path/to/model")
+
+ @patch("os.path.abspath", return_value="/absolute/path/to/model")
+ @patch.object(EnvironmentFactory, "check_environment_type", return_value="conda")
+ def test_get_env_default_conda(self, mock_check_env, mock_abspath):
+ env_manager = EnvironmentFactory.get_env(
+ build=None, model_name="test_model", model_path="/path/to/model"
+ )
+ self.assertIsInstance(env_manager, CondaEnvironmentManager)
+ self.assertEqual(env_manager.base_name, "test_model")
+ self.assertEqual(env_manager.model_directory, "/absolute/path/to/model")
+
+ @patch("os.path.abspath", return_value="/absolute/path/to/model")
+ @patch.object(EnvironmentFactory, "check_environment_type", return_value="venv")
+ def test_get_env_default_venv(self, mock_check_env, mock_abspath):
+ env_manager = EnvironmentFactory.get_env(
+ build=None, model_name="test_model", model_path="/path/to/model"
+ )
+ self.assertIsInstance(env_manager, VenvEnvironmentManager)
+ self.assertEqual(env_manager.base_name, "test_model")
+ self.assertEqual(env_manager.model_directory, "/absolute/path/to/model")
+
+ @patch("os.path.abspath", return_value="/absolute/path/to/model")
+ @patch.object(EnvironmentFactory, "check_environment_type", return_value=None)
+ def test_get_env_invalid(self, mock_check_env, mock_abspath):
+ with self.assertRaises(Exception) as context:
+ EnvironmentFactory.get_env(
+ build="invalid", model_name="test_model", model_path="/path/to/model"
+ )
+ self.assertTrue("Wrong environment selection" in str(context.exception))
+
+ @patch("os.path.abspath", return_value="/absolute/path/to/model")
+ @patch.object(EnvironmentFactory, "check_environment_type", return_value="venv")
+ @patch("logging.Logger.warning")
+ def test_get_env_warning(self, mock_log_warning, mock_check_env, mock_abspath):
+ EnvironmentFactory.get_env(
+ build="conda", model_name="test_model", model_path="/path/to/model"
+ )
+ mock_log_warning.assert_called_once_with(
+ f"Selected build environment (conda) for this model is different than that of"
+ f" the experiment run. Consider selecting the same environment."
+ )
+
+
+class TestVenvEnvironmentManager(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # Check if venv is available (Python standard library)
+ if not hasattr(venv, "create"):
+ raise unittest.SkipTest("Venv is not available in the environment.")
+
+ def setUp(self):
+ self.model_directory = "/tmp/test_model"
+ self.manager = VenvEnvironmentManager(
+ base_name="test_env", model_directory=self.model_directory
+ )
+ os.makedirs(self.model_directory, exist_ok=True)
+ with open(os.path.join(self.model_directory, "setup.py"), "w") as f:
+ f.write("from setuptools import setup\nsetup(name='test_model', version='0.1')")
+ logging.disable(logging.CRITICAL)
+
+ def tearDown(self):
+ if self.manager.env_exists():
+ shutil.rmtree(self.manager.env_path)
+ if os.path.exists(self.model_directory):
+ shutil.rmtree(self.model_directory)
+
+ def test_create_and_delete_environment(self):
+ # Create the environment
+ self.manager.create_environment(force=True)
+
+ # Check if the environment was created
+ self.assertTrue(self.manager.env_exists())
+
+ # Check if pip is available in the environment
+ pip_executable = os.path.join(self.manager.env_path, "bin", "pip")
+ result = subprocess.run(
+ [pip_executable, "list"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ self.assertEqual(result.returncode, 0) # pip should run without errors
+
+ # Delete the environment
+ self.manager.create_environment(
+ force=True
+ ) # This should remove and recreate the environment
+
+ # Check if the environment was recreated
+ self.assertTrue(self.manager.env_exists())
+
+ def test_init(self):
+ self.assertEqual(self.manager.base_name, "test_env")
+ self.assertEqual(self.manager.model_directory, self.model_directory)
+ self.assertTrue(self.manager.env_name.startswith("test_env_"))
+
+ def test_env_exists(self):
+ self.assertFalse(self.manager.env_exists())
+ self.manager.create_environment(force=True)
+ self.assertTrue(self.manager.env_exists())
+
+ def test_create_environment(self):
+ self.manager.create_environment(force=True)
+ self.assertTrue(self.manager.env_exists())
+
+ def test_create_environment_force(self):
+ self.manager.create_environment(force=True)
+ env_path_before = self.manager.env_path
+ self.manager.create_environment(force=True)
+ self.assertTrue(self.manager.env_exists())
+ self.assertEqual(env_path_before, self.manager.env_path) # Ensure it's a new path
+
+ def test_install_dependencies(self):
+ self.manager.create_environment(force=True)
+ pip_executable = os.path.join(self.manager.env_path, "bin", "pip")
+ result = subprocess.run(
+ [pip_executable, "install", "-e", self.model_directory],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ self.assertEqual(result.returncode, 0) # pip should run without errors
+
+ @patch("subprocess.Popen")
+ def test_run_command(self, mock_popen):
+ # Arrange
+ mock_process = MagicMock()
+ mock_process.stdout = iter(["Output line 1\n", "Output line 2\n"])
+ mock_process.wait.return_value = None
+ mock_popen.return_value = mock_process
+
+ command = "echo test_command"
+
+ # Act
+ self.manager.run_command(command)
+
+ output_cmd = f"bash -c 'source {os.path.join(self.manager.env_path, 'bin', 'activate')}' && {command}"
+ # Assert
+ mock_popen.assert_called_once_with(
+ output_cmd,
+ shell=True,
+ env=unittest.mock.ANY,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/unit/test_evaluation.py b/tests/unit/test_evaluation.py
index 11aee79..4b26298 100644
--- a/tests/unit/test_evaluation.py
+++ b/tests/unit/test_evaluation.py
@@ -12,37 +12,36 @@ def setUpClass(cls) -> None:
def mock_eval():
return
- setattr(cls, 'mock_eval', mock_eval)
+ setattr(cls, "mock_eval", mock_eval)
@staticmethod
def init_noreg(name, func, **kwargs):
- """ Instantiates a model without using the @register deco,
- but mocks Model.Registry() attrs"""\
-
+ """Instantiates a model without using the @register deco,
+ but mocks Model.Registry() attrs"""
# deprecated
# evaluation = Evaluation.__new__(Evaluation)
# Evaluation.__init__.__wrapped__(self=evaluation,
# name=name,
# func=func,
# **kwargs)
- evaluation = Evaluation(name=name, func=func,
- **kwargs)
+ evaluation = Evaluation(name=name, func=func, **kwargs)
return evaluation
def test_init(self):
- name = 'N_test'
- eval_ = self.init_noreg(name=name,
- func=self.mock_eval)
+ name = "N_test"
+ eval_ = self.init_noreg(name=name, func=self.mock_eval)
self.assertIs(None, eval_.type)
- dict_ = {'name': 'N_test',
- 'func': self.mock_eval,
- 'func_kwargs': {},
- 'ref_model': None,
- 'plot_func': None,
- 'plot_args': None,
- 'plot_kwargs': None,
- 'markdown': '',
- '_type': None}
+ dict_ = {
+ "name": "N_test",
+ "func": self.mock_eval,
+ "func_kwargs": {},
+ "ref_model": None,
+ "plot_func": None,
+ "plot_args": None,
+ "plot_kwargs": None,
+ "markdown": "",
+ "_type": None,
+ }
self.assertEqual(dict_, eval_.__dict__)
def test_discrete_args(self):
@@ -56,32 +55,30 @@ def test_prepare_catalog(self):
def read_cat(_):
cat = Mock()
- cat.name = 'csep'
+ cat.name = "csep"
return cat
- with patch('csep.core.catalogs.CSEPCatalog.load_json', read_cat):
- region = 'CSEPRegion'
- forecast = MagicMock(name='forecast', region=region)
+ with patch("csep.core.catalogs.CSEPCatalog.load_json", read_cat):
+ region = "CSEPRegion"
+ forecast = MagicMock(name="forecast", region=region)
- catt = Evaluation.get_catalog('path_to_cat', forecast)
- self.assertEqual('csep', catt.name)
+ catt = Evaluation.get_catalog("path_to_cat", forecast)
+ self.assertEqual("csep", catt.name)
self.assertEqual(region, catt.region)
- region2 = 'definitelyNotCSEPregion'
- forecast2 = Mock(name='forecast', region=region2)
- cats = Evaluation.get_catalog(['path1', 'path2'],
- [forecast, forecast2])
+ region2 = "definitelyNotCSEPregion"
+ forecast2 = Mock(name="forecast", region=region2)
+ cats = Evaluation.get_catalog(["path1", "path2"], [forecast, forecast2])
self.assertIsInstance(cats, list)
- self.assertEqual(cats[0].name, 'csep')
- self.assertEqual(cats[0].region, 'CSEPRegion')
- self.assertEqual(cats[1].region, 'definitelyNotCSEPregion')
+ self.assertEqual(cats[0].name, "csep")
+ self.assertEqual(cats[0].region, "CSEPRegion")
+ self.assertEqual(cats[1].region, "definitelyNotCSEPregion")
with self.assertRaises(AttributeError):
- Evaluation.get_catalog('path1', [forecast, forecast2])
+ Evaluation.get_catalog("path1", [forecast, forecast2])
with self.assertRaises(IndexError):
- Evaluation.get_catalog(['path1', 'path2'],
- forecast)
+ Evaluation.get_catalog(["path1", "path2"], forecast)
assert True
def test_write_result(self):
diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py
index 2f417a5..2a1ebc0 100644
--- a/tests/unit/test_experiment.py
+++ b/tests/unit/test_experiment.py
@@ -9,19 +9,18 @@
from csep.core.catalogs import CSEPCatalog
_dir = os.path.dirname(__file__)
-_model_cfg = os.path.normpath(os.path.join(_dir, '../artifacts', 'models',
- 'model_cfg.yml'))
-_region = os.path.normpath(os.path.join(_dir, '../artifacts', 'regions',
- 'mock_region'))
-_time_config = {'start_date': datetime(2021, 1, 1),
- 'end_date': datetime(2022, 1, 1)}
-_region_config = {'region': _region,
- 'mag_max': 10.0,
- 'mag_min': 1.0,
- 'mag_bin': 0.1,
- 'depth_min': 0,
- 'depth_max': 1}
-_cat = os.path.normpath(os.path.join(_dir, '../artifacts', 'catalog.json'))
+_model_cfg = os.path.normpath(os.path.join(_dir, "../artifacts", "models", "model_cfg.yml"))
+_region = os.path.normpath(os.path.join(_dir, "../artifacts", "regions", "mock_region"))
+_time_config = {"start_date": datetime(2021, 1, 1), "end_date": datetime(2022, 1, 1)}
+_region_config = {
+ "region": _region,
+ "mag_max": 10.0,
+ "mag_min": 1.0,
+ "mag_bin": 0.1,
+ "depth_min": 0,
+ "depth_max": 1,
+}
+_cat = os.path.normpath(os.path.join(_dir, "../artifacts", "catalog.json"))
class TestExperiment(TestCase):
@@ -46,68 +45,72 @@ def assertEqualExperiment(self, exp_a, exp_b):
def init_no_wrap(name, path, **kwargs):
model = Experiment.__new__(Experiment)
- Experiment.__init__.__wrapped__(self=model, name=name,
- path=path, **kwargs)
+ Experiment.__init__.__wrapped__(self=model, name=name, path=path, **kwargs)
return Experiment
def test_init(self):
- exp_a = Experiment(**_time_config, **_region_config,
- catalog=_cat)
- exp_b = Experiment(time_config=_time_config,
- region_config=_region_config,
- catalog=_cat)
+ exp_a = Experiment(**_time_config, **_region_config, catalog=_cat)
+ exp_b = Experiment(time_config=_time_config, region_config=_region_config, catalog=_cat)
self.assertEqualExperiment(exp_a, exp_b)
def test_to_dict(self):
- time_config = {'start_date': datetime(2020, 1, 1),
- 'end_date': datetime(2021, 1, 1),
- 'horizon': '6 month',
- 'growth': 'cumulative'}
-
- region_config = {'region': 'california_relm_region',
- 'mag_max': 9.0,
- 'mag_min': 3.0,
- 'mag_bin': 0.1,
- 'depth_min': -2,
- 'depth_max': 70}
-
- exp_a = Experiment(name='test', **time_config, **region_config,
- catalog=_cat)
- dict_ = {'name': 'test',
- 'path': os.getcwd(),
- 'rundir': 'results',
- 'time_config':
- {'exp_class': 'ti',
- 'start_date': datetime(2020, 1, 1),
- 'end_date': datetime(2021, 1, 1),
- 'horizon': '6-months',
- 'growth': 'cumulative'},
- 'region_config': {
- 'region': 'california_relm_region',
- 'mag_max': 9.0,
- 'mag_min': 3.0,
- 'mag_bin': 0.1,
- 'depth_min': -2,
- 'depth_max': 70
- },
- 'catalog': os.path.relpath(_cat, os.getcwd())
- }
+ time_config = {
+ "start_date": datetime(2020, 1, 1),
+ "end_date": datetime(2021, 1, 1),
+ "horizon": "6 month",
+ "growth": "cumulative",
+ }
+
+ region_config = {
+ "region": "california_relm_region",
+ "mag_max": 9.0,
+ "mag_min": 3.0,
+ "mag_bin": 0.1,
+ "depth_min": -2,
+ "depth_max": 70,
+ }
+
+ exp_a = Experiment(name="test", **time_config, **region_config, catalog=_cat)
+ dict_ = {
+ "name": "test",
+ "path": os.getcwd(),
+ "rundir": "results",
+ "time_config": {
+ "exp_class": "ti",
+ "start_date": datetime(2020, 1, 1),
+ "end_date": datetime(2021, 1, 1),
+ "horizon": "6-months",
+ "growth": "cumulative",
+ },
+ "region_config": {
+ "region": "california_relm_region",
+ "mag_max": 9.0,
+ "mag_min": 3.0,
+ "mag_bin": 0.1,
+ "depth_min": -2,
+ "depth_max": 70,
+ },
+ "catalog": os.path.relpath(_cat, os.getcwd()),
+ }
self.assertEqual(dict_, exp_a.as_dict())
def test_to_yml(self):
- time_config = {'start_date': datetime(2021, 1, 1),
- 'end_date': datetime(2022, 1, 1),
- 'intervals': 12}
-
- region_config = {'region': 'california_relm_region',
- 'mag_max': 9.0,
- 'mag_min': 3.0,
- 'mag_bin': 0.1,
- 'depth_min': -2,
- 'depth_max': 70}
-
- exp_a = Experiment(**time_config, **region_config,
- catalog=_cat)
+ time_config = {
+ "start_date": datetime(2021, 1, 1),
+ "end_date": datetime(2022, 1, 1),
+ "intervals": 12,
+ }
+
+ region_config = {
+ "region": "california_relm_region",
+ "mag_max": 9.0,
+ "mag_min": 3.0,
+ "mag_bin": 0.1,
+ "depth_min": -2,
+ "depth_max": 70,
+ }
+
+ exp_a = Experiment(**time_config, **region_config, catalog=_cat)
file_ = tempfile.mkstemp()[1]
exp_a.to_yml(file_)
exp_b = Experiment.from_yml(file_)
@@ -120,49 +123,49 @@ def test_to_yml(self):
self.assertEqualExperiment(exp_a, exp_c)
def test_set_models(self):
- exp = Experiment(**_time_config, **_region_config,
- model_config=_model_cfg,
- catalog=_cat)
+ exp = Experiment(
+ **_time_config, **_region_config, model_config=_model_cfg, catalog=_cat
+ )
names = [i.name for i in exp.models]
- self.assertEqual(['mock', 'qtree@team10', 'qtree@team25'], names)
+ self.assertEqual(["mock", "qtree@team10", "qtree@team25"], names)
m1_path = os.path.normpath(
- os.path.join(_dir, '../artifacts', 'models', 'qtree',
- 'TEAM=N10L11.csv'))
+ os.path.join(_dir, "../artifacts", "models", "qtree", "TEAM=N10L11.csv")
+ )
def test_stage_models(self):
- exp = Experiment(**_time_config, **_region_config,
- model_config=_model_cfg,
- catalog=_cat)
+ exp = Experiment(
+ **_time_config, **_region_config, model_config=_model_cfg, catalog=_cat
+ )
exp.stage_models()
- dbpath = os.path.relpath(
- os.path.join(_dir, '../artifacts', 'models', 'model.hdf5'))
+ dbpath = os.path.relpath(os.path.join(_dir, "../artifacts", "models", "model.hdf5"))
self.assertEqual(exp.models[0].path.database, dbpath)
def test_set_tests(self):
test_cfg = os.path.normpath(
- os.path.join(_dir, '../artifacts', 'evaluations',
- 'tests_cfg.yml'))
- exp = Experiment(**_time_config, **_region_config,
- test_config=test_cfg,
- catalog=_cat)
+ os.path.join(_dir, "../artifacts", "evaluations", "tests_cfg.yml")
+ )
+ exp = Experiment(**_time_config, **_region_config, test_config=test_cfg, catalog=_cat)
funcs = [i.func for i in exp.tests]
- funcs_expected = [poisson_evaluations.number_test,
- poisson_evaluations.spatial_test,
- poisson_evaluations.paired_t_test]
+ funcs_expected = [
+ poisson_evaluations.number_test,
+ poisson_evaluations.spatial_test,
+ poisson_evaluations.paired_t_test,
+ ]
for i, j in zip(funcs, funcs_expected):
self.assertIs(i, j)
def test_prepare_subcatalog(self):
time_config = {**_time_config}
- exp = Experiment(**time_config, **_region_config,
- catalog=_cat)
- tstring = '2020-08-01_2021-01-02'
+ exp = Experiment(**time_config, **_region_config, catalog=_cat)
+ tstring = "2020-08-01_2021-01-02"
with tempfile.NamedTemporaryFile() as file_:
+
def filetree(*args):
return file_.name
+
exp.path = filetree
# with patch.object(exp, 'filetree', filetree):
# print(file_.name)
@@ -172,6 +175,6 @@ def filetree(*args):
@classmethod
def tearDownClass(cls) -> None:
- path_ = os.path.join(_dir, '../artifacts', 'models', 'model.hdf5')
+ path_ = os.path.join(_dir, "../artifacts", "models", "model.hdf5")
if os.path.isfile(path_):
os.remove(path_)
diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py
index 965804f..c57dadf 100644
--- a/tests/unit/test_model.py
+++ b/tests/unit/test_model.py
@@ -9,6 +9,7 @@
import csep.core.regions
import numpy.testing
+from floatcsep.environments import EnvironmentManager
from floatcsep.model import TimeIndependentModel, TimeDependentModel
from floatcsep.utils import str2timewindow
@@ -229,14 +230,14 @@ def init_model(name, path, **kwargs):
return model
- def test_from_git(self):
+ @patch.object(EnvironmentManager, "create_environment")
+ def test_from_git(self, mock_create_environment):
"""clones model from git, checks with test artifacts"""
name = "mock_git"
_dir = "git_template"
path_ = os.path.join(tempfile.tempdir, _dir)
giturl = (
- "https://git.gfz-potsdam.de/csep-group/"
- "rise_italy_experiment/models/template.git"
+ "https://git.gfz-potsdam.de/csep-group/" "rise_italy_experiment/models/template.git"
)
model_a = self.init_model(name=name, path=path_, giturl=giturl)
model_a.stage()
@@ -335,42 +336,6 @@ def test_get_forecast_multiple(self, mock_load_forecast):
self.assertEqual(result[0], mock_forecast1)
self.assertEqual(result[1], mock_forecast2)
- @patch("subprocess.run") # Mock subprocess.run
- @patch("subprocess.Popen") # Mock subprocess.Popen
- @patch("os.path.exists") # Mock os.path.exists
- def test_build_model_creates_venv(self, mock_exists, mock_popen, mock_run):
- # Arrange
- model_path = "../artifacts/models/td_model"
- model = self.init_model(name="TestModel", path=model_path, build="venv")
- mock_exists.return_value = False # Simulate that the venv does not exist
-
- # Act
- model.build_model()
-
- # Assert
- mock_run.assert_called_once_with(
- ["python", "-m", "venv", model.path("path") + "/venv"]
- )
- mock_popen.assert_called_once() # Ensure Popen was called to install dependencies
- self.assertIn(f"cd {os.path.abspath(model_path)} && source", model.run_prefix)
-
- @patch("subprocess.run")
- @patch("subprocess.Popen")
- @patch("os.path.exists")
- def test_build_model_when_venv_exists(self, mock_exists, mock_popen, mock_run):
- # Arrange
- model_path = "../artifacts/models/td_model"
- model = self.init_model(name="TestModel", path=model_path, build="venv")
- mock_exists.return_value = True # Simulate that the venv already exists
-
- # Act
- model.build_model()
-
- # Assert
- mock_run.assert_not_called()
- mock_popen.assert_not_called()
- self.assertIn(f"cd {os.path.abspath(model_path)} && source", model.run_prefix)
-
def test_argprep(self):
model_path = os.path.join(self._dir, "td_model")
with open(os.path.join(model_path, "input", "args.txt"), "w") as args:
@@ -396,12 +361,8 @@ def test_argprep(self):
@patch("floatcsep.model.open", new_callable=mock_open, read_data='{"key": "value"}')
@patch("json.dump")
def test_argprep_json(self, mock_json_dump, mock_file):
- model = self.init_model(
- name="TestModel", path=os.path.join(self._dir, "td_model")
- )
- model.path = MagicMock(
- return_value="path/to/model/args.json"
- ) # Mock path method
+ model = self.init_model(name="TestModel", path=os.path.join(self._dir, "td_model"))
+ model.path = MagicMock(return_value="path/to/model/args.json")
start = MagicMock()
end = MagicMock()
start.isoformat.return_value = "2023-01-01"
diff --git a/tests/unit/test_readers.py b/tests/unit/test_readers.py
index c4c6a41..1ed0648 100644
--- a/tests/unit/test_readers.py
+++ b/tests/unit/test_readers.py
@@ -14,30 +14,27 @@ class TestForecastParsers(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls._path = os.path.dirname(__file__)
- cls._dir = os.path.join(cls._path, '../artifacts', 'models')
+ cls._dir = os.path.join(cls._path, "../artifacts", "models")
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ fname = os.path.join(cls._dir, 'model.hdf5')
+ if os.path.isfile(fname):
+ os.remove(fname)
def test_parse_csv(self):
- fname = os.path.join(self._dir, 'model.csv')
+ fname = os.path.join(self._dir, "model.csv")
numpy.seterr(all="ignore")
rates, region, mags = readers.ForecastParsers.csv(fname)
- rts = numpy.array([[1., 0.1],
- [1., 0.1],
- [1., 0.1],
- [1., 0.1]])
- orgs = numpy.array([[0., 0.],
- [0.1, 0],
- [0., 0.1],
- [0.1, 0.1]])
- poly_2 = numpy.array([[0., 0.1],
- [0., 0.2],
- [0.1, 0.2],
- [0.1, 0.1]])
+ rts = numpy.array([[1.0, 0.1], [1.0, 0.1], [1.0, 0.1], [1.0, 0.1]])
+ orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]])
+ poly_2 = numpy.array([[0.0, 0.1], [0.0, 0.2], [0.1, 0.2], [0.1, 0.1]])
numpy.testing.assert_allclose(rts, rates)
numpy.testing.assert_allclose(orgs, region.origins())
numpy.testing.assert_almost_equal(0.1, region.dh)
- numpy.testing.assert_allclose([5., 5.1], mags)
+ numpy.testing.assert_allclose([5.0, 5.1], mags)
numpy.testing.assert_allclose(poly_2, region.polygons[2].points)
def test_parse_dat(self):
@@ -50,41 +47,41 @@ def test_parse_dat(self):
numpy.testing.assert_allclose(forecast.data, rates)
def test_parse_csv_qtree(self):
- fname = os.path.join(self._dir, 'qtree', 'TEAM=N10L11.csv')
+ fname = os.path.join(self._dir, "qtree", "TEAM=N10L11.csv")
numpy.seterr(all="ignore")
rates, region, mags = readers.ForecastParsers.csv(fname)
- poly = numpy.array([[-180., 66.51326],
- [-180., 79.171335],
- [-135., 79.171335],
- [-135., 66.51326]])
+ poly = numpy.array(
+ [[-180.0, 66.51326], [-180.0, 79.171335], [-135.0, 79.171335], [-135.0, 66.51326]]
+ )
numpy.testing.assert_allclose(115.96694121688556, rates.sum())
- numpy.testing.assert_allclose([-177.1875, 51.179343],
- region.origins()[123])
+ numpy.testing.assert_allclose([-177.1875, 51.179343], region.origins()[123])
self.assertEqual(8089, rates.shape[0])
numpy.testing.assert_allclose(poly, region.polygons[2].points)
rates2, region2, mags2 = readers.ForecastParsers.quadtree(fname)
numpy.testing.assert_allclose(rates, rates2)
- numpy.testing.assert_allclose([i.points for i in region.polygons],
- [i.points for i in region2.polygons])
+ numpy.testing.assert_allclose(
+ [i.points for i in region.polygons], [i.points for i in region2.polygons]
+ )
numpy.testing.assert_allclose(mags, mags2)
def test_parse_xml(self):
- fname = os.path.join(self._path, '../../examples', 'case_e',
- 'models',
- 'gulia-wiemer.ALM.italy.10yr.2010-01-01.xml')
+ fname = os.path.join(
+ self._path,
+ "../../examples",
+ "case_e",
+ "models",
+ "gulia-wiemer.ALM.italy.10yr.2010-01-01.xml",
+ )
numpy.seterr(all="ignore")
rates, region, mags = readers.ForecastParsers.xml(fname)
orgs = numpy.array([12.6, 38.3])
- poly = numpy.array([[12.6, 38.3],
- [12.6, 38.4],
- [12.7, 38.4],
- [12.7, 38.3]])
+ poly = numpy.array([[12.6, 38.3], [12.6, 38.4], [12.7, 38.4], [12.7, 38.3]])
mags_ = numpy.arange(5, 9.05, 0.1)
numpy.testing.assert_allclose(16.185424321406536, rates.sum())
@@ -96,40 +93,34 @@ def test_parse_xml(self):
def test_serialize_hdf5(self):
numpy.seterr(all="ignore")
- fname = os.path.join(self._dir, 'model.csv')
+ fname = os.path.join(self._dir, "model.csv")
rates, region, mags = readers.ForecastParsers.csv(fname)
- fname_db = os.path.join(self._dir, 'model.hdf5')
- readers.HDF5Serializer.grid2hdf5(rates, region, mags,
- hdf5_filename=fname_db)
+ fname_db = os.path.join(self._dir, "model.hdf5")
+ readers.HDF5Serializer.grid2hdf5(rates, region, mags, hdf5_filename=fname_db)
self.assertTrue(os.path.isfile(fname_db))
size = os.path.getsize(fname_db)
- self.assertEqual(4640, size)
+ self.assertLessEqual(4500, size)
+ self.assertGreaterEqual(5000, size)
def test_parse_hdf5(self):
- fname = os.path.join(self._dir, 'model_h5.hdf5')
+ fname = os.path.join(self._dir, "model_h5.hdf5")
rates, region, mags = readers.ForecastParsers.hdf5(fname)
- orgs = numpy.array([[0., 0.],
- [0.1, 0],
- [0., 0.1],
- [0.1, 0.1]])
- poly_3 = numpy.array([[0.1, 0.1],
- [0.1, 0.2],
- [0.2, 0.2],
- [0.2, 0.1]])
+ orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]])
+ poly_3 = numpy.array([[0.1, 0.1], [0.1, 0.2], [0.2, 0.2], [0.2, 0.1]])
numpy.testing.assert_allclose(4.4, rates.sum())
numpy.testing.assert_allclose(orgs, region.origins())
numpy.testing.assert_almost_equal(0.1, region.dh)
- numpy.testing.assert_allclose([5., 5.1], mags)
+ numpy.testing.assert_allclose([5.0, 5.1], mags)
numpy.testing.assert_allclose(poly_3, region.polygons[3].points)
def test_checkformat_xml(self):
def save(xml_list):
- name_ = os.path.join(tempfile.tempdir, 'tmpxml.xml')
- with open(name_, 'w') as file_:
+ name_ = os.path.join(tempfile.tempdir, "tmpxml.xml")
+ with open(name_, "w") as file_:
for i in xml_list:
- file_.write(i + '\n')
+ file_.write(i + "\n")
return name_
forecast_xml = [
@@ -148,34 +139,34 @@ def save(xml_list):
filename = save(forecast_xml)
try:
- readers.check_format(filename, fmt='xml')
+ readers.check_format(filename, fmt="xml")
except (IndentationError, IndexError, KeyError):
- self.fail('Format check failed')
+ self.fail("Format check failed")
xml_fail = copy.deepcopy(forecast_xml)
xml_fail[3] = ""
xml_fail[-3] = ""
filename = save(xml_fail)
with pytest.raises(LookupError):
- readers.check_format(filename, fmt='xml')
+ readers.check_format(filename, fmt="xml")
xml_fail = copy.deepcopy(forecast_xml)
xml_fail[4] = ""
filename = save(xml_fail)
with pytest.raises(KeyError):
- readers.check_format(filename, fmt='xml')
+ readers.check_format(filename, fmt="xml")
xml_fail = copy.deepcopy(forecast_xml)
xml_fail[5] = "1.6773966e-008"
filename = save(xml_fail)
with pytest.raises(LookupError):
- readers.check_format(filename, fmt='xml')
+ readers.check_format(filename, fmt="xml")
xml_fail = copy.deepcopy(forecast_xml)
xml_fail[5] = "1.6773966e-008"
filename = save(xml_fail)
with pytest.raises(KeyError):
- readers.check_format(filename, fmt='xml')
+ readers.check_format(filename, fmt="xml")
xml_fail = copy.deepcopy(forecast_xml)
xml_fail[2] = ""
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index f70a696..2e08832 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -10,8 +10,13 @@
import floatcsep
import floatcsep.accessors
import floatcsep.extras
-from floatcsep.utils import parse_timedelta_string, timewindows_ti, \
- read_time_cfg, read_region_cfg, parse_csep_func
+from floatcsep.utils import (
+ parse_timedelta_string,
+ timewindows_ti,
+ read_time_cfg,
+ read_region_cfg,
+ parse_csep_func,
+)
root_dir = os.path.dirname(os.path.abspath(__file__))
@@ -19,73 +24,82 @@
class CsepFunctionTest(unittest.TestCase):
def test_parse_csep_func(self):
- self.assertIsInstance(parse_csep_func('load_gridded_forecast'),
- csep.load_gridded_forecast.__class__)
- self.assertIsInstance(parse_csep_func('join_struct_arrays'),
- csep.utils.join_struct_arrays.__class__)
- self.assertIsInstance(parse_csep_func('plot_poisson_consistency_test'),
- csep.utils.plots.plot_poisson_consistency_test.__class__)
- self.assertIsInstance(parse_csep_func('italy_csep_region'),
- csep.core.regions.italy_csep_region.__class__)
- self.assertIsInstance(parse_csep_func('plot_forecast_lowres'),
- floatcsep.utils.plot_forecast_lowres.__class__)
- self.assertIsInstance(parse_csep_func('from_zenodo'),
- floatcsep.accessors.from_zenodo.__class__)
- self.assertIsInstance(parse_csep_func('from_zenodo'),
- floatcsep.extras.vector_poisson_t_w_test.__class__)
- self.assertRaises(AttributeError, parse_csep_func, 'panic_button')
+ self.assertIsInstance(
+ parse_csep_func("load_gridded_forecast"), csep.load_gridded_forecast.__class__
+ )
+ self.assertIsInstance(
+ parse_csep_func("join_struct_arrays"), csep.utils.join_struct_arrays.__class__
+ )
+ self.assertIsInstance(
+ parse_csep_func("plot_poisson_consistency_test"),
+ csep.utils.plots.plot_poisson_consistency_test.__class__,
+ )
+ self.assertIsInstance(
+ parse_csep_func("italy_csep_region"), csep.core.regions.italy_csep_region.__class__
+ )
+ self.assertIsInstance(
+ parse_csep_func("plot_forecast_lowres"),
+ floatcsep.utils.plot_forecast_lowres.__class__,
+ )
+ self.assertIsInstance(
+ parse_csep_func("from_zenodo"), floatcsep.accessors.from_zenodo.__class__
+ )
+ self.assertIsInstance(
+ parse_csep_func("from_zenodo"), floatcsep.extras.vector_poisson_t_w_test.__class__
+ )
+ self.assertRaises(AttributeError, parse_csep_func, "panic_button")
class TimeUtilsTest(unittest.TestCase):
def test_parse_time_window(self):
- dt = '1Year'
- self.assertEqual(parse_timedelta_string(dt), '1-years')
- dt = '7-Days'
- self.assertEqual(parse_timedelta_string(dt), '7-days')
- dt = '1- mOnThS'
- self.assertEqual(parse_timedelta_string(dt), '1-months')
- dt = '20 days'
- self.assertEqual(parse_timedelta_string(dt), '20-days')
- dt = '1decade'
+ dt = "1Year"
+ self.assertEqual(parse_timedelta_string(dt), "1-years")
+ dt = "7-Days"
+ self.assertEqual(parse_timedelta_string(dt), "7-days")
+ dt = "1- mOnThS"
+ self.assertEqual(parse_timedelta_string(dt), "1-months")
+ dt = "20 days"
+ self.assertEqual(parse_timedelta_string(dt), "20-days")
+ dt = "1decade"
self.assertRaises(ValueError, parse_timedelta_string, dt)
def test_timewindows_ti(self):
start = datetime(2014, 1, 1)
end = datetime(2022, 1, 1)
- self.assertEqual(timewindows_ti(start_date=start,
- end_date=end), [(start, end)])
-
- t1 = [(datetime(2014, 1, 1), datetime(2018, 1, 1)),
- (datetime(2018, 1, 1), datetime(2022, 1, 1))]
- self.assertEqual(timewindows_ti(start_date=start,
- end_date=end,
- intervals=2), t1)
- self.assertEqual(timewindows_ti(start_date=start,
- end_date=end,
- horizon='4-years'), t1)
- self.assertEqual(timewindows_ti(start_date=start,
- intervals=2,
- horizon='4-years'), t1)
-
- t2 = [(datetime(2014, 1, 1, 0, 0),
- datetime(2015, 2, 22, 10, 17, 8, 571428)),
- (datetime(2015, 2, 22, 10, 17, 8, 571428),
- datetime(2016, 4, 14, 20, 34, 17, 142857)),
- (datetime(2016, 4, 14, 20, 34, 17, 142857),
- datetime(2017, 6, 6, 6, 51, 25, 714285)),
- (datetime(2017, 6, 6, 6, 51, 25, 714285),
- datetime(2018, 7, 28, 17, 8, 34, 285714)),
- (datetime(2018, 7, 28, 17, 8, 34, 285714),
- datetime(2019, 9, 19, 3, 25, 42, 857142)),
- (datetime(2019, 9, 19, 3, 25, 42, 857142),
- datetime(2020, 11, 9, 13, 42, 51, 428571)),
- (datetime(2020, 11, 9, 13, 42, 51, 428571),
- datetime(2022, 1, 1, 0, 0))]
- self.assertEqual(timewindows_ti(start_date=start,
- end_date=end,
- intervals=7), t2)
+ self.assertEqual(timewindows_ti(start_date=start, end_date=end), [(start, end)])
+
+ t1 = [
+ (datetime(2014, 1, 1), datetime(2018, 1, 1)),
+ (datetime(2018, 1, 1), datetime(2022, 1, 1)),
+ ]
+ self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=2), t1)
+ self.assertEqual(timewindows_ti(start_date=start, end_date=end, horizon="4-years"), t1)
+ self.assertEqual(timewindows_ti(start_date=start, intervals=2, horizon="4-years"), t1)
+
+ t2 = [
+ (datetime(2014, 1, 1, 0, 0), datetime(2015, 2, 22, 10, 17, 8, 571428)),
+ (
+ datetime(2015, 2, 22, 10, 17, 8, 571428),
+ datetime(2016, 4, 14, 20, 34, 17, 142857),
+ ),
+ (
+ datetime(2016, 4, 14, 20, 34, 17, 142857),
+ datetime(2017, 6, 6, 6, 51, 25, 714285),
+ ),
+ (datetime(2017, 6, 6, 6, 51, 25, 714285), datetime(2018, 7, 28, 17, 8, 34, 285714)),
+ (
+ datetime(2018, 7, 28, 17, 8, 34, 285714),
+ datetime(2019, 9, 19, 3, 25, 42, 857142),
+ ),
+ (
+ datetime(2019, 9, 19, 3, 25, 42, 857142),
+ datetime(2020, 11, 9, 13, 42, 51, 428571),
+ ),
+ (datetime(2020, 11, 9, 13, 42, 51, 428571), datetime(2022, 1, 1, 0, 0)),
+ ]
+ self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=7), t2)
def test_timewindows_ti_td(self):
pass
@@ -94,20 +108,15 @@ def test_read_time_config(self):
start = datetime(2014, 1, 1)
end = datetime(2022, 1, 1)
intervals = 2
- config = {'start_date': start,
- 'end_date': end,
- 'intervals': intervals}
- self.assertEqual(read_time_cfg(None, **config),
- read_time_cfg(config))
-
- short_config = {'start_date': start,
- 'end_date': end}
- time_config = {'intervals': 2}
- full_config = {'start_date': start,
- 'end_date': end,
- 'intervals': 2}
- self.assertEqual(read_time_cfg(time_config, **short_config),
- read_time_cfg(None, **full_config))
+ config = {"start_date": start, "end_date": end, "intervals": intervals}
+ self.assertEqual(read_time_cfg(None, **config), read_time_cfg(config))
+
+ short_config = {"start_date": start, "end_date": end}
+ time_config = {"intervals": 2}
+ full_config = {"start_date": start, "end_date": end, "intervals": 2}
+ self.assertEqual(
+ read_time_cfg(time_config, **short_config), read_time_cfg(None, **full_config)
+ )
class RegionUtilsTest(unittest.TestCase):
@@ -117,50 +126,51 @@ def test_magnitudes_depth(self):
mag_min = 1
mag_max = 1.2
mag_bin = 0.1
- depth_max = 1.
- depth_min = 0.
+ depth_max = 1.0
+ depth_min = 0.0
- config = {'mag_min': mag_min,
- 'mag_max': mag_max,
- 'mag_bin': mag_bin,
- 'depth_min': depth_min,
- 'depth_max': depth_max}
+ config = {
+ "mag_min": mag_min,
+ "mag_max": mag_max,
+ "mag_bin": mag_bin,
+ "depth_min": depth_min,
+ "depth_max": depth_max,
+ }
region_config = read_region_cfg(config)
self.assertEqual(8, len(region_config))
- numpy.testing.assert_equal(magnitudes,
- region_config['magnitudes'])
- numpy.testing.assert_equal(numpy.array([depth_min, depth_max]),
- region_config['depths'])
+ numpy.testing.assert_equal(magnitudes, region_config["magnitudes"])
+ numpy.testing.assert_equal(numpy.array([depth_min, depth_max]), region_config["depths"])
def test_region(self):
- region_origins = numpy.array([[0, 0],
- [0.1, 0],
- [0.1, 0.1],
- [0, 0.1]])
- region_path = os.path.join(os.path.dirname(__file__), '../artifacts',
- 'regions', 'mock_region')
- config = {'region': region_path,
- 'mag_min': 1,
- 'mag_max': 1.2,
- 'mag_bin': 0.1,
- 'depth_min': 0,
- 'depth_max': 1}
+ region_origins = numpy.array([[0, 0], [0.1, 0], [0.1, 0.1], [0, 0.1]])
+ region_path = os.path.join(
+ os.path.dirname(__file__), "../artifacts", "regions", "mock_region"
+ )
+ config = {
+ "region": region_path,
+ "mag_min": 1,
+ "mag_max": 1.2,
+ "mag_bin": 0.1,
+ "depth_min": 0,
+ "depth_max": 1,
+ }
region_config = read_region_cfg(config)
self.assertEqual(9, len(region_config))
- numpy.testing.assert_equal(region_origins,
- region_config['region'].origins())
-
- config = {'region': 'italy_csep_region',
- 'mag_min': 1,
- 'mag_max': 1.2,
- 'mag_bin': 0.1,
- 'depth_min': 0,
- 'depth_max': 1}
- region_path = os.path.join(os.path.dirname(__file__), '../artifacts',
- 'regions', 'italy_midpoints')
+ numpy.testing.assert_equal(region_origins, region_config["region"].origins())
+
+ config = {
+ "region": "italy_csep_region",
+ "mag_min": 1,
+ "mag_max": 1.2,
+ "mag_bin": 0.1,
+ "depth_min": 0,
+ "depth_max": 1,
+ }
+ region_path = os.path.join(
+ os.path.dirname(__file__), "../artifacts", "regions", "italy_midpoints"
+ )
midpoints = numpy.genfromtxt(region_path)
region_config = read_region_cfg(config)
- numpy.testing.assert_almost_equal(midpoints,
- region_config['region'].midpoints())
+ numpy.testing.assert_almost_equal(midpoints, region_config["region"].midpoints())
|