Skip to content

Commit

Permalink
ft: A forecast factory class is created to handle the creation of gri…
Browse files Browse the repository at this point in the history
…dded forecast according to the file/forecast format.

build: added h5py as requirement
  • Loading branch information
pabloitu committed Sep 2, 2024
1 parent 62d50c7 commit 19f2d50
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 3 deletions.
255 changes: 252 additions & 3 deletions csep/core/forecasts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import itertools
import time
import os
import datetime
from typing import Optional

# third-party imports
import numpy
import xml.etree.ElementTree as eTree
import h5py
import pandas

from csep.utils.log import LoggingMixin
from csep.core.regions import CartesianGrid2D, create_space_magnitude_region
from csep.core.regions import CartesianGrid2D, create_space_magnitude_region, QuadtreeGrid2D
from csep.models import Polygon
from csep.utils.calc import bin1d_vec
from csep.utils.time_utils import decimal_year, datetime_to_utc_epoch
Expand Down Expand Up @@ -753,4 +756,250 @@ def load_ascii(cls, fname, **kwargs):
Returns:
:class:`csep.core.forecasts.CatalogForecast
"""
raise NotImplementedError("load_ascii is not implemented!")
raise NotImplementedError("load_ascii is not implemented!")


class GriddedForecastFactory:

@staticmethod
def from_dat(filename: str,
swap_latlon: bool = False,
name: Optional[str] = None,
start_date: Optional[datetime.datetime] = None,
end_date: Optional[datetime.datetime] = None,
**kwargs) -> GriddedForecast:
""" Creates a :class:`GriddedCatalog` from a.dat file."""

data = numpy.loadtxt(filename)
all_polys = data[:, :4]
all_poly_mask = data[:, -1]
sorted_idx = numpy.sort(
numpy.unique(all_polys, return_index=True, axis=0)[1], kind="stable"
)
unique_poly = all_polys[sorted_idx]
poly_mask = all_poly_mask[sorted_idx]
all_mws = data[:, -4]
sorted_idx = numpy.sort(numpy.unique(all_mws, return_index=True)[1], kind="stable")

magnitudes = all_mws[sorted_idx]
if swap_latlon:
bboxes = [((i[2], i[0]), (i[3], i[0]),
(i[3], i[1]), (i[2], i[1])) for i in unique_poly]
else:
bboxes = [((i[0], i[2]), (i[0], i[3]),
(i[1], i[3]), (i[1], i[2])) for i in unique_poly]

dh = float(unique_poly[0, 3] - unique_poly[0, 2])

n_mag_bins = len(magnitudes)
rates = data[:, -2].reshape(len(bboxes), n_mag_bins)

region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask)

forecast = GriddedForecast(
name=f"{name}",
data=rates,
region=region,
magnitudes=magnitudes,
start_time=start_date,
end_time=end_date,
**kwargs
)

return forecast

@staticmethod
def from_xml(filename: str,
**kwargs):
tree = eTree.parse(filename)
root = tree.getroot()
metadata = {}
data_ijm = []
m_bins = []
cells = []
cell_dim = {}
for k, children in enumerate(list(root[0])):
if "modelName" in children.tag:
name_xml = children.text
metadata["name"] = name_xml
elif "author" in children.tag:
author_xml = children.text
metadata["author"] = author_xml
elif "forecastStartDate" in children.tag:
start_date = children.text.replace("Z", "")
metadata["forecastStartDate"] = start_date
elif "forecastEndDate" in children.tag:
end_date = children.text.replace("Z", "")
metadata["forecastEndDate"] = end_date
elif "defaultMagBinDimension" in children.tag:
m_bin_width = float(children.text)
metadata["defaultMagBinDimension"] = m_bin_width
elif "lastMagBinOpen" in children.tag:
lastmbin = float(children.text)
metadata["lastMagBinOpen"] = lastmbin
elif "defaultCellDimension" in children.tag:
cell_dim = {i[0]: float(i[1]) for i in children.attrib.items()}
metadata["defaultCellDimension"] = cell_dim
elif "depthLayer" in children.tag:
depth = {i[0]: float(i[1]) for i in root[0][k].attrib.items()}
cells = root[0][k]
metadata["depthLayer"] = depth

for cell in cells:
cell_data = []
m_cell_bins = []
for i, m in enumerate(cell.iter()):
if i == 0:
cell_data.extend([float(m.attrib["lon"]), float(m.attrib["lat"])])
else:
cell_data.append(float(m.text))
m_cell_bins.append(float(m.attrib["m"]))
data_ijm.append(cell_data)
m_bins.append(m_cell_bins)
try:
data_ijm = numpy.array(data_ijm)
m_bins = numpy.array(m_bins)
except (TypeError, ValueError):
raise Exception("Data is not square")

magnitudes = m_bins[0, :]
rates = data_ijm[:, -len(magnitudes) :]
all_polys = numpy.vstack(
(
data_ijm[:, 0] - cell_dim["lonRange"] / 2.0,
data_ijm[:, 0] + cell_dim["lonRange"] / 2.0,
data_ijm[:, 1] - cell_dim["latRange"] / 2.0,
data_ijm[:, 1] + cell_dim["latRange"] / 2.0,
)
).T
bboxes = [((i[0], i[2]), (i[0], i[3]), (i[1], i[3]), (i[1], i[2])) for i in all_polys]
dh = float(all_polys[0, 3] - all_polys[0, 2])
poly_mask = numpy.ones(len(bboxes))
region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask)

forecast = GriddedForecast(
name=f"{metadata['name']}",
data=rates,
region=region,
magnitudes=magnitudes,
start_time=datetime.datetime.fromisoformat(metadata["forecastStartDate"]),
end_time=datetime.datetime.fromisoformat(metadata["forecastEndDate"]),
**kwargs
)
return forecast

@staticmethod
def from_quadtree(filename: str,
name: Optional[str] = None,
start_date: Optional[datetime.datetime] = None,
end_date: Optional[datetime.datetime] = None,
**kwargs) -> GriddedForecast:

with open(filename, "r") as file_:
qt_header = file_.readline().split(",")
formats = [str]
for i in range(len(qt_header) - 1):
formats.append(float)

qt_formats = {i: j for i, j in zip(qt_header, formats)}
data = pandas.read_csv(filename, header=0, dtype=qt_formats)

quadkeys = numpy.array([i.encode("ascii", "ignore") for i in data.tile])
magnitudes = numpy.array(data.keys()[3:]).astype(float)
rates = data[magnitudes.astype(str)].to_numpy()

region = QuadtreeGrid2D.from_quadkeys([str(i) for i in quadkeys], magnitudes=magnitudes)
region.get_cell_area()

forecast = GriddedForecast(
name=f"{name}",
data=rates,
region=region,
magnitudes=magnitudes,
start_time=start_date,
end_time=end_date,
**kwargs
)

return forecast

@staticmethod
def from_csv(filename):
def is_mag(num):
try:
m = float(num)
if -1 < m < 12.0:
return True
else:
return False
except ValueError:
return False

with open(filename, "r") as file_:
line = file_.readline()
if len(line.split(",")) > 3:
sep = ","
else:
sep = " "

data = pandas.read_csv(
filename, header=0, sep=sep, escapechar="#", skipinitialspace=True
)

data.columns = [i.strip() for i in data.columns]
magnitudes = numpy.array([float(i) for i in data.columns if is_mag(i)])
rates = data[[i for i in data.columns if is_mag(i)]].to_numpy()
all_polys = data[["lon_min", "lon_max", "lat_min", "lat_max"]].to_numpy()
bboxes = [((i[0], i[2]), (i[0], i[3]), (i[1], i[3]), (i[1], i[2])) for i in all_polys]
dh = float(all_polys[0, 3] - all_polys[0, 2])

try:
poly_mask = data["mask"]
except KeyError:
poly_mask = numpy.ones(len(bboxes))

region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask)

return rates, region, magnitudes

@staticmethod
def from_hdf5(filename: str,
group: str = "",
name: Optional[str] = None,
start_date: Optional[datetime.datetime] = None,
end_date: Optional[datetime.datetime] = None,
**kwargs) -> GriddedForecast:
"""
Load a gridded forecast from an HDF5 file.
Arguments:
filename: The name of the HDF5 file.
group: The HDF5 group to load the forecast from. Usually represents the forecast
time
name: The name of the gridded forecast.
start_date: The start date of the forecast.
end_date: The end date of the forecast.
**kwargs: Additional keyword arguments passed to `read`.
"""

with h5py.File(filename, "r") as db:
rates = db[f"{group}/rates"][:]
magnitudes = db[f"{group}/magnitudes"][:]

dh = db[f"{group}/dh"][:][0]
bboxes = db[f"{group}/bboxes"][:]
poly_mask = db[f"{group}/poly_mask"][:]
region = CartesianGrid2D([Polygon(bbox) for bbox in bboxes], dh, mask=poly_mask)

forecast = GriddedForecast(
name=f"{name}",
data=rates,
region=region,
magnitudes=magnitudes,
start_time=start_date,
end_time=end_date,
**kwargs
)

return forecast
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ scipy
pandas
matplotlib
cartopy
h5py
obspy
pyproj
python-dateutil
Expand Down
1 change: 1 addition & 0 deletions requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- numpy
- pandas
- scipy
- h5py
- matplotlib
- pyproj
- obspy
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pandas
matplotlib
cartopy
obspy
h5py
pyproj
python-dateutil
pytest
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_version():
'numpy',
'scipy',
'pandas',
'h5py',
'matplotlib',
'cartopy',
'obspy',
Expand Down

0 comments on commit 19f2d50

Please sign in to comment.