Skip to content

Commit

Permalink
feat: misc
Browse files Browse the repository at this point in the history
  • Loading branch information
dmyersturnbull committed Nov 1, 2021
1 parent 520e833 commit 81ca098
Show file tree
Hide file tree
Showing 12 changed files with 412 additions and 181 deletions.
20 changes: 18 additions & 2 deletions pocketutils/core/_internal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import bz2
import gzip
import logging
import operator
Expand Down Expand Up @@ -113,20 +114,33 @@ def look(obj: Y, attrs: Union[str, Iterable[str], Callable[[Y], Z]]) -> Optional
return None


GZ_BZ2_SUFFIXES = {".gz", ".gzip", ".bz2", ".bzip2"}
JSON_SUFFIXES = {".json" + s for s in {"", *GZ_BZ2_SUFFIXES}}
TOML_SUFFIXES = {".toml" + s for s in {"", *GZ_BZ2_SUFFIXES}}


def read_txt_or_gz(path: PathLike) -> str:
path = Path(path)
if path.name.endswith(".bz2") or path.name.endswith(".bzip2"):
return bz2.decompress(path.read_bytes()).decode(encoding="utf8")
if path.name.endswith(".gz") or path.name.endswith(".gzip"):
return gzip.decompress(path.read_bytes()).decode(encoding="utf8")
return Path(path).read_text(encoding="utf8")


def write_txt_or_gz(txt: str, path: PathLike) -> None:
def write_txt_or_gz(txt: str, path: PathLike, *, mkdirs: bool = False) -> str:
path = Path(path)
if path.name.endswith(".gz") or path.name.endswith(".gzip"):
if mkdirs:
path.parent.mkdir(parents=True, exist_ok=True)
if path.name.endswith(".bz2") or path.name.endswith(".bzip2"):
data = bz2.compress(txt.encode(encoding="utf8"))
path.write_bytes(data)
elif path.name.endswith(".gz") or path.name.endswith(".gzip"):
data = gzip.compress(txt.encode(encoding="utf8"))
path.write_bytes(data)
else:
path.write_text(txt)
return txt


def null_context():
Expand All @@ -142,4 +156,6 @@ def null_context():
"PathLikeUtils",
"read_txt_or_gz",
"write_txt_or_gz",
"GZ_SUFFIXES",
"BZ2_SUFFIXES",
]
90 changes: 66 additions & 24 deletions pocketutils/core/dot_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import gzip
import json
import pickle
from copy import copy
from datetime import date, datetime
Expand Down Expand Up @@ -35,23 +33,22 @@ class NestedDotDict(Mapping):
Keys must be strings that do not contain a dot (.).
A dot is reserved for splitting values to traverse the tree.
For example, ``dotdict["pet.species.name"]``.
"""

@classmethod
def read_toml(cls, path: PathLike) -> NestedDotDict:
return NestedDotDict(toml.loads(read_txt_or_gz(path)))

@classmethod
def read_json(cls, path: Union[PurePath, str]) -> NestedDotDict:
def read_json(cls, path: PathLike) -> NestedDotDict:
"""
Reads JSON from a file, into a NestedDotDict.
If the JSON data is a list type, converts into a dict with the key ``data``.
If the JSON data is a list type, converts into a dict with keys ``"1", "2", ...`` .
Can read .json or .json.gz.
"""
data = orjson.loads(read_txt_or_gz(path))
if isinstance(data, list):
data = {"data": data}
data = dict(enumerate(data))
return cls(data)

@classmethod
Expand All @@ -75,9 +72,9 @@ def parse_json(cls, data: str) -> NestedDotDict:
Parses JSON from a string, into a NestedDotDict.
If the JSON data is a list type, converts into a dict with the key ``data``.
"""
data = json.loads(data)
data = orjson.loads(data.encode(encoding="utf8"))
if isinstance(data, list):
data = {"data": data}
data = dict(enumerate(data))
return cls(data)

@classmethod
Expand All @@ -104,24 +101,45 @@ def __init__(self, x: Mapping[str, Any]) -> None:
if len(bad) > 0:
raise XValueError(f"Keys contained dots (.) for these values: {bad}", value=bad)
self._x = x
# Let's make sure this constructor gets called on subdicts:
# Let's make sure this constructor gets called on sub-dicts:
self.leaves()

def write_json(self, path: PathLike, *, indent: bool = False) -> None:
write_txt_or_gz(self.to_json(indent=indent), path)
def write_json(self, path: PathLike, *, indent: bool = False, mkdirs: bool = False) -> str:
"""
Writes to a json or .json.gz file.
Returns:
The JSON text
"""
return write_txt_or_gz(self.to_json(indent=indent), path, mkdirs=mkdirs)

def write_toml(self, path: PathLike, mkdirs: bool = False) -> str:
"""
Writes to a toml or .toml.gz file.
def write_toml(self, path: PathLike) -> None:
write_txt_or_gz(self.to_toml(), path)
Returns:
The JSON text
"""
return write_txt_or_gz(self.to_toml(), path, mkdirs=mkdirs)

def write_pickle(self, path: PathLike) -> None:
"""
Writes to a pickle file.
"""
Path(path).write_bytes(pickle.dumps(self._x, protocol=PICKLE_PROTOCOL))

def to_json(self, *, indent: bool = False) -> str:
"""
Returns JSON text.
"""
kwargs = dict(option=orjson.OPT_INDENT_2) if indent else {}
encoded = orjson.dumps(self._x, default=_json_encode_default, **kwargs)
return encoded.decode(encoding="utf8")

def to_toml(self) -> str:
"""
Returns TOML text.
"""
return toml.dumps(self._x)

def leaves(self) -> Mapping[str, Any]:
Expand All @@ -142,13 +160,29 @@ def leaves(self) -> Mapping[str, Any]:
return mp

def sub(self, items: str) -> NestedDotDict:
"""
Returns the dictionary under (dotted) keys ``items``.
See Also:
:meth:`sub_opt`
"""
return NestedDotDict(self[items])

def sub_opt(self, items: str) -> NestedDotDict:
"""
Returns the dictionary under (dotted) keys ``items``, or empty if a key is not found.
See Also:
:meth:`sub`
"""
try:
return NestedDotDict(self[items])
except XKeyError:
return NestedDotDict({})

def exactly(self, items: str, astype: Type[T]) -> T:
"""
Gets the key ``items`` from the dict if it has type ``astype``.
Calling ``dotdict.exactly(k, t) is equivalent to calling ``t(dotdict[k])``,
but a raised ``TypeError`` will note the key, making this a useful shorthand for the above within a try-except.
Args:
items: The key hierarchy, with a dot (.) as a separator
Expand All @@ -158,7 +192,7 @@ def exactly(self, items: str, astype: Type[T]) -> T:
The value in the required type
Raises:
TypeError: If not ``isinstance(value, astype)``
XTypeError: If not ``isinstance(value, astype)``
"""
z = self[items]
if not isinstance(z, astype):
Expand All @@ -174,7 +208,11 @@ def get_as(
) -> Optional[T]:
"""
Gets the value of an *optional* key, or ``default`` if it doesn't exist.
Also see ``req_as``.
Calls ``astype(value)`` on the value before returning.
See Also:
:meth:`req_as`
:meth:`exactly`
Args:
items: The key hierarchy, with a dot (.) as a separator.
Expand Down Expand Up @@ -203,7 +241,11 @@ def get_as(
def req_as(self, items: str, astype: Optional[Callable[[Any], T]]) -> T:
"""
Gets the value of a *required* key.
Also see ``get_as`` and ``exactly``.
Calls ``astype(value)`` on the value before returning.
See Also:
:meth:`req_as`
:meth:`exactly`
Args:
items: The key hierarchy, with a dot (.) as a separator.
Expand Down Expand Up @@ -317,12 +359,9 @@ def pretty_str(self) -> str:
Returns:
A multi-line string
"""
return json.dumps(
self.leaves(),
sort_keys=True,
indent=4,
ensure_ascii=False,
)
return orjson.dumps(
self.leaves(), option=orjson.OPT_SORT_KEYS | orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z
).decode(encoding="utf8")

def __len__(self) -> int:
"""
Expand All @@ -331,6 +370,9 @@ def __len__(self) -> int:
"""
return len(self._x)

def is_empty(self) -> bool:
return len(self._x) == 0

def __iter__(self):
"""
Iterates over values in this dict.
Expand Down
4 changes: 4 additions & 0 deletions pocketutils/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,10 @@ class IllegalPathError(PathError, ValueError):
"""Not a valid path (e.g. not ok on the filesystem)."""


class InvalidFileType(PathError):
"""Not a valid file type."""


class FileDoesNotExistError(PathError):
"""A file is expected, but the path does not exist."""

Expand Down
47 changes: 27 additions & 20 deletions pocketutils/core/iterators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import abc
from collections.abc import Iterator as _Iterator
from typing import Any, Generic, Iterable, TypeVar
from typing import Iterable, Sequence, Tuple, TypeVar

import numpy as np

T = TypeVar("T")
IX = TypeVar("IX")


# noinspection PyAbstractClass
class SizedIterator(_Iterator, metaclass=abc.ABCMeta):
class SizedIterator(_Iterator[T], metaclass=abc.ABCMeta):
"""
An iterator with size and progress.
"""
Expand All @@ -30,42 +29,46 @@ def __len__(self) -> int:
return self.total()

def __repr__(self):
return "{}({} items)".format(self.__class__, self.total())
return f"It({self.position()}/{self.total()})"

def __str__(self):
return repr(self)


class SeqIterator(SizedIterator, Generic[IX]):
class SeqIterator(SizedIterator[T]):
"""
A concrete SizedIterator backed by a list.
"""

def __init__(self, it: Iterable[Any]):
self.seq, self.__i, self.current = list(it), 0, None
def __init__(self, it: Iterable[T]):
self.__seq, self.__i, self.__current = list(it), 0, None

@property
def seq(self) -> Sequence[T]:
return self.__seq

def reset(self) -> None:
self.__i, self.current = 0, None
self.__i, self.__current = 0, None

def peek(self) -> None:
return self.seq[self.__i]
def peek(self) -> T:
return self.__seq[self.__i]

def position(self) -> int:
return self.__i

def total(self) -> int:
return len(self.seq)
return len(self.__seq)

def __next__(self) -> T:
try:
self.current = self.seq[self.__i]
self.current = self.__seq[self.__i]
except IndexError:
raise StopIteration(f"Size is {len(self)}")
self.__i += 1
return self.current


class TieredIterator(SeqIterator):
class TieredIterator(SeqIterator[Tuple[IX]]):
"""
A SizedIterator that iterates over every tuples of combination from multiple sequences.
Expand All @@ -76,18 +79,22 @@ class TieredIterator(SeqIterator):
"""

# noinspection PyMissingConstructor
def __init__(self, sequence):
self.seqs = list([SeqIterator(s) for s in reversed(sequence)])
def __init__(self, sequence: Sequence[Sequence[IX]]):
self.__seqs = list([SeqIterator(s) for s in reversed(sequence)])
self.__total = 0 if len(self.seqs) == 0 else int(np.product([i.total() for i in self.seqs]))
self.__i = 0

@property
def seqs(self) -> Sequence[SeqIterator[IX]]:
return self.__seqs

def position(self) -> int:
return self.__i

def total(self) -> int:
return self.__total

def __next__(self):
def __next__(self) -> Tuple[IX]:
if not self.has_next():
raise StopIteration(f"Length is {self.total()}")
t = tuple((seq.peek() for seq in reversed(self.seqs)))
Expand All @@ -101,13 +108,13 @@ def __set(self, i: int):
next(seq)
else:
seq.reset()
if (
i < len(self.seqs) - 1
): # to avoid index error for last case, after which self.has_next() is False
# to avoid index error for last case, after which self.has_next() is False
if i < len(self.seqs) - 1:
self.__set(i + 1) # recurse

def __repr__(self):
return "{}({} items)".format(self.__class__.__name__, [it.total() for it in self.seqs])
sizes = ", ".join([f"{it.current}/{it.total()}" for it in self.seqs])
return f"Iter({sizes})"

def __str__(self):
return repr(self)
Expand Down
Loading

0 comments on commit 81ca098

Please sign in to comment.