Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save nexus #97

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9]
python-version: [3.8, 3.9, '3.10', '3.11']

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

Expand Down Expand Up @@ -47,12 +47,12 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7]
python-version: [3.7]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

Expand Down
97 changes: 93 additions & 4 deletions orsopy/fileio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def _noop(self, *args, **kw):
pass


JSON_MIMETYPE = "application/json"

yaml.emitter.Emitter.process_tag = _noop

# make sure that datetime strings get loaded as str not datetime instances
Expand Down Expand Up @@ -82,7 +84,12 @@ def _custom_init_fn(fieldsarg, frozen, has_post_init, self_name, globals):
)


# register all ORSO classes here:
ORSO_DATACLASSES = dict()


def orsodataclass(cls: type):
ORSO_DATACLASSES[cls.__name__] = cls
attrs = cls.__dict__
bases = cls.__bases__
if "__annotations__" in attrs and len([k for k in attrs["__annotations__"].keys() if not k.startswith("_")]) > 0:
Expand Down Expand Up @@ -275,7 +282,8 @@ def _resolve_type(hint: type, item: Any) -> Any:
return item
else:
warnings.warn(
f"Has to be one of {get_args(hint)} got {item}", RuntimeWarning,
f"Has to be one of {get_args(hint)} got {item}",
RuntimeWarning,
)
return str(item)
return None
Expand Down Expand Up @@ -376,6 +384,67 @@ def yaml_representer_compact(self, dumper: yaml.Dumper):
output = self._to_object_dict()
return dumper.represent_mapping(dumper.DEFAULT_MAPPING_TAG, output, flow_style=True)

def to_nexus(self, root=None, name=None):
"""
Produces an HDF5 representation of the Header object, removing
any optional attributes with the value :code:`None`.

:return: HDF5 object
"""
classname = self.__class__.__name__
import h5py

assert isinstance(root, h5py.Group)
group = root.create_group(classname if name is None else name)
group.attrs["ORSO_class"] = classname

for child_name, value in self.__dict__.items():
if child_name.startswith("_") or (value is None and child_name in self._orso_optionals):
continue

if value.__class__ in ORSO_DATACLASSES.values():
value.to_nexus(root=group, name=child_name)
elif isinstance(value, (list, tuple)):
child_group = group.create_group(child_name)
child_group.attrs["sequence"] = 1
for index, item in enumerate(value):
# use the 'name' attribute of children if it exists, else index:
sub_name = getattr(item, "name", str(index))
if item.__class__ in ORSO_DATACLASSES.values():
item_out = item.to_nexus(root=child_group, name=sub_name)
else:
t_value = nexus_value_converter(item)
if any(isinstance(t_value, t) for t in (str, float, int, bool, np.ndarray)):
item_out = child_group.create_dataset(sub_name, data=t_value)
elif t_value is None:
# special handling for null datasets: no data
item_out = child_group.create_dataset(sub_name, dtype="f")
elif isinstance(t_value, dict):
item_out = child_group.create_dataset(sub_name, data=json.dumps(t_value))
item_out.attrs["mimetype"] = JSON_MIMETYPE
else:
import warnings
# raise ValueError(f"unserializable attribute found: {child_name}[{index}] = {t_value}")
warnings.warn(f"unserializable attribute found: {child_name}[{index}] = {t_value}")
continue
item_out.attrs["sequence_index"] = index
else:
# here _todict converts objects that aren't derived from Header
# and therefore don't have to_dict methods.
t_value = nexus_value_converter(value)
if any(isinstance(t_value, t) for t in (str, float, int, bool, np.ndarray)):
group.create_dataset(child_name, data=t_value)
elif t_value is None:
group.create_dataset(child_name, dtype="f")
elif isinstance(t_value, dict):
dset = group.create_dataset(child_name, data=json.dumps(t_value))
dset.attrs["mimetype"] = JSON_MIMETYPE
else:
import warnings
warnings.warn(f"unserializable attribute found: {child_name} = {t_value}")
# raise ValueError(f"unserializable attribute found: {child_name} = {t_value}")
return group

@staticmethod
def _check_unit(unit: str):
"""
Expand Down Expand Up @@ -431,6 +500,8 @@ def represent_data(self, data):
elif isinstance(data, datetime.datetime):
value = data.isoformat("T")
return super().represent_scalar("tag:yaml.org,2002:timestamp", value)
elif isinstance(data, np.floating):
return super().represent_data(float(data))
else:
return super().represent_data(data)

Expand Down Expand Up @@ -798,7 +869,7 @@ def _read_header_data(file: Union[TextIO, str], validate: bool = False) -> Tuple
# numerical array and start collecting the numbers for this
# dataset
_d = np.array([np.fromstring(v, dtype=float, sep=" ") for v in _ds_lines])
data.append(_d)
data.append(_d.T)
_ds_lines = []

# append '---' to signify the start of a new yaml document
Expand All @@ -811,7 +882,7 @@ def _read_header_data(file: Union[TextIO, str], validate: bool = False) -> Tuple

# append the last numerical array
_d = np.array([np.fromstring(v, dtype=float, sep=" ") for v in _ds_lines])
data.append(_d)
data.append(_d.T)

yml = "".join(header)

Expand Down Expand Up @@ -924,7 +995,7 @@ def _todict(obj: Any, classkey: Any = None) -> dict:
"""
if isinstance(obj, dict):
data = {}
for (k, v) in obj.items():
for k, v in obj.items():
data[k] = _todict(v, classkey)
return data
elif isinstance(obj, Enum):
Expand All @@ -949,6 +1020,24 @@ def _todict(obj: Any, classkey: Any = None) -> dict:
return obj


def json_datetime_trap(obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat()
return obj


def enum_trap(obj):
if isinstance(obj, Enum):
return obj.value
return obj


def nexus_value_converter(obj):
for converter in (json_datetime_trap, enum_trap):
obj = converter(obj)
return obj


def _nested_update(d: dict, u: dict) -> dict:
"""
Nested dictionary update.
Expand Down
118 changes: 110 additions & 8 deletions orsopy/fileio/orso.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Implementation of the top level class for the ORSO header.
"""

from dataclasses import dataclass
from typing import Any, List, Optional, TextIO, Union
from dataclasses import dataclass, fields
from typing import BinaryIO, List, Optional, Sequence, TextIO, Union

import numpy as np
import yaml

from .base import (Column, ErrorColumn, Header, _dict_diff, _nested_update, _possibly_open_file, _read_header_data,
orsodataclass)
from .base import (JSON_MIMETYPE, ORSO_DATACLASSES, Column, ErrorColumn, Header, _dict_diff, _nested_update,
_possibly_open_file, _read_header_data, orsodataclass)
from .data_source import DataSource
from .reduction import Reduction

Expand Down Expand Up @@ -163,11 +163,14 @@ class OrsoDataset:
"""

info: Orso
data: Any
data: Union[np.ndarray, Sequence[np.ndarray], Sequence[Sequence]]

def __post_init__(self):
if self.data.shape[1] != len(self.info.columns):
if len(self.data) != len(self.info.columns):
raise ValueError("Data has to have the same number of columns as header")
column_lengths = set(len(c) for c in self.data)
if len(column_lengths) > 1:
raise ValueError("Columns must all have the same length in first dimension")

def header(self) -> str:
"""
Expand Down Expand Up @@ -210,6 +213,9 @@ def __eq__(self, other: "OrsoDataset"):
return self.info == other.info and (self.data == other.data).all()


ORSO_DATACLASSES["OrsoDataset"] = OrsoDataset


def save_orso(
datasets: List[OrsoDataset], fname: Union[TextIO, str], comment: Optional[str] = None, data_separator: str = ""
) -> None:
Expand Down Expand Up @@ -249,13 +255,13 @@ def save_orso(

ds1 = datasets[0]
header += ds1.header()
np.savetxt(f, ds1.data, header=header, fmt="%-22.16e")
np.savetxt(f, np.asarray(ds1.data).T, header=header, fmt="%-22.16e")

for dsi in datasets[1:]:
# write an optional spacer string between dataset e.g. \n
f.write(data_separator)
hi = ds1.diff_header(dsi)
np.savetxt(f, dsi.data, header=hi, fmt="%-22.16e")
np.savetxt(f, np.asarray(dsi.data).T, header=hi, fmt="%-22.16e")


def load_orso(fname: Union[TextIO, str]) -> List[OrsoDataset]:
Expand All @@ -273,3 +279,99 @@ def load_orso(fname: Union[TextIO, str]) -> List[OrsoDataset]:
od = OrsoDataset(o, data)
ods.append(od)
return ods


def _from_nexus_group(group):
if group.attrs.get("sequence", None) is not None:
sort_list = [[v.attrs["sequence_index"], v] for v in group.values()]
return [_get_nexus_item(v) for _, v in sorted(sort_list)]
else:
dct = dict()
for name, value in group.items():
if value.attrs.get('NX_class', None) == 'NXdata':
# remove NXdata folder, which exists only for NeXus plotting
continue
dct[name] = _get_nexus_item(value)

if "ORSO_class" in group.attrs:
cls = ORSO_DATACLASSES[group.attrs["ORSO_class"]]
return cls(**dct)
else:
return dct


def _get_nexus_item(value):
import json

import h5py

if isinstance(value, h5py.Group):
return _from_nexus_group(value)
elif isinstance(value, h5py.Dataset):
v = value[()]
if isinstance(value, h5py.Empty):
return None
elif value.attrs.get("mimetype", None) == JSON_MIMETYPE:
return json.loads(v)
elif hasattr(v, "decode"):
# it is a bytes object, should be string
return v.decode()
else:
return v


def load_nexus(fname: Union[str, BinaryIO]) -> List[OrsoDataset]:
import h5py

f = h5py.File(fname, "r")
return [_from_nexus_group(g) for g in f.values() if g.attrs.get("ORSO_class", None) == "OrsoDataset"]


def save_nexus(datasets: List[OrsoDataset], fname: Union[str, BinaryIO], comment: Optional[str] = None) -> BinaryIO:
import h5py

for idx, dataset in enumerate(datasets):
info = dataset.info
data_set = info.data_set
if data_set is None or (isinstance(data_set, str) and len(data_set) == 0):
# it's not set, or is zero length string
info.data_set = idx

dsets = [dataset.info.data_set for dataset in datasets]
if len(set(dsets)) != len(dsets):
raise ValueError("All `OrsoDataset.info.data_set` values must be unique")

with h5py.File(fname, mode="w") as f:
f.attrs["NX_class"] = "NXroot"
if comment is not None:
f.attrs["comment"] = comment

for dsi in datasets:
info = dsi.info
entry = f.create_group(info.data_set)
entry.attrs["ORSO_class"] = "OrsoDataset"
entry.attrs["NX_class"] = "NXentry"
entry.attrs["default"] = "plottable_data"
info.to_nexus(root=entry, name="info")
data_group = entry.create_group("data")
data_group.attrs["sequence"] = 1
plottable_data_group = entry.create_group("plottable_data", track_order=True)
plottable_data_group.attrs["NX_class"] = "NXdata"
plottable_data_group.attrs["sequence"] = 1
plottable_data_group.attrs["axes"] = [info.columns[0].name]
plottable_data_group.attrs["signal"] = info.columns[1].name
plottable_data_group.attrs[f"{info.columns[0].name}_indices"] = [0]
for column_index, column in enumerate(info.columns):
# assume that dataset.data has dimension == ncolumns along first dimension
# (note that this is not how data would be loaded from e.g. load_orso, which is row-first)
col_data = data_group.create_dataset(column.name, data=dsi.data[column_index])
col_data.attrs["sequence_index"] = column_index
col_data.attrs["target"] = col_data.name
if isinstance(column, ErrorColumn):
nexus_colname = column.error_of + "_errors"
else:
nexus_colname = column.name
if column.unit is not None:
col_data.attrs["units"] = column.unit

plottable_data_group[nexus_colname] = col_data
14 changes: 7 additions & 7 deletions tests/test_fileio/test_orso.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def test_write_read(self):
# test write and read of multiple datasets
info = fileio.Orso.empty()
info2 = fileio.Orso.empty()
data = np.zeros((100, 3))
data[:] = np.arange(100.0)[:, None]
data = np.zeros((3, 100))
data[:] = np.arange(100.0)[None, :]

info.columns = [
fileio.Column("Qz", "1/angstrom"),
Expand Down Expand Up @@ -177,14 +177,14 @@ def test_unique_dataset(self):
info2.data_set = 0
info2.columns = [Column("stuff")] * 4

ds = OrsoDataset(info, np.empty((2, 4)))
ds2 = OrsoDataset(info2, np.empty((2, 4)))
ds = OrsoDataset(info, np.empty((4, 2)))
ds2 = OrsoDataset(info2, np.empty((4, 2)))

with pytest.raises(ValueError):
fileio.save_orso([ds, ds2], "test_data_set.ort")

with pytest.raises(ValueError):
OrsoDataset(info, np.empty((2, 5)))
OrsoDataset(info, np.empty((5, 2)))

def test_user_data(self):
# test write and read of userdata
Expand All @@ -195,8 +195,8 @@ def test_user_data(self):
fileio.ErrorColumn("R"),
]

data = np.zeros((100, 3))
data[:] = np.arange(100.0)[:, None]
data = np.zeros((3, 100))
data[:] = np.arange(100.0)[None, :]
dct = {"ci": "1", "foo": ["bar", 1, 2, 3.5]}
info.user_data = dct
ds = fileio.OrsoDataset(info, data)
Expand Down
Loading