diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3614ed4..ed1e0ee 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,7 +19,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v3 with: - version: "0.4.21" + version: "0.4.27" enable-cache: true cache-dependency-glob: "**/pyproject.toml" - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index cd46b55..29e9038 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -22,7 +22,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v3 with: - version: "0.4.21" + version: "0.4.27" enable-cache: true cache-dependency-glob: "**/pyproject.toml" - name: Set up Python @@ -43,7 +43,7 @@ jobs: needs: [build] environment: name: release - url: https://pypi.org/p/coqui-tts-coqpit + url: https://pypi.org/p/coqpit-config permissions: id-token: write steps: diff --git a/.gitignore b/.gitignore index 33a2f6e..a3bbda1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +uv.lock + WadaSNR/ .idea/ *.pyc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 727d5cc..5d5f10e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,3 +12,10 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.12.0 + hooks: + - id: mypy + args: [--strict] + additional_dependencies: + - "pytest" diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 8b8dab6..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,9 +0,0 @@ -include README.md -include LICENSE.txt -recursive-include coqpit *.json -recursive-include coqpit *.html -recursive-include coqpit *.png -recursive-include coqpit *.md -recursive-include coqpit *.py -recursive-include coqpit *.pyx -recursive-include images *.png diff --git a/README.md b/README.md index 762ab5f..4b0dea8 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,12 @@ [![CI](https://github.com/idiap/coqui-ai-coqpit/actions/workflows/main.yml/badge.svg?branch=main)](https://github.com/idiap/coqui-ai-coqpit/actions/workflows/main.yml) -Simple, light-weight and no dependency config handling through python data classes with to/from JSON serialization/deserialization. +Simple, light-weight and no dependency config handling through python data +classes with to/from JSON serialization/deserialization. -Currently it is being used by [🐸TTS](https://github.com/idiap/coqui-ai-TTS). +Fork of the [original, unmaintained repository](https://github.com/coqui-ai/coqpit). New PyPI package: [coqpit-config](https://pypi.org/project/coqpit-config) + +Currently it is being used by [coqui-tts](https://github.com/idiap/coqui-ai-TTS). ## ❔ Why I need this What I need from a ML configuration library... diff --git a/coqpit/__init__.py b/coqpit/__init__.py index aa97fe8..0635173 100644 --- a/coqpit/__init__.py +++ b/coqpit/__init__.py @@ -1,8 +1,7 @@ import importlib.metadata -from dataclasses import dataclass from coqpit.coqpit import MISSING, Coqpit, check_argument -__all__ = ["dataclass", "MISSING", "Coqpit", "check_argument"] +__all__ = ["MISSING", "Coqpit", "check_argument"] __version__ = importlib.metadata.version("coqpit") diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 4c4330b..73d34f6 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -1,113 +1,122 @@ +"""Simple, light-weight config handling through Python data classes.""" + +from __future__ import annotations + import argparse import contextlib -import functools import json import operator -from collections.abc import MutableMapping +import sys +import typing +from collections.abc import ItemsView, Iterable, Iterator, MutableMapping from dataclasses import MISSING as _MISSING from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace from pathlib import Path from pprint import pprint -from typing import Any, Generic, Optional, TypeVar, Union +from types import GenericAlias +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, overload + +from typing_extensions import Self, TypeAlias, TypeGuard, TypeIs + +# TODO: Available from Python 3.10 +if sys.version_info >= (3, 10): + from types import UnionType +else: + UnionType: TypeAlias = Union -T = TypeVar("T") +if TYPE_CHECKING: # pragma: no cover + import os + from dataclasses import _MISSING_TYPE + + from _typeshed import SupportsKeysAndGetItem + +_T = TypeVar("_T") MISSING: Any = "???" -class _NoDefault(Generic[T]): +class _NoDefault(Generic[_T]): pass -NoDefaultVar = Union[_NoDefault[T], T] -no_default: NoDefaultVar = _NoDefault() +NoDefaultVar: TypeAlias = Union[_NoDefault[_T], _T] +no_default: NoDefaultVar[Any] = _NoDefault() +FieldType: TypeAlias = Union[str, type, "UnionType"] -def is_primitive_type(arg_type: Any) -> bool: + +def _is_primitive_type(field_type: FieldType) -> TypeGuard[type]: """Check if the input type is one of `int, float, str, bool`. Args: - arg_type (typing.Any): input type to check. + field_type: input type to check. Returns: bool: True if input type is one of `int, float, str, bool`. """ - try: - return isinstance(arg_type(), (int, float, str, bool)) - except (AttributeError, TypeError): - return False + return field_type is int or field_type is float or field_type is str or field_type is bool -def is_list(arg_type: Any) -> bool: +def _is_list(field_type: FieldType) -> TypeGuard[type]: """Check if the input type is `list`. Args: - arg_type (typing.Any): input type. + field_type: input type. Returns: bool: True if input type is `list` """ - try: - return arg_type is list or arg_type is list or arg_type.__origin__ is list or arg_type.__origin__ is list - except AttributeError: - return False + return field_type is list or typing.get_origin(field_type) is list -def is_dict(arg_type: Any) -> bool: +def _is_dict(field_type: FieldType) -> TypeGuard[type]: """Check if the input type is `dict`. Args: - arg_type (typing.Any): input type. + field_type: input type. Returns: bool: True if input type is `dict` """ - try: - return arg_type is dict or arg_type is dict or arg_type.__origin__ is dict - except AttributeError: - return False + return field_type is dict or typing.get_origin(field_type) is dict -def is_union(arg_type: Any) -> bool: +def _is_union(field_type: FieldType) -> TypeIs[UnionType]: """Check if the input type is `Union`. Args: - arg_type (typing.Any): input type. + field_type: input type. Returns: bool: True if input type is `Union` """ - try: - return safe_issubclass(arg_type.__origin__, Union) - except AttributeError: - return False + origin = typing.get_origin(field_type) + is_union = origin is Union + if sys.version_info >= (3, 10): + is_union = is_union or origin is UnionType + return is_union -def safe_issubclass(cls, classinfo) -> bool: - """Check if the input type is a subclass of the given class. +def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionType]: + """Check if the input type is `Union`. + + Note: `int | None` would be of type Union, but here we don't consider such + cases where the only other accepted type is None. Args: - cls (type): input type. - classinfo (type): parent class. + field_type: input type. Returns: - bool: True if the input type is a subclass of the given class + bool: True if input type is `Union` and not optional type like `int | None` """ - try: - r = issubclass(cls, classinfo) - except Exception: # pylint: disable=broad-except - return cls is classinfo - else: - return r - - -def _coqpit_json_default(obj: Any) -> Any: - if isinstance(obj, Path): - return str(obj) - msg = f"Can't encode object of type {type(obj).__name__}" - raise TypeError(msg) + args = typing.get_args(field_type) + is_python_union = _is_union(field_type) + if is_python_union and len(args) == 2 and type(None) in args: # noqa: PLR2004 + # This is an Optional type like `int | None` + return False + return is_python_union -def _default_value(x: Field): +def _default_value(x: Field[_T]) -> _T | Literal[_MISSING_TYPE.MISSING]: """Return the default value of the input Field. Args: @@ -116,27 +125,43 @@ def _default_value(x: Field): Returns: object: default value of the input Field. """ - if x.default not in (MISSING, _MISSING): + if x.default != MISSING and x.default is not _MISSING: return x.default - if x.default_factory not in (MISSING, _MISSING): + if x.default_factory != MISSING and x.default_factory is not _MISSING: return x.default_factory() return x.default -def _is_optional_field(field) -> bool: - """Check if the input field is optional. +def _is_optional_field(field_type: FieldType) -> TypeGuard[UnionType]: + """Check if the input field type is optional. Args: - field (Field): input Field to check. + field_type: input Field's type to check. Returns: - bool: True if the input field is optional. + bool: True if the input field type is optional. """ - # return isinstance(field.type, _GenericAlias) and type(None) in getattr(field.type, "__args__") - return type(None) in field.type.__args__ + return type(None) in typing.get_args(field_type) -def _serialize(x): +def _drop_none_type(field_type: FieldType) -> FieldType: + """Drop None from Union-like types. + + >>> _drop_none_type(str | int | None) + str | int + """ + if not _is_union(field_type): + return field_type + origin = typing.get_origin(field_type) + args = list(typing.get_args(field_type)) + if type(None) in args: + args.remove(type(None)) + if len(args) == 1: + return typing.cast(type, args[0]) + return typing.cast("UnionType", GenericAlias(origin, args)) + + +def _serialize(x: Any) -> Any: """Pick the right serialization for the datatype of the given input. Args: @@ -154,11 +179,11 @@ def _serialize(x): if isinstance(x, Serializable) or issubclass(type(x), Serializable): return x.serialize() if isinstance(x, type) and issubclass(x, Serializable): - return x.serialize(x) + return x.serialize(x()) return x -def _deserialize_dict(x: dict) -> dict: +def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]: """Deserialize dict. Args: @@ -167,7 +192,7 @@ def _deserialize_dict(x: dict) -> dict: Returns: Dict: deserialized dictionary. """ - out_dict = {} + out_dict: dict[Any, Any] = {} for k, v in x.items(): if v is None: # if {'key':None} out_dict[k] = None @@ -176,7 +201,7 @@ def _deserialize_dict(x: dict) -> dict: return out_dict -def _deserialize_list(x: list, field_type: type) -> list: +def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]: """Deserialize values for List typed fields. Args: @@ -189,25 +214,20 @@ def _deserialize_list(x: list, field_type: type) -> list: Returns: [List]: deserialized list. """ - field_args = None - if hasattr(field_type, "__args__") and field_type.__args__: - field_args = field_type.__args__ - elif hasattr(field_type, "__parameters__") and field_type.__parameters__: - # bandaid for python 3.6 - field_args = field_type.__parameters__ - if field_args: - if len(field_args) > 1: - msg = " [!] Coqpit does not support multi-type hinted 'List'" - raise ValueError(msg) - field_arg = field_args[0] - # if field type is TypeVar set the current type by the value's type. - if isinstance(field_arg, TypeVar): - field_arg = type(x) - return [_deserialize(xi, field_arg) for xi in x] - return x + field_args = typing.get_args(field_type) + if len(field_args) == 0: + return x + if len(field_args) > 1: + msg = "Coqpit does not support multi-type hinted 'List'" + raise ValueError(msg) + field_arg = field_args[0] + # if field type is TypeVar set the current type by the value's type. + if isinstance(field_arg, TypeVar): + field_arg = type(x) + return [_deserialize(xi, field_arg) for xi in x] -def _deserialize_union(x: Any, field_type: type) -> Any: +def _deserialize_union(x: Any, field_type: UnionType) -> Any: """Deserialize values for Union typed fields. Args: @@ -215,9 +235,9 @@ def _deserialize_union(x: Any, field_type: type) -> Any: field_type (Type): field type. Returns: - [Any]: desrialized value. + [Any]: deserialized value. """ - for arg in field_type.__args__: + for arg in typing.get_args(field_type): # stop after first matching type in Union try: x = _deserialize(x, arg) @@ -227,7 +247,10 @@ def _deserialize_union(x: Any, field_type: type) -> Any: return x -def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: type) -> Union[int, float, str, bool]: +def _deserialize_primitive_types( + x: int | float | str | bool | None, # noqa: PYI041 + field_type: FieldType, +) -> int | float | str | bool | None: """Deserialize python primitive types (float, int, str, bool). It handles `inf` values exclusively and keeps them float against int fields since int does not support inf values. @@ -242,16 +265,26 @@ def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: ty if isinstance(x, (str, bool)): return x if isinstance(x, (int, float)): + base_type = _drop_none_type(field_type) + if base_type is not float and base_type is not int and base_type is not str and base_type is not bool: + raise TypeError + base_type = typing.cast(type[Union[int, float, str, bool]], base_type) if x == float("inf") or x == float("-inf"): # if value type is inf return regardless. return x - return field_type(x) - # TODO: Raise an error when x does not match the types. + return base_type(x) return None -def _deserialize(x: Any, field_type: Any) -> Any: - """Pick the right desrialization for the given object and the corresponding field type. +def _deserialize_path(x: Any, field_type: FieldType) -> Path | None: + """Deserialize to a Path.""" + if x is None and _is_optional_field(field_type): + return None + return Path(x) + + +def _deserialize(x: Any, field_type: FieldType) -> Any: + """Pick the right deserialization for the given object and the corresponding field type. Args: x (object): object to be deserialized. @@ -261,63 +294,76 @@ def _deserialize(x: Any, field_type: Any) -> Any: object: deserialized object """ - # pylint: disable=too-many-return-statements - if is_dict(field_type): + if isinstance(field_type, str): + msg = "Strings as type hints are not supported." + raise NotImplementedError(msg) + if _is_dict(_drop_none_type(field_type)): return _deserialize_dict(x) - if is_list(field_type): - return _deserialize_list(x, field_type) - if is_union(field_type): + if _is_list(_drop_none_type(field_type)): + return _deserialize_list(x, _drop_none_type(field_type)) + if _is_union_and_not_simple_optional(field_type): return _deserialize_union(x, field_type) - if issubclass(field_type, Serializable): + if not _is_union(field_type) and isinstance(field_type, type) and issubclass(field_type, Serializable): return field_type.deserialize_immutable(x) - if is_primitive_type(field_type): + if _drop_none_type(field_type) is Path: + return _deserialize_path(x, field_type) + if _is_primitive_type(_drop_none_type(field_type)): return _deserialize_primitive_types(x, field_type) msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type." raise ValueError(msg) -# Recursive setattr (supports dotted attr names) -def rsetattr(obj, attr, val): - def _setitem(obj, attr, val): - return operator.setitem(obj, int(attr), val) +CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"] +CoqpitNestedValue: TypeAlias = Union["CoqpitValue", CoqpitType] +CoqpitValue: TypeAlias = Union[str, int, float, bool, None] - pre, _, post = attr.rpartition(".") - setfunc = _setitem if post.isnumeric() else setattr - - return setfunc(rgetattr(obj, pre) if pre else obj, post, val) +# TODO: It should be possible to get rid of the next 3 `type: ignore`. At +# nested levels, the key can be `str | int` as well, not just `str`. +def _rsetattr(obj: CoqpitType, keys: str, val: CoqpitValue) -> None: + """Recursive setattr (supports dotted key names).""" + pre, _, post = keys.rpartition(".") + target = _rgetattr(obj, pre) if pre else obj + if post.isnumeric(): + operator.setitem(target, int(post), val) # type: ignore[misc] + else: + setattr(target, post, val) -# Recursive getattr (supports dotted attr names) -def rgetattr(obj, attr, *args): - def _getitem(obj, attr): - return operator.getitem(obj, int(attr), *args) - def _getattr(obj, attr): - getfunc = _getitem if attr.isnumeric() else getattr - return getfunc(obj, attr, *args) +def _rgetattr(obj: CoqpitType, keys: str) -> CoqpitType: + """Recursive getattr (supports dotted key names).""" + v = obj + for k in keys.split("."): + v = operator.getitem(v, int(k)) if k.isnumeric() else getattr(v, k) # type: ignore[arg-type] + return v - return functools.reduce(_getattr, [obj, *attr.split(".")]) +def _rsetitem(obj: CoqpitType, keys: str, value: CoqpitValue) -> None: + """Recursive setitem (supports dotted key names). -# Recursive setitem (supports dotted attr names) -def rsetitem(obj, attr, val): - pre, _, post = attr.rpartition(".") - return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val) + _rsetitem(a, "b.c", 1) => a["b"]["c"] = 1 + """ + pre, _, post = keys.rpartition(".") + operator.setitem(_rgetitem(obj, pre) if pre else obj, post, value) -# Recursive getitem (supports dotted attr names) -def rgetitem(obj, attr, *args): - def _getitem(obj, attr): - return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args) +def _rgetitem(obj: CoqpitType, keys: str) -> CoqpitType: + """Recursive getitem (supports dotted key names). - return functools.reduce(_getitem, [obj, *attr.split(".")]) + _rgetitem(a, "b.c") => a["b"]["c"] + """ + v = obj + for k in keys.split("."): + v = operator.getitem(v, int(k) if k.isnumeric() else k) # type: ignore[arg-type] + return v @dataclass class Serializable: """Gives serialization ability to any inheriting dataclass.""" - def __post_init__(self): + def __post_init__(self) -> None: + """Validate contracts and check required arguments are specified.""" self._validate_contracts() for key, value in self.__dict__.items(): if value is no_default: @@ -325,12 +371,13 @@ def __post_init__(self): raise TypeError(msg) def _validate_contracts(self) -> None: + """Validate contracts specified in the dataclass.""" dataclass_fields = fields(self) for field in dataclass_fields: value = getattr(self, field.name) - if value is None and not _is_optional_field(field): + if value is None and not _is_optional_field(field.type): msg = f"{field.name} is not optional" raise TypeError(msg) @@ -343,13 +390,11 @@ def _validate_contracts(self) -> None: def validate(self) -> None: """Validate if object can serialize / deserialize correctly.""" self._validate_contracts() - if self != self.__class__.deserialize( # pylint: disable=no-value-for-parameter - json.loads(json.dumps(self.serialize())), - ): + if self != self.__class__().deserialize(json.loads(json.dumps(self.serialize()))): msg = "could not be deserialized with same value" raise ValueError(msg) - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: """Transform serializable object to dict.""" cls_fields = fields(self) o = {} @@ -357,7 +402,7 @@ def to_dict(self) -> dict: o[cls_field.name] = getattr(self, cls_field.name) return o - def serialize(self) -> dict: + def serialize(self) -> dict[str, Any]: """Serialize object to be json serializable representation.""" if not is_dataclass(self): msg = "need to be decorated as dataclass" @@ -373,8 +418,8 @@ def serialize(self) -> dict: o[field.name] = value return o - def deserialize(self, data: dict) -> "Serializable": - """Parse input dictionary and desrialize its fields to a dataclass. + def deserialize(self, data: dict[str, Any]) -> Self: + """Parse input dictionary and deserialize its fields to a dataclass. Returns: self: deserialized `self`. @@ -396,7 +441,7 @@ def deserialize(self, data: dict) -> "Serializable": init_kwargs[field.name] = value continue if value == MISSING: - msg = f"deserialized with unknown value for {field.name} in {self.__name__}" + msg = f"deserialized with unknown value for {field.name} in {self.__class__.__name__}" raise ValueError(msg) value = _deserialize(value, field.type) init_kwargs[field.name] = value @@ -405,8 +450,8 @@ def deserialize(self, data: dict) -> "Serializable": return self @classmethod - def deserialize_immutable(cls, data: dict) -> "Serializable": - """Parse input dictionary and desrialize its fields to a dataclass. + def deserialize_immutable(cls, data: dict[str, Any]) -> Self: + """Parse input dictionary and deserialize its fields to a dataclass. Returns: Newly created deserialized object. @@ -445,41 +490,49 @@ def deserialize_immutable(cls, data: dict) -> "Serializable": # ---------------------------------------------------------------------------- # -def _get_help(field): +def _get_help(field: Field[Any]) -> str: try: - field_help = field.metadata["help"] + return str(field.metadata["help"]) except KeyError: - field_help = "" - return field_help - - -def _init_argparse( - parser, - field_name, - field_type, - field_default, - field_default_factory, - field_help, - arg_prefix="", - help_prefix="", + return "" + + +def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915 + parser: argparse.ArgumentParser, + field_name: str, + field_type: FieldType, + field_default: Any, + field_default_factory: Callable[[], Any] | Literal[_MISSING_TYPE.MISSING], + field_help: str, + arg_prefix: str = "", + help_prefix: str = "", *, relaxed_parser: bool = False, -): - has_default = False +) -> argparse.ArgumentParser: + """Add a new argument to the argparse parser, matching the given field.""" + if isinstance(field_type, str): + msg = "Strings as type hints are not supported." + raise NotImplementedError(msg) default = None + has_default = False if field_default: has_default = True default = field_default - elif field_default_factory not in (None, _MISSING): + elif field_default_factory is not None and field_default_factory is not _MISSING: has_default = True default = field_default_factory() - if not has_default and not is_primitive_type(field_type) and not is_list(field_type): - # aggregate types (fields with a Coqpit subclass as type) are not supported without None + if ( + not has_default + and not _is_primitive_type(_drop_none_type(field_type)) + and not _is_list(_drop_none_type(field_type)) + ): + # aggregate types (fields with a Coqpit subclass as type) are not + # supported without None return parser arg_prefix = field_name if arg_prefix == "" else f"{arg_prefix}.{field_name}" help_prefix = field_help if help_prefix == "" else f"{help_prefix} - {field_help}" - if is_dict(field_type): # pylint: disable=no-else-raise + if _is_dict(field_type): # NOTE: accept any string in json format as input to dict field. parser.add_argument( f"--{arg_prefix}", @@ -487,23 +540,23 @@ def _init_argparse( default=json.dumps(field_default) if field_default else None, type=json.loads, ) - elif is_list(field_type): + elif _is_list(_drop_none_type(field_type)): # TODO: We need a more clear help msg for lists. - if hasattr(field_type, "__args__"): # if the list is hinted - if len(field_type.__args__) > 1 and not relaxed_parser: - msg = " [!] Coqpit does not support multi-type hinted 'List'" - raise ValueError(msg) - list_field_type = field_type.__args__[0] - else: - msg = " [!] Coqpit does not support un-hinted 'List'" + field_args = typing.get_args(_drop_none_type(field_type)) + if len(field_args) > 1 and not relaxed_parser: + msg = "Coqpit does not support multi-type hinted 'List'" + raise ValueError(msg) + if len(field_args) == 0: + msg = "Coqpit does not support un-hinted 'List'" raise ValueError(msg) + list_field_type = field_args[0] # TODO: handle list of lists - if is_list(list_field_type) and relaxed_parser: + if _is_list(list_field_type) and relaxed_parser: return parser if not has_default or field_default_factory is list: - if not is_primitive_type(list_field_type) and not relaxed_parser: + if not _is_primitive_type(list_field_type) and not relaxed_parser: msg = " [!] Empty list with non primitive inner type is currently not supported." raise NotImplementedError(msg) @@ -517,8 +570,11 @@ def _init_argparse( else: # If a default value is defined, just enable editing the values from argparse # TODO: allow inserting a new value/obj to the end of the list. + if not isinstance(default, list): + msg = f"Default value must be a list, got {default}" + raise TypeError(msg) for idx, fv in enumerate(default): - parser = _init_argparse( + parser = _add_argument( parser, str(idx), list_field_type, @@ -529,21 +585,25 @@ def _init_argparse( arg_prefix=f"{arg_prefix}", relaxed_parser=relaxed_parser, ) - elif is_union(field_type): + elif _is_union_and_not_simple_optional(field_type): # TODO: currently I don't know how to handle Union type on argparse if not relaxed_parser: msg = " [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue." raise NotImplementedError(msg) - elif issubclass(field_type, Serializable): + elif not _is_union(field_type) and issubclass(field_type, Coqpit): + if not isinstance(default, Coqpit): + msg = f"Default value must be a Coqpit instance, got {default}" + raise TypeError(msg) return default.init_argparse( - parser, + instance=default, + parser=parser, arg_prefix=arg_prefix, help_prefix=help_prefix, relaxed_parser=relaxed_parser, ) - elif isinstance(field_type(), bool): + elif field_type is bool: - def parse_bool(x): + def parse_bool(x: str) -> bool: if x not in ("true", "false"): msg = f' [!] Value for boolean field must be either "true" or "false". Got "{x}".' raise ValueError(msg) @@ -556,11 +616,14 @@ def parse_bool(x): help=f"Coqpit Field: {help_prefix}", metavar="true/false", ) - elif is_primitive_type(field_type): + elif _is_primitive_type(_drop_none_type(field_type)): + base_type = _drop_none_type(field_type) + if _is_union(base_type): + raise TypeError parser.add_argument( f"--{arg_prefix}", default=field_default, - type=field_type, + type=base_type, help=f"Coqpit Field: {help_prefix}", ) elif not relaxed_parser: @@ -575,17 +638,19 @@ def parse_bool(x): @dataclass -class Coqpit(Serializable, MutableMapping): +class Coqpit(Serializable, CoqpitType): """Coqpit base class to be inherited by any Coqpit dataclasses. It overrides Python `dict` interface and provides `dict` compatible API. - It also enables serializing/deserializing a dataclass to/from a json file, plus some semi-dynamic type and value check. + It also enables serializing/deserializing a dataclass to/from a json file, + plus some semi-dynamic type and value check. + Note that it does not support all datatypes and likely to fail in some cases. """ _initialized = False - def _is_initialized(self): + def _is_initialized(self) -> bool: """Check if Coqpit is initialized. Useful to prevent running some aux functions @@ -593,35 +658,40 @@ def _is_initialized(self): """ return "_initialized" in vars(self) and self._initialized - def __post_init__(self): + def __post_init__(self) -> None: + """Check values if a check_values() method is defined.""" self._initialized = True with contextlib.suppress(AttributeError): self.check_values() ## `dict` API functions - def __iter__(self): + def __iter__(self) -> Iterator[str]: + """Return iterator over the Coqpit.""" return iter(asdict(self)) def __len__(self) -> int: + """Return the number of fields in the Coqpit.""" return len(fields(self)) def __setitem__(self, arg: str, value: Any) -> None: + """Set the value for the given attribute.""" setattr(self, arg, value) - def __getitem__(self, arg: str): + def __getitem__(self, arg: str) -> Any: """Access class attributes with ``[arg]``.""" return self.__dict__[arg] def __delitem__(self, arg: str) -> None: + """Remove an attribute.""" delattr(self, arg) - def _keytransform(self, key): # pylint: disable=no-self-use + def _keytransform(self, key: str) -> str: return key ## end `dict` API functions - def __getattribute__(self, arg: str): # pylint: disable=no-self-use + def __getattribute__(self, arg: str) -> Any: """Check if the mandatory field is defined when accessing it.""" value = super().__getattribute__(arg) if isinstance(value, str) and value == "???": @@ -629,18 +699,21 @@ def __getattribute__(self, arg: str): # pylint: disable=no-self-use raise AttributeError(msg) return value - def __contains__(self, arg: str) -> bool: + def __contains__(self, arg: object) -> bool: + """Check whether the Coqpit contains the given attribute.""" return arg in self.to_dict() - def get(self, key: str, default: Any = None): + def get(self, key: str, default: Any = None) -> Any: + """Return value of the given attribute if present, otherwise the default.""" if self.has(key): return asdict(self)[key] return default - def items(self): + def items(self) -> ItemsView[str, Any]: + """Return (key, value) items of the Coqpit.""" return asdict(self).items() - def merge(self, coqpits: Union["Coqpit", list["Coqpit"]]) -> None: + def merge(self, coqpits: Coqpit | list[Coqpit]) -> None: """Merge a coqpit instance or a list of coqpit instances to self. Note that it does not pass the fields and overrides attributes with @@ -651,7 +724,7 @@ def merge(self, coqpits: Union["Coqpit", list["Coqpit"]]) -> None: coqpits (Union[Coqpit, List[Coqpit]]): coqpit instance or list of instances to be merged. """ - def _merge(coqpit) -> None: + def _merge(coqpit: Coqpit) -> None: self.__dict__.update(coqpit.__dict__) self.__annotations__.update(coqpit.__annotations__) self.__dataclass_fields__.update(coqpit.__dataclass_fields__) @@ -663,57 +736,75 @@ def _merge(coqpit) -> None: _merge(coqpits) def check_values(self) -> None: - pass + """Perform data validation after initialization. + + Can be implemented in subclasses. + """ def has(self, arg: str) -> bool: + """Check whether the Coqpit has the given attribute.""" return arg in vars(self) - def copy(self): + def copy(self) -> Self: + """Return a copy of the Coqpit.""" return replace(self) - def update(self, new: dict, allow_new: bool = False) -> None: + @overload + def update(self, other: SupportsKeysAndGetItem[str, CoqpitNestedValue], /, **kwargs: CoqpitNestedValue) -> None: ... + @overload + def update(self, other: Iterable[tuple[str, CoqpitNestedValue]], /, **kwargs: CoqpitNestedValue) -> None: ... + @overload + def update(self, /, **kwargs: CoqpitNestedValue) -> None: ... + def update(self, other: Any = (), /, **kwargs: CoqpitNestedValue) -> None: """Update Coqpit fields by the input ```dict```. Args: - new (dict): dictionary with new values. - allow_new (bool, optional): allow new fields to add. Defaults to False. + other: dictionary or iterable with new values. + **kwargs: alternative way to pass new keys and values. """ - for key, value in new.items(): - if allow_new or hasattr(self, key): + if isinstance(other, dict): + for key in other: + setattr(self, key, other[key]) + elif hasattr(other, "keys"): + for key in other.keys(): # noqa: SIM118 + setattr(self, key, other[key]) + else: + for key, value in other: setattr(self, key, value) - else: - msg = f" [!] No key - {key}" - raise KeyError(msg) + for key, value in kwargs.items(): + setattr(self, key, value) def pprint(self) -> None: """Print Coqpit fields in a format.""" - pprint(asdict(self)) + pprint(asdict(self)) # noqa: T203 - def to_dict(self) -> dict: - # return asdict(self) + def to_dict(self) -> dict[str, Any]: + """Convert the Coqpit to a dictionary, serializing any values.""" return self.serialize() - def from_dict(self, data: dict) -> None: - self = self.deserialize(data) # pylint: disable=self-cls-assignment + def from_dict(self, data: dict[str, Any]) -> None: + """Update Coqpit from the dictionary.""" + self.deserialize(data) @classmethod - def new_from_dict(cls: Serializable, data: dict) -> "Coqpit": + def new_from_dict(cls, data: dict[str, Any]) -> Self: + """Create a new Coqpit from a dictionary.""" return cls.deserialize_immutable(data) def to_json(self) -> str: - """Returns a JSON string representation.""" - return json.dumps(asdict(self), indent=4, default=_coqpit_json_default) + """Return a JSON string representation.""" + return json.dumps(self.to_dict(), indent=4) - def save_json(self, file_name: str) -> None: + def save_json(self, file_name: str | os.PathLike[Any]) -> None: """Save Coqpit to a json file. Args: file_name (str): path to the output json file. """ - with open(file_name, "w", encoding="utf8") as f: - json.dump(asdict(self), f, indent=4) + with Path(file_name).open("w", encoding="utf8") as f: + json.dump(self.to_dict(), f, indent=4) - def load_json(self, file_name: str) -> None: + def load_json(self, file_name: str | os.PathLike[Any]) -> None: """Load a json file and update matching config fields with type checking. Non-matching parameters in the json file are ignored. @@ -724,38 +815,42 @@ def load_json(self, file_name: str) -> None: Returns: Coqpit: new Coqpit with updated config fields. """ - with open(file_name, encoding="utf8") as f: + with Path(file_name).open(encoding="utf8") as f: input_str = f.read() dump_dict = json.loads(input_str) - # TODO: this looks stupid 💆 - self = self.deserialize(dump_dict) # pylint: disable=self-cls-assignment + self.deserialize(dump_dict) self.check_values() @classmethod def init_from_argparse( cls, - args: Optional[Union[argparse.Namespace, list[str]]] = None, + args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", - ) -> "Coqpit": + ) -> Self: """Create a new Coqpit instance from argparse input. Args: - args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. - arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + args (namespace or list of str, optional): parsed argparse.Namespace + or list of command line parameters. If unspecified will use a + newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to + ```init_argparse``` when ```args``` is not passed. """ if not args: # If args was not specified, parse from sys.argv - parser = cls.init_argparse(cls, arg_prefix=arg_prefix) - args = parser.parse_args() # pylint: disable=E1120, E1111 + parser = cls.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args() if isinstance(args, list): - # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace - parser = cls.init_argparse(cls, arg_prefix=arg_prefix) - args = parser.parse_args(args) # pylint: disable=E1120, E1111 + # If a list was passed in (eg. the second result of + # `parse_known_args`, run that through argparse first to get a + # parsed Namespace + parser = cls.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args(args) # Handle list and object attributes with defaults, which can be modified # directly (eg. --coqpit.list.0.val_a 1), by constructing real objects # from defaults and passing those to `cls.__init__` - args_with_lists_processed = {} + args_with_lists_processed: CoqpitType = {} class_fields = fields(cls) for field in class_fields: has_default = False @@ -769,58 +864,61 @@ def init_from_argparse( has_default = True default = field_default_factory() - if has_default and (not is_primitive_type(field.type) or is_list(field.type)): + if has_default and (not _is_primitive_type(field.type) or _is_list(field.type)): args_with_lists_processed[field.name] = default args_dict = vars(args) - for k, v in args_dict.items(): + for key, v in args_dict.items(): # Remove argparse prefix (eg. "--coqpit." if present) - if k.startswith(f"{arg_prefix}."): - k = k[len(f"{arg_prefix}.") :] - - rsetitem(args_with_lists_processed, k, v) + k = key.removeprefix(f"{arg_prefix}.") + _rsetitem(args_with_lists_processed, k, v) return cls(**args_with_lists_processed) def parse_args( self, - args: Optional[Union[argparse.Namespace, list[str]]] = None, + args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", ) -> None: """Update config values from argparse arguments with some meta-programming ✨. Args: - args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. - arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + args (namespace or list of str, optional): parsed argparse.Namespace + or list of command line parameters. If unspecified will use a + newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to + ```init_argparse``` when ```args``` is not passed. """ if not args: # If args was not specified, parse from sys.argv - parser = self.init_argparse(arg_prefix=arg_prefix) + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix) args = parser.parse_args() if isinstance(args, list): - # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace - parser = self.init_argparse(arg_prefix=arg_prefix) + # If a list was passed in (eg. the second result of + # `parse_known_args`, run that through argparse first + # to get a parsed Namespace + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix) args = parser.parse_args(args) args_dict = vars(args) - for k, v in args_dict.items(): - if k.startswith(f"{arg_prefix}."): - k = k[len(f"{arg_prefix}.") :] + for key, v in args_dict.items(): + k = key.removeprefix(f"{arg_prefix}.") try: - rgetattr(self, k) + _rgetattr(self, k) except (TypeError, AttributeError) as e: msg = f" [!] '{k}' not exist to override from argparse." - raise Exception(msg) from e + raise TypeError(msg) from e - rsetattr(self, k, v) + _rsetattr(self, k, v) self.check_values() def parse_known_args( self, - args: Optional[Union[argparse.Namespace, list[str]]] = None, + args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", + *, relaxed_parser: bool = False, ) -> list[str]: """Update config values from argparse arguments. Ignore unknown arguments. @@ -828,54 +926,75 @@ def parse_known_args( This is analog to argparse.ArgumentParser.parse_known_args (vs parse_args). Args: - args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. - arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. - relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + args (namespace or list of str, optional): parsed argparse.Namespace + or list of command line parameters. If unspecified will use a + newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to + ```init_argparse``` when ```args``` is not passed. + relaxed_parser (bool, optional): If True, do not force all the fields + to have compatible types with the argparser. Defaults to False. Returns: List of unknown parameters. """ if not args: # If args was not specified, parse from sys.argv - parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) args, unknown = parser.parse_known_args() if isinstance(args, list): - # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace - parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + # If a list was passed in (eg. the second result of + # `parse_known_args`, run that through argparse first to get a + # parsed Namespace + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) args, unknown = parser.parse_known_args(args) self.parse_args(args) return unknown + @classmethod def init_argparse( - self, - parser: Optional[argparse.ArgumentParser] = None, - arg_prefix="coqpit", - help_prefix="", + cls, + *, + instance: Self | None = None, + parser: argparse.ArgumentParser | None = None, + arg_prefix: str = "coqpit", + help_prefix: str = "", relaxed_parser: bool = False, ) -> argparse.ArgumentParser: - """Pass Coqpit fields as argparse arguments. This allows to edit values through command-line. + """Create an argparse parser that can parse the Coqpit fields. + + This allows to edit values through command-line. Args: - parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created. - arg_prefix (str, optional): Prefix to be used for the argument name. Defaults to 'coqpit'. - help_prefix (str, optional): Prefix to be used for the argument description. Defaults to ''. - relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + instance (Coqpit, optional): instance of the given Coqpit class + to initialize any default values. + parser (argparse.ArgumentParser, optional): argparse.ArgumentParser + instance. If unspecified a new one will be created. + arg_prefix (str, optional): Prefix to be used for the argument name. + Defaults to 'coqpit'. + help_prefix (str, optional): Prefix to be used for the argument + description. Defaults to ''. + relaxed_parser (bool, optional): If True, do not force all the fields + to have compatible types with the argparser. Defaults to False. Returns: argparse.ArgumentParser: parser instance with the new arguments. """ if not parser: parser = argparse.ArgumentParser() - class_fields = fields(self) + cls_or_instance = cls if instance is None else instance + class_fields = fields(cls_or_instance) for field in class_fields: # use the current value of the field to prevent dropping the current value, # else use the default value of the field - field_default = vars(self).get(field.name, field.default if field.default is not _MISSING else None) + field_default = vars(cls_or_instance).get( + field.name, + field.default if field.default is not _MISSING else None, + ) field_type = field.type field_default_factory = field.default_factory field_help = _get_help(field) - _init_argparse( + _add_argument( parser, field.name, field_type, @@ -889,16 +1008,17 @@ def init_argparse( return parser -def check_argument( - name, - c, +def check_argument( # noqa: C901, PLR0913 + name: str, + c: dict[str, Any], + *, is_path: bool = False, - prerequest: Optional[str] = None, - enum_list: Optional[list] = None, - max_val: Optional[float] = None, - min_val: Optional[float] = None, + prerequest: list[str] | str | None = None, + enum_list: list[Any] | None = None, + max_val: float | None = None, + min_val: float | None = None, restricted: bool = False, - alternative: Optional[str] = None, + alternative: str | None = None, allow_none: bool = True, ) -> None: """Simple type and value checking for Coqpit. @@ -926,29 +1046,38 @@ def check_argument( >>> check_argument('fft_size', c, restricted=True, min_val=128, max_val=4058) """ # check if None allowed - if allow_none and c[name] is None: - return - if not allow_none: - assert c[name] is not None, f" [!] None value is not allowed for {name}." + if c[name] is None: + if allow_none: + return + msg = f" [!] None value is not allowed for {name}." + raise TypeError(msg) # check if restricted and it it is check if it exists - if isinstance(restricted, bool) and restricted: - assert name in c, f" [!] {name} not defined in config.json" + if isinstance(restricted, bool) and restricted and name not in c: + msg = f" [!] {name} not defined in config.json" + raise KeyError(msg) # check prerequest fields are defined if isinstance(prerequest, list): - assert any(f not in c for f in prerequest), f" [!] prequested fields {prerequest} for {name} are not defined." - else: - assert prerequest is None or prerequest in c, f" [!] prequested fields {prerequest} for {name} are not defined." + if any(f not in c for f in prerequest): + msg = f" [!] prequested fields {prerequest} for {name} are not defined." + raise KeyError(msg) + elif prerequest is not None and prerequest not in c: + msg = f" [!] prequested field {prerequest} for {name} is not defined." + raise KeyError(msg) # check if the path exists - if is_path: - assert Path(c[name]).exists(), f' [!] path for {name} ("{c[name]}") does not exist.' + if is_path and not Path(c[name]).exists(): + msg = f' [!] path for {name} ("{c[name]}") does not exist.' + raise FileNotFoundError(msg) # skip the rest if the alternative field is defined. - if alternative in c and c[alternative] is not None: + if alternative is not None and alternative in c and c[alternative] is not None: return # check value constraints if name in c: - if max_val is not None: - assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}" - if min_val is not None: - assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" - if enum_list is not None: - assert c[name].lower() in enum_list, f" [!] {name} is not a valid value" + if max_val is not None and c[name] > max_val: + msg = f" [!] {name} is larger than max value {max_val}" + raise ValueError + if min_val is not None and c[name] < min_val: + msg = f" [!] {name} is smaller than min value {min_val}" + raise ValueError + if enum_list is not None and c[name].lower() not in enum_list: + msg = f" [!] {name} is not a valid value" + raise ValueError diff --git a/coqpit/py.typed b/coqpit/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 9587bfb..3cadc39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,10 @@ [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" - -[tool.setuptools.packages.find] -include = ["coqpit*"] +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "coqpit" -version = "0.0.17" +version = "0.1.0" description = "Simple (maybe too simple), light-weight config management through python data-classes." readme = "README.md" requires-python = ">=3.9" @@ -32,11 +29,14 @@ classifiers = [ "Operating System :: MacOS", "Operating System :: Microsoft :: Windows", ] -dependencies = [] +dependencies = [ + "typing_extensions>=4.10", +] -[tool.uv] -dev-dependencies = [ +[dependency-groups] +dev = [ "coverage>=7", + "mypy>=1.12.0", "pre-commit>=3", "pytest>=8", "ruff==0.6.9", @@ -46,10 +46,35 @@ dev-dependencies = [ Repository = "https://github.com/idiap/coqui-ai-coqpit" Issues = "https://github.com/idiap/coqui-ai-coqpit/issues" +[tool.hatch.build] +exclude = [ + "/.github", + "/.gitignore", + "/.pre-commit-config.yaml", + "/Makefile", + "/tests", +] + [tool.ruff] target-version = "py39" line-length = 120 -lint.extend-select = ["I", "UP", "B", "W", "A", "PLC", "PLE"] +lint.select = ["ALL"] +lint.ignore = [ + "ANN401", + "D104", + "FIX", + "TD", +] [tool.ruff.lint.pydocstyle] convention = "google" + +[tool.ruff.lint.per-file-ignores] +"tests/**" = [ + "D", + "FA100", + "PLR2004", + "S101", + "SLF001", + "T201", +] diff --git a/setup.py b/setup.py deleted file mode 100644 index beda28e..0000000 --- a/setup.py +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env python - -from setuptools import setup - -setup() diff --git a/tests/test_copying.py b/tests/test_copying.py index c3dbf1d..ccd2120 100644 --- a/tests/test_copying.py +++ b/tests/test_copying.py @@ -9,7 +9,7 @@ class SimpleConfig(Coqpit): val_a: int = 10 -def test_copying(): +def test_copying() -> None: config = SimpleConfig() config_new = config.copy() diff --git a/tests/test_init_from_dict.py b/tests/test_init_from_dict.py index f7b9ffc..1e0a735 100644 --- a/tests/test_init_from_dict.py +++ b/tests/test_init_from_dict.py @@ -1,12 +1,15 @@ from dataclasses import dataclass, field +from typing import Optional + +import pytest from coqpit import Coqpit @dataclass class Person(Coqpit): - name: str = None - age: int = None + name: Optional[str] = None + age: Optional[int] = None @dataclass @@ -18,7 +21,7 @@ class Reference(Coqpit): Person(name="Eren", age=11), Person(name="Geren", age=12), Person(name="Ceren", age=15), - ] + ], ) people_ids: list[int] = field(default_factory=lambda: [1, 2, 3]) @@ -28,7 +31,7 @@ class WithRequired(Coqpit): name: str -def test_new_from_dict(): +def test_new_from_dict() -> None: ref_config = Reference(name="Fancy", size=3**10, people=[Person(name="Anonymous", age=42)]) new_config = Reference.new_from_dict({"name": "Fancy", "size": 3**10, "people": [{"name": "Anonymous", "age": 42}]}) @@ -40,7 +43,5 @@ def test_new_from_dict(): assert ref_config.people[0].name == new_config.people[0].name assert ref_config.people[0].age == new_config.people[0].age - try: + with pytest.raises(ValueError, match="Missing required field"): WithRequired.new_from_dict({}) - except ValueError as e: - assert "Missing required field" in e.args[0] diff --git a/tests/test_merge_configs.py b/tests/test_merge_configs.py index 8f5a52d..f5e1ae4 100644 --- a/tests/test_merge_configs.py +++ b/tests/test_merge_configs.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional from coqpit.coqpit import Coqpit @@ -6,7 +7,7 @@ @dataclass class CoqpitA(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_c: str = "Coqpit is great!" val_same: float = 10.21 @@ -22,21 +23,21 @@ class CoqpitB(Coqpit): @dataclass class Reference(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_c: str = "Coqpit is great!" val_e: int = 257 val_f: float = -10.21 val_g: str = "Coqpit is really great!" - val_same: int = 10.21 # duplicate fields are override by the merged Coqpit class. + val_same: float = 10.21 # duplicate fields are override by the merged Coqpit class. -def test_config_merge(): +def test_config_merge() -> None: coqpit_ref = Reference() coqpita = CoqpitA() coqpitb = CoqpitB() coqpitb.merge(coqpita) print(coqpitb.val_a) - print(coqpitb.pprint()) + coqpitb.pprint() assert coqpit_ref.val_a == coqpitb.val_a assert coqpit_ref.val_b == coqpitb.val_b diff --git a/tests/test_nested_configs.py b/tests/test_nested_configs.py index ea654a0..4117fe6 100644 --- a/tests/test_nested_configs.py +++ b/tests/test_nested_configs.py @@ -1,6 +1,6 @@ -import os from dataclasses import asdict, dataclass, field -from typing import Union +from pathlib import Path +from typing import Optional, Union from coqpit import Coqpit, check_argument @@ -8,12 +8,10 @@ @dataclass class SimpleConfig(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_c: str = "Coqpit is great!" - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -24,15 +22,13 @@ def check_values( @dataclass class NestedConfig(Coqpit): val_d: int = 10 - val_e: int = None + val_e: Optional[int] = None val_f: str = "Coqpit is great!" - sc_list: list[SimpleConfig] = None + sc_list: Optional[list[SimpleConfig]] = None sc: SimpleConfig = field(default_factory=lambda: SimpleConfig()) union_var: Union[list[SimpleConfig], SimpleConfig] = field(default_factory=lambda: [SimpleConfig(), SimpleConfig()]) - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_d", c, restricted=True, min_val=10, max_val=2056) @@ -42,27 +38,27 @@ def check_values( check_argument("sc", c, restricted=True, allow_none=True) -def test_nested(): - file_path = os.path.dirname(os.path.abspath(__file__)) +def test_nested() -> None: + file_path = Path(__file__).resolve().parent / "example_config.json" # init 🐸 dataclass config = NestedConfig() # save to a json file - config.save_json(os.path.join(file_path, "example_config.json")) + config.save_json(file_path) # load a json file - config2 = NestedConfig(val_d=None, val_e=500, val_f=None, sc_list=None, sc=None, union_var=None) + config2 = NestedConfig(val_e=500) # update the config with the json file. - config2.load_json(os.path.join(file_path, "example_config.json")) + config2.load_json(file_path) # now they should be having the same values. assert config == config2 # pretty print the dataclass - print(config.pprint()) + config.pprint() # export values to a dict config_dict = config.to_dict() # crate a new config with different values than the defaults - config2 = NestedConfig(val_d=None, val_e=500, val_f=None, sc_list=None, sc=None, union_var=None) + config2 = NestedConfig(val_e=500) # update the config with the exported valuess from the previous config. config2.from_dict(config_dict) # now they should be having the same values. diff --git a/tests/test_parse_argparse.py b/tests/test_parse_argparse.py index d8d4756..ec7d137 100644 --- a/tests/test_parse_argparse.py +++ b/tests/test_parse_argparse.py @@ -1,36 +1,37 @@ from dataclasses import asdict, dataclass, field +from typing import Optional from coqpit.coqpit import Coqpit, check_argument @dataclass class SimplerConfig(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) + val_a: Optional[int] = field(default=None, metadata={"help": "this is val_a"}) @dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) val_c: str = "Coqpit is great!" - val_dict: dict = field(default_factory=lambda: {"val_a": 100, "val_b": 200, "val_c": 300}) + val_dict: dict[str, int] = field(default_factory=lambda: {"val_a": 100, "val_b": 200, "val_c": 300}) mylist_with_default: list[SimplerConfig] = field( default_factory=lambda: [SimplerConfig(val_a=100), SimplerConfig(val_a=999)], metadata={"help": "list of SimplerConfig"}, ) int_list: list[int] = field(default_factory=lambda: [1, 2, 3], metadata={"help": "int"}) str_list: list[str] = field(default_factory=lambda: ["veni", "vidi", "vici"], metadata={"help": "str"}) - empty_int_list: list[int] = field(default=None, metadata={"help": "int list without default value"}) - empty_str_list: list[str] = field(default=None, metadata={"help": "str list without default value"}) + empty_int_list: Optional[list[int]] = field(default=None, metadata={"help": "int list without default value"}) + empty_str_list: Optional[list[str]] = field(default=None, metadata={"help": "str list without default value"}) list_with_default_factory: list[str] = field( - default_factory=list, metadata={"help": "str list with default factory"} + default_factory=list, + metadata={"help": "str list with default factory"}, ) - # mylist_without_default: List[SimplerConfig] = field(default=None, metadata={'help': 'list of SimplerConfig'}) # NOT SUPPORTED YET! + # TODO: not supported yet + # mylist_without_default: List[SimplerConfig] = field(default=None) noqa: ERA001 - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -38,7 +39,7 @@ def check_values( check_argument("val_c", c, restricted=True) -def test_parse_argparse(): +def test_parse_argparse() -> None: args = [] args.extend(["--coqpit.val_a", "222"]) args.extend(["--coqpit.val_b", "999"]) @@ -54,7 +55,7 @@ def test_parse_argparse(): # initial config config = SimpleConfig() - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig( @@ -71,7 +72,7 @@ def test_parse_argparse(): ) # create and init argparser with Coqpit - parser = config.init_argparse() + parser = config.init_argparse(instance=config) parser.print_help() # parse the argsparser @@ -81,7 +82,7 @@ def test_parse_argparse(): assert config == config_ref -def test_boolean_parse(): +def test_boolean_parse() -> None: @dataclass class Config(Coqpit): boolean_without_default: bool = field() @@ -113,8 +114,8 @@ class Config(Coqpit): try: config.parse_args(args) - assert False, "should not reach this" # noqa: B011 - except: # noqa: E722 + raise AssertionError # pragma: no cover, should not reach this + except SystemExit: pass @@ -123,12 +124,11 @@ class ArgparseWithRequiredField(Coqpit): val_a: int -def test_argparse_with_required_field(): +def test_argparse_with_required_field() -> None: args = ["--coqpit.val_a", "10"] try: - c = ArgparseWithRequiredField() # pylint: disable=no-value-for-parameter - c.parse_args(args) - assert False # noqa: B011 + c = ArgparseWithRequiredField() # type: ignore[call-arg] + raise AssertionError # pragma: no cover, should not reach this except TypeError: # __init__ should fail due to missing val_a pass @@ -137,27 +137,26 @@ def test_argparse_with_required_field(): assert c.val_a == 10 -def test_init_argparse_list_and_nested(): +def test_init_argparse_list_and_nested() -> None: @dataclass class SimplerConfig2(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) + val_a: Optional[int] = field(default=None, metadata={"help": "this is val_a"}) @dataclass class SimpleConfig2(Coqpit): val_req: str # required field val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig2"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) nested_config: SimplerConfig2 = field(default_factory=lambda: SimplerConfig2()) mylist_with_default: list[SimplerConfig2] = field( default_factory=lambda: [SimplerConfig2(val_a=100), SimplerConfig2(val_a=999)], metadata={"help": "list of SimplerConfig2"}, ) - # mylist_without_default: List[SimplerConfig2] = field(default=None, metadata={'help': 'list of SimplerConfig2'}) # NOT SUPPORTED YET! + # TODO: not supported yet + # mylist_without_default: List[SimplerConfig2] = field(default=None) # noqa: ERA001 - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) diff --git a/tests/test_parse_known_argparse.py b/tests/test_parse_known_argparse.py index 4855bda..a1b3218 100644 --- a/tests/test_parse_known_argparse.py +++ b/tests/test_parse_known_argparse.py @@ -1,26 +1,25 @@ from dataclasses import asdict, dataclass, field +from typing import Optional from coqpit.coqpit import Coqpit, check_argument @dataclass class SimplerConfig(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) + val_a: Optional[int] = field(default=None, metadata={"help": "this is val_a"}) @dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) val_c: str = "Coqpit is great!" mylist_with_default: list[SimplerConfig] = field( default_factory=lambda: [SimplerConfig(val_a=100), SimplerConfig(val_a=999)], metadata={"help": "list of SimplerConfig"}, ) - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -28,7 +27,7 @@ def check_values( check_argument("val_c", c, restricted=True) -def test_parse_argparse(): +def test_parse_argparse() -> None: unknown_args = ["--coqpit.arg_does_not_exist", "111"] args = [] args.extend(["--coqpit.val_a", "222"]) @@ -40,7 +39,7 @@ def test_parse_argparse(): # initial config config = SimpleConfig() - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig( @@ -62,10 +61,10 @@ def test_parse_argparse(): assert unknown == unknown_args -def test_parse_edited_argparse(): - """calling `parse_known_argparse` after some modifications in the config values. - `parse_known_argparse` should keep the modified values if not defined in argv""" - +def test_parse_edited_argparse() -> None: + """Calling `parse_known_argparse` after some modifications in the config values. + `parse_known_argparse` should keep the modified values if not defined in argv + """ unknown_args = ["--coqpit.arg_does_not_exist", "111"] args = [] args.extend(["--coqpit.mylist_with_default.1.val_a", "111"]) @@ -77,7 +76,7 @@ def test_parse_edited_argparse(): config.val_b = 444 config.val_c = "this is different" config.mylist_with_default[0].val_a = 777 - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig( diff --git a/tests/test_relaxed_parse_known_argparse.py b/tests/test_relaxed_parse_known_argparse.py index 93c7ef7..9d096b2 100644 --- a/tests/test_relaxed_parse_known_argparse.py +++ b/tests/test_relaxed_parse_known_argparse.py @@ -1,31 +1,24 @@ from dataclasses import asdict, dataclass, field -from typing import Union +from typing import Any, Optional, Union from coqpit.coqpit import Coqpit, check_argument -@dataclass -class SimplerConfig(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) - - @dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) - val_c: Union[int, str] = None - val_d: list[list] = None + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) + val_c: Optional[Union[int, str]] = None + val_d: Optional[list[list[Any]]] = None - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) check_argument("val_b", c, restricted=True, min_val=128, max_val=4058, allow_none=True) -def test_parse_argparse(): +def test_parse_argparse() -> None: unknown_args = ["--coqpit.arg_does_not_exist", "111"] args = [] args.extend(["--coqpit.val_a", "222"]) @@ -34,7 +27,7 @@ def test_parse_argparse(): # initial config config = SimpleConfig() - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig(val_a=222, val_b=999, val_c=None, val_d=None) diff --git a/tests/test_serialization.json b/tests/test_serialization.json index e90dd3a..81be55d 100644 --- a/tests/test_serialization.json +++ b/tests/test_serialization.json @@ -1,6 +1,7 @@ { "name": "Coqpit", "size": 3, + "path": "a/b", "people": [ { "name": "Eren", @@ -14,5 +15,10 @@ "name": "Ceren", "age": 15 } - ] + ], + "some_dict": { + "a": 1, + "b": 2, + "c": null + } } diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ef96618..557d261 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,37 +1,42 @@ -import os from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional -from coqpit import Coqpit +from coqpit.coqpit import Coqpit, _deserialize_list @dataclass class Person(Coqpit): - name: str = None - age: int = None + name: Optional[str] = None + age: Optional[int] = None @dataclass class Group(Coqpit): - name: str = None - size: int = None - people: list[Person] = None + name: Optional[str] = None + size: Optional[int] = None + path: Optional[Path] = None + people: list[Person] = field(default_factory=list) + some_dict: dict[str, Optional[int]] = field(default_factory=dict) @dataclass class Reference(Coqpit): - name: str = "Coqpit" - size: int = 3 + name: Optional[str] = "Coqpit" + size: Optional[int] = 3 + path: Path = Path("a/b") people: list[Person] = field( default_factory=lambda: [ Person(name="Eren", age=11), Person(name="Geren", age=12), Person(name="Ceren", age=15), - ] + ], ) + some_dict: dict[str, Optional[int]] = field(default_factory=lambda: {"a": 1, "b": 2, "c": None}) -def test_serizalization(): - file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_serialization.json") +def test_serialization() -> None: + file_path = Path(__file__).resolve().parent / "test_serialization.json" ref_config = Reference() ref_config.save_json(file_path) @@ -50,3 +55,14 @@ def test_serizalization(): assert ref_config.people[0].age == new_config.people[0].age assert ref_config.people[1].age == new_config.people[1].age assert ref_config.people[2].age == new_config.people[2].age + assert ref_config.path == new_config.path + assert ref_config.some_dict["a"] == new_config.some_dict["a"] + assert ref_config.some_dict["b"] == new_config.some_dict["b"] + assert ref_config.some_dict["c"] == new_config.some_dict["c"] + + +def test_deserialize_list() -> None: + assert _deserialize_list([1, 2, 3], list) == [1, 2, 3] + assert _deserialize_list([1, 2, 3], list[int]) == [1, 2, 3] + assert _deserialize_list([1, 2, 3], list[float]) == [1.0, 2.0, 3.0] + assert _deserialize_list([1, 2, 3], list[str]) == ["1", "2", "3"] diff --git a/tests/test_simple_config.py b/tests/test_simple_config.py index bd57889..cc2d17a 100644 --- a/tests/test_simple_config.py +++ b/tests/test_simple_config.py @@ -1,6 +1,6 @@ -import os from dataclasses import asdict, dataclass, field -from typing import Union +from pathlib import Path +from typing import Any, Optional, Union from coqpit.coqpit import MISSING, Coqpit, check_argument @@ -8,7 +8,7 @@ @dataclass class SimpleConfig(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_d: float = 10.21 val_c: str = "Coqpit is great!" vol_e: bool = True @@ -16,16 +16,16 @@ class SimpleConfig(Coqpit): # raise an error when accessing the value if it is not changed. It is a way to define val_k: int = MISSING # optional field - val_dict: dict = field(default_factory=lambda: {"val_aa": 10, "val_ss": "This is in a dict."}) + val_dict: dict[str, Any] = field(default_factory=lambda: {"val_aa": 10, "val_ss": "This is in a dict."}) # list of list - val_listoflist: list[list] = field(default_factory=lambda: [[1, 2], [3, 4]]) + val_listoflist: list[list[int]] = field(default_factory=lambda: [[1, 2], [3, 4]]) val_listofunion: list[list[Union[str, int, bool]]] = field( - default_factory=lambda: [[1, 3], [1, "Hi!"], [True, False]] + default_factory=lambda: [[1, 3], [1, "Hi!"], [True, False]], ) def check_values( self, - ): # you can define explicit constraints on the fields using `check_argument()` + ) -> None: # you can define explicit constraints on the fields using `check_argument()` """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -33,10 +33,12 @@ def check_values( check_argument("val_c", c, restricted=True) -def test_simple_config(): - file_path = os.path.dirname(os.path.abspath(__file__)) +def test_simple_config() -> None: + file_path = Path(__file__).resolve().parent / "example_config.json" config = SimpleConfig() + assert config._is_initialized() + # try MISSING class argument try: _ = config.val_k @@ -44,12 +46,15 @@ def test_simple_config(): print(" val_k needs a different value before accessing it.") config.val_k = 1000 + assert "val_a" in config + assert config.has("val_a") + # try serialization and deserialization print(config.serialize()) print(config.to_json()) - config.save_json(os.path.join(file_path, "example_config.json")) - config.load_json(os.path.join(file_path, "example_config.json")) - print(config.pprint()) + config.save_json(file_path) + config.load_json(file_path) + config.pprint() # try `dict` interface print(*config)