From 1f3d114051d0b64184bcc6b469a12bae5e02667f Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 01:52:54 +0200 Subject: [PATCH 01/27] build: add typing_extensions --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9587bfb..2bba169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,9 @@ classifiers = [ "Operating System :: MacOS", "Operating System :: Microsoft :: Windows", ] -dependencies = [] +dependencies = [ + "typing_extensions>=4.10", +] [tool.uv] dev-dependencies = [ From f901d57f7d3ffa1c918742087e945ccf2cf8c6a8 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 01:35:32 +0200 Subject: [PATCH 02/27] docs: fix typos, improve docstrings --- coqpit/coqpit.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 4c4330b..f1294ac 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -215,7 +215,7 @@ def _deserialize_union(x: Any, field_type: type) -> Any: field_type (Type): field type. Returns: - [Any]: desrialized value. + [Any]: deserialized value. """ for arg in field_type.__args__: # stop after first matching type in Union @@ -251,7 +251,7 @@ def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: ty def _deserialize(x: Any, field_type: Any) -> Any: - """Pick the right desrialization for the given object and the corresponding field type. + """Pick the right deserialization for the given object and the corresponding field type. Args: x (object): object to be deserialized. @@ -374,7 +374,7 @@ def serialize(self) -> dict: return o def deserialize(self, data: dict) -> "Serializable": - """Parse input dictionary and desrialize its fields to a dataclass. + """Parse input dictionary and deserialize its fields to a dataclass. Returns: self: deserialized `self`. @@ -406,7 +406,7 @@ def deserialize(self, data: dict) -> "Serializable": @classmethod def deserialize_immutable(cls, data: dict) -> "Serializable": - """Parse input dictionary and desrialize its fields to a dataclass. + """Parse input dictionary and deserialize its fields to a dataclass. Returns: Newly created deserialized object. @@ -854,7 +854,9 @@ def init_argparse( help_prefix="", relaxed_parser: bool = False, ) -> argparse.ArgumentParser: - """Pass Coqpit fields as argparse arguments. This allows to edit values through command-line. + """Create an argparse parser that can parse the Coqpit fields. + + This allows to edit values through command-line. Args: parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created. From 4e3ee4aed39bbb27d9d8934c6f58c7df50f62289 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 01:52:30 +0200 Subject: [PATCH 03/27] chore: improve types --- coqpit/coqpit.py | 73 ++++++++++++---------- tests/test_copying.py | 2 +- tests/test_init_from_dict.py | 7 ++- tests/test_merge_configs.py | 11 ++-- tests/test_nested_configs.py | 24 +++---- tests/test_parse_argparse.py | 33 +++++----- tests/test_parse_known_argparse.py | 17 +++-- tests/test_relaxed_parse_known_argparse.py | 21 +++---- tests/test_simple_config.py | 14 ++--- 9 files changed, 97 insertions(+), 105 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index f1294ac..5616860 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -1,14 +1,19 @@ +from __future__ import annotations + import argparse import contextlib import functools import json import operator -from collections.abc import MutableMapping +import os +import typing +from collections.abc import ItemsView, Iterable, Iterator, MutableMapping from dataclasses import MISSING as _MISSING from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace from pathlib import Path from pprint import pprint -from typing import Any, Generic, Optional, TypeVar, Union +from types import GenericAlias, NoneType, UnionType +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, overload T = TypeVar("T") MISSING: Any = "???" @@ -158,7 +163,7 @@ def _serialize(x): return x -def _deserialize_dict(x: dict) -> dict: +def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]: """Deserialize dict. Args: @@ -167,7 +172,7 @@ def _deserialize_dict(x: dict) -> dict: Returns: Dict: deserialized dictionary. """ - out_dict = {} + out_dict: dict[Any, Any] = {} for k, v in x.items(): if v is None: # if {'key':None} out_dict[k] = None @@ -207,7 +212,7 @@ def _deserialize_list(x: list, field_type: type) -> list: return x -def _deserialize_union(x: Any, field_type: type) -> Any: +def _deserialize_union(x: Any, field_type: UnionType) -> Any: """Deserialize values for Union typed fields. Args: @@ -317,7 +322,7 @@ def _getitem(obj, attr): class Serializable: """Gives serialization ability to any inheriting dataclass.""" - def __post_init__(self): + def __post_init__(self) -> None: self._validate_contracts() for key, value in self.__dict__.items(): if value is no_default: @@ -349,7 +354,7 @@ def validate(self) -> None: msg = "could not be deserialized with same value" raise ValueError(msg) - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: """Transform serializable object to dict.""" cls_fields = fields(self) o = {} @@ -357,7 +362,7 @@ def to_dict(self) -> dict: o[cls_field.name] = getattr(self, cls_field.name) return o - def serialize(self) -> dict: + def serialize(self) -> dict[str, Any]: """Serialize object to be json serializable representation.""" if not is_dataclass(self): msg = "need to be decorated as dataclass" @@ -445,7 +450,7 @@ def deserialize_immutable(cls, data: dict) -> "Serializable": # ---------------------------------------------------------------------------- # -def _get_help(field): +def _get_help(field: Field[Any]) -> str: try: field_help = field.metadata["help"] except KeyError: @@ -543,7 +548,7 @@ def _init_argparse( ) elif isinstance(field_type(), bool): - def parse_bool(x): + def parse_bool(x: str) -> bool: if x not in ("true", "false"): msg = f' [!] Value for boolean field must be either "true" or "false". Got "{x}".' raise ValueError(msg) @@ -585,7 +590,7 @@ class Coqpit(Serializable, MutableMapping): _initialized = False - def _is_initialized(self): + def _is_initialized(self) -> bool: """Check if Coqpit is initialized. Useful to prevent running some aux functions @@ -593,14 +598,14 @@ def _is_initialized(self): """ return "_initialized" in vars(self) and self._initialized - def __post_init__(self): + def __post_init__(self) -> None: self._initialized = True with contextlib.suppress(AttributeError): self.check_values() ## `dict` API functions - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(asdict(self)) def __len__(self) -> int: @@ -609,19 +614,19 @@ def __len__(self) -> int: def __setitem__(self, arg: str, value: Any) -> None: setattr(self, arg, value) - def __getitem__(self, arg: str): + def __getitem__(self, arg: str) -> Any: """Access class attributes with ``[arg]``.""" return self.__dict__[arg] def __delitem__(self, arg: str) -> None: delattr(self, arg) - def _keytransform(self, key): # pylint: disable=no-self-use + def _keytransform(self, key: str) -> str: # pylint: disable=no-self-use return key ## end `dict` API functions - def __getattribute__(self, arg: str): # pylint: disable=no-self-use + def __getattribute__(self, arg: str) -> Any: # pylint: disable=no-self-use """Check if the mandatory field is defined when accessing it.""" value = super().__getattribute__(arg) if isinstance(value, str) and value == "???": @@ -629,18 +634,18 @@ def __getattribute__(self, arg: str): # pylint: disable=no-self-use raise AttributeError(msg) return value - def __contains__(self, arg: str) -> bool: + def __contains__(self, arg: object) -> bool: return arg in self.to_dict() - def get(self, key: str, default: Any = None): + def get(self, key: str, default: Any = None) -> Any: if self.has(key): return asdict(self)[key] return default - def items(self): + def items(self) -> ItemsView[str, Any]: return asdict(self).items() - def merge(self, coqpits: Union["Coqpit", list["Coqpit"]]) -> None: + def merge(self, coqpits: Coqpit | list[Coqpit]) -> None: """Merge a coqpit instance or a list of coqpit instances to self. Note that it does not pass the fields and overrides attributes with @@ -651,7 +656,7 @@ def merge(self, coqpits: Union["Coqpit", list["Coqpit"]]) -> None: coqpits (Union[Coqpit, List[Coqpit]]): coqpit instance or list of instances to be merged. """ - def _merge(coqpit) -> None: + def _merge(coqpit: Coqpit) -> None: self.__dict__.update(coqpit.__dict__) self.__annotations__.update(coqpit.__annotations__) self.__dataclass_fields__.update(coqpit.__dataclass_fields__) @@ -689,7 +694,7 @@ def pprint(self) -> None: """Print Coqpit fields in a format.""" pprint(asdict(self)) - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: # return asdict(self) return self.serialize() @@ -704,7 +709,7 @@ def to_json(self) -> str: """Returns a JSON string representation.""" return json.dumps(asdict(self), indent=4, default=_coqpit_json_default) - def save_json(self, file_name: str) -> None: + def save_json(self, file_name: str | os.PathLike[Any]) -> None: """Save Coqpit to a json file. Args: @@ -713,7 +718,7 @@ def save_json(self, file_name: str) -> None: with open(file_name, "w", encoding="utf8") as f: json.dump(asdict(self), f, indent=4) - def load_json(self, file_name: str) -> None: + def load_json(self, file_name: str | os.PathLike[Any]) -> None: """Load a json file and update matching config fields with type checking. Non-matching parameters in the json file are ignored. @@ -734,7 +739,7 @@ def load_json(self, file_name: str) -> None: @classmethod def init_from_argparse( cls, - args: Optional[Union[argparse.Namespace, list[str]]] = None, + args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", ) -> "Coqpit": """Create a new Coqpit instance from argparse input. @@ -784,7 +789,7 @@ def init_from_argparse( def parse_args( self, - args: Optional[Union[argparse.Namespace, list[str]]] = None, + args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", ) -> None: """Update config values from argparse arguments with some meta-programming ✨. @@ -819,7 +824,7 @@ def parse_args( def parse_known_args( self, - args: Optional[Union[argparse.Namespace, list[str]]] = None, + args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", relaxed_parser: bool = False, ) -> list[str]: @@ -892,15 +897,15 @@ def init_argparse( def check_argument( - name, - c, + name: str, + c: dict[str, Any], is_path: bool = False, - prerequest: Optional[str] = None, - enum_list: Optional[list] = None, - max_val: Optional[float] = None, - min_val: Optional[float] = None, + prerequest: str | None = None, + enum_list: list[Any] | None = None, + max_val: float | None = None, + min_val: float | None = None, restricted: bool = False, - alternative: Optional[str] = None, + alternative: str | None = None, allow_none: bool = True, ) -> None: """Simple type and value checking for Coqpit. diff --git a/tests/test_copying.py b/tests/test_copying.py index c3dbf1d..ccd2120 100644 --- a/tests/test_copying.py +++ b/tests/test_copying.py @@ -9,7 +9,7 @@ class SimpleConfig(Coqpit): val_a: int = 10 -def test_copying(): +def test_copying() -> None: config = SimpleConfig() config_new = config.copy() diff --git a/tests/test_init_from_dict.py b/tests/test_init_from_dict.py index f7b9ffc..a18fb67 100644 --- a/tests/test_init_from_dict.py +++ b/tests/test_init_from_dict.py @@ -1,12 +1,13 @@ from dataclasses import dataclass, field +from typing import Optional from coqpit import Coqpit @dataclass class Person(Coqpit): - name: str = None - age: int = None + name: Optional[str] = None + age: Optional[int] = None @dataclass @@ -28,7 +29,7 @@ class WithRequired(Coqpit): name: str -def test_new_from_dict(): +def test_new_from_dict() -> None: ref_config = Reference(name="Fancy", size=3**10, people=[Person(name="Anonymous", age=42)]) new_config = Reference.new_from_dict({"name": "Fancy", "size": 3**10, "people": [{"name": "Anonymous", "age": 42}]}) diff --git a/tests/test_merge_configs.py b/tests/test_merge_configs.py index 8f5a52d..f5e1ae4 100644 --- a/tests/test_merge_configs.py +++ b/tests/test_merge_configs.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional from coqpit.coqpit import Coqpit @@ -6,7 +7,7 @@ @dataclass class CoqpitA(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_c: str = "Coqpit is great!" val_same: float = 10.21 @@ -22,21 +23,21 @@ class CoqpitB(Coqpit): @dataclass class Reference(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_c: str = "Coqpit is great!" val_e: int = 257 val_f: float = -10.21 val_g: str = "Coqpit is really great!" - val_same: int = 10.21 # duplicate fields are override by the merged Coqpit class. + val_same: float = 10.21 # duplicate fields are override by the merged Coqpit class. -def test_config_merge(): +def test_config_merge() -> None: coqpit_ref = Reference() coqpita = CoqpitA() coqpitb = CoqpitB() coqpitb.merge(coqpita) print(coqpitb.val_a) - print(coqpitb.pprint()) + coqpitb.pprint() assert coqpit_ref.val_a == coqpitb.val_a assert coqpit_ref.val_b == coqpitb.val_b diff --git a/tests/test_nested_configs.py b/tests/test_nested_configs.py index ea654a0..8abb381 100644 --- a/tests/test_nested_configs.py +++ b/tests/test_nested_configs.py @@ -1,6 +1,6 @@ import os from dataclasses import asdict, dataclass, field -from typing import Union +from typing import Optional, Union from coqpit import Coqpit, check_argument @@ -8,12 +8,10 @@ @dataclass class SimpleConfig(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_c: str = "Coqpit is great!" - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -24,15 +22,13 @@ def check_values( @dataclass class NestedConfig(Coqpit): val_d: int = 10 - val_e: int = None + val_e: Optional[int] = None val_f: str = "Coqpit is great!" - sc_list: list[SimpleConfig] = None + sc_list: Optional[list[SimpleConfig]] = None sc: SimpleConfig = field(default_factory=lambda: SimpleConfig()) union_var: Union[list[SimpleConfig], SimpleConfig] = field(default_factory=lambda: [SimpleConfig(), SimpleConfig()]) - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_d", c, restricted=True, min_val=10, max_val=2056) @@ -42,7 +38,7 @@ def check_values( check_argument("sc", c, restricted=True, allow_none=True) -def test_nested(): +def test_nested() -> None: file_path = os.path.dirname(os.path.abspath(__file__)) # init 🐸 dataclass config = NestedConfig() @@ -50,19 +46,19 @@ def test_nested(): # save to a json file config.save_json(os.path.join(file_path, "example_config.json")) # load a json file - config2 = NestedConfig(val_d=None, val_e=500, val_f=None, sc_list=None, sc=None, union_var=None) + config2 = NestedConfig(val_e=500) # update the config with the json file. config2.load_json(os.path.join(file_path, "example_config.json")) # now they should be having the same values. assert config == config2 # pretty print the dataclass - print(config.pprint()) + config.pprint() # export values to a dict config_dict = config.to_dict() # crate a new config with different values than the defaults - config2 = NestedConfig(val_d=None, val_e=500, val_f=None, sc_list=None, sc=None, union_var=None) + config2 = NestedConfig(val_e=500) # update the config with the exported valuess from the previous config. config2.from_dict(config_dict) # now they should be having the same values. diff --git a/tests/test_parse_argparse.py b/tests/test_parse_argparse.py index d8d4756..850e42e 100644 --- a/tests/test_parse_argparse.py +++ b/tests/test_parse_argparse.py @@ -1,36 +1,35 @@ from dataclasses import asdict, dataclass, field +from typing import Optional from coqpit.coqpit import Coqpit, check_argument @dataclass class SimplerConfig(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) + val_a: Optional[int] = field(default=None, metadata={"help": "this is val_a"}) @dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) val_c: str = "Coqpit is great!" - val_dict: dict = field(default_factory=lambda: {"val_a": 100, "val_b": 200, "val_c": 300}) + val_dict: dict[str, int] = field(default_factory=lambda: {"val_a": 100, "val_b": 200, "val_c": 300}) mylist_with_default: list[SimplerConfig] = field( default_factory=lambda: [SimplerConfig(val_a=100), SimplerConfig(val_a=999)], metadata={"help": "list of SimplerConfig"}, ) int_list: list[int] = field(default_factory=lambda: [1, 2, 3], metadata={"help": "int"}) str_list: list[str] = field(default_factory=lambda: ["veni", "vidi", "vici"], metadata={"help": "str"}) - empty_int_list: list[int] = field(default=None, metadata={"help": "int list without default value"}) - empty_str_list: list[str] = field(default=None, metadata={"help": "str list without default value"}) + empty_int_list: Optional[list[int]] = field(default=None, metadata={"help": "int list without default value"}) + empty_str_list: Optional[list[str]] = field(default=None, metadata={"help": "str list without default value"}) list_with_default_factory: list[str] = field( default_factory=list, metadata={"help": "str list with default factory"} ) # mylist_without_default: List[SimplerConfig] = field(default=None, metadata={'help': 'list of SimplerConfig'}) # NOT SUPPORTED YET! - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -38,7 +37,7 @@ def check_values( check_argument("val_c", c, restricted=True) -def test_parse_argparse(): +def test_parse_argparse() -> None: args = [] args.extend(["--coqpit.val_a", "222"]) args.extend(["--coqpit.val_b", "999"]) @@ -54,7 +53,7 @@ def test_parse_argparse(): # initial config config = SimpleConfig() - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig( @@ -81,7 +80,7 @@ def test_parse_argparse(): assert config == config_ref -def test_boolean_parse(): +def test_boolean_parse() -> None: @dataclass class Config(Coqpit): boolean_without_default: bool = field() @@ -123,7 +122,7 @@ class ArgparseWithRequiredField(Coqpit): val_a: int -def test_argparse_with_required_field(): +def test_argparse_with_required_field() -> None: args = ["--coqpit.val_a", "10"] try: c = ArgparseWithRequiredField() # pylint: disable=no-value-for-parameter @@ -137,16 +136,16 @@ def test_argparse_with_required_field(): assert c.val_a == 10 -def test_init_argparse_list_and_nested(): +def test_init_argparse_list_and_nested() -> None: @dataclass class SimplerConfig2(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) + val_a: Optional[int] = field(default=None, metadata={"help": "this is val_a"}) @dataclass class SimpleConfig2(Coqpit): val_req: str # required field val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig2"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) nested_config: SimplerConfig2 = field(default_factory=lambda: SimplerConfig2()) mylist_with_default: list[SimplerConfig2] = field( default_factory=lambda: [SimplerConfig2(val_a=100), SimplerConfig2(val_a=999)], @@ -155,9 +154,7 @@ class SimpleConfig2(Coqpit): # mylist_without_default: List[SimplerConfig2] = field(default=None, metadata={'help': 'list of SimplerConfig2'}) # NOT SUPPORTED YET! - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) diff --git a/tests/test_parse_known_argparse.py b/tests/test_parse_known_argparse.py index 4855bda..2208354 100644 --- a/tests/test_parse_known_argparse.py +++ b/tests/test_parse_known_argparse.py @@ -1,26 +1,25 @@ from dataclasses import asdict, dataclass, field +from typing import Optional from coqpit.coqpit import Coqpit, check_argument @dataclass class SimplerConfig(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) + val_a: Optional[int] = field(default=None, metadata={"help": "this is val_a"}) @dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) val_c: str = "Coqpit is great!" mylist_with_default: list[SimplerConfig] = field( default_factory=lambda: [SimplerConfig(val_a=100), SimplerConfig(val_a=999)], metadata={"help": "list of SimplerConfig"}, ) - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -28,7 +27,7 @@ def check_values( check_argument("val_c", c, restricted=True) -def test_parse_argparse(): +def test_parse_argparse() -> None: unknown_args = ["--coqpit.arg_does_not_exist", "111"] args = [] args.extend(["--coqpit.val_a", "222"]) @@ -40,7 +39,7 @@ def test_parse_argparse(): # initial config config = SimpleConfig() - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig( @@ -62,7 +61,7 @@ def test_parse_argparse(): assert unknown == unknown_args -def test_parse_edited_argparse(): +def test_parse_edited_argparse() -> None: """calling `parse_known_argparse` after some modifications in the config values. `parse_known_argparse` should keep the modified values if not defined in argv""" @@ -77,7 +76,7 @@ def test_parse_edited_argparse(): config.val_b = 444 config.val_c = "this is different" config.mylist_with_default[0].val_a = 777 - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig( diff --git a/tests/test_relaxed_parse_known_argparse.py b/tests/test_relaxed_parse_known_argparse.py index 93c7ef7..9d096b2 100644 --- a/tests/test_relaxed_parse_known_argparse.py +++ b/tests/test_relaxed_parse_known_argparse.py @@ -1,31 +1,24 @@ from dataclasses import asdict, dataclass, field -from typing import Union +from typing import Any, Optional, Union from coqpit.coqpit import Coqpit, check_argument -@dataclass -class SimplerConfig(Coqpit): - val_a: int = field(default=None, metadata={"help": "this is val_a"}) - - @dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) - val_b: int = field(default=None, metadata={"help": "this is val_b"}) - val_c: Union[int, str] = None - val_d: list[list] = None + val_b: Optional[int] = field(default=None, metadata={"help": "this is val_b"}) + val_c: Optional[Union[int, str]] = None + val_d: Optional[list[list[Any]]] = None - def check_values( - self, - ): + def check_values(self) -> None: """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) check_argument("val_b", c, restricted=True, min_val=128, max_val=4058, allow_none=True) -def test_parse_argparse(): +def test_parse_argparse() -> None: unknown_args = ["--coqpit.arg_does_not_exist", "111"] args = [] args.extend(["--coqpit.val_a", "222"]) @@ -34,7 +27,7 @@ def test_parse_argparse(): # initial config config = SimpleConfig() - print(config.pprint()) + config.pprint() # reference config that we like to match with the config above config_ref = SimpleConfig(val_a=222, val_b=999, val_c=None, val_d=None) diff --git a/tests/test_simple_config.py b/tests/test_simple_config.py index bd57889..cf45257 100644 --- a/tests/test_simple_config.py +++ b/tests/test_simple_config.py @@ -1,6 +1,6 @@ import os from dataclasses import asdict, dataclass, field -from typing import Union +from typing import Any, Optional, Union from coqpit.coqpit import MISSING, Coqpit, check_argument @@ -8,7 +8,7 @@ @dataclass class SimpleConfig(Coqpit): val_a: int = 10 - val_b: int = None + val_b: Optional[int] = None val_d: float = 10.21 val_c: str = "Coqpit is great!" vol_e: bool = True @@ -16,16 +16,16 @@ class SimpleConfig(Coqpit): # raise an error when accessing the value if it is not changed. It is a way to define val_k: int = MISSING # optional field - val_dict: dict = field(default_factory=lambda: {"val_aa": 10, "val_ss": "This is in a dict."}) + val_dict: dict[str, Any] = field(default_factory=lambda: {"val_aa": 10, "val_ss": "This is in a dict."}) # list of list - val_listoflist: list[list] = field(default_factory=lambda: [[1, 2], [3, 4]]) + val_listoflist: list[list[int]] = field(default_factory=lambda: [[1, 2], [3, 4]]) val_listofunion: list[list[Union[str, int, bool]]] = field( default_factory=lambda: [[1, 3], [1, "Hi!"], [True, False]] ) def check_values( self, - ): # you can define explicit constraints on the fields using `check_argument()` + ) -> None: # you can define explicit constraints on the fields using `check_argument()` """Check config fields""" c = asdict(self) check_argument("val_a", c, restricted=True, min_val=10, max_val=2056) @@ -33,7 +33,7 @@ def check_values( check_argument("val_c", c, restricted=True) -def test_simple_config(): +def test_simple_config() -> None: file_path = os.path.dirname(os.path.abspath(__file__)) config = SimpleConfig() @@ -49,7 +49,7 @@ def test_simple_config(): print(config.to_json()) config.save_json(os.path.join(file_path, "example_config.json")) config.load_json(os.path.join(file_path, "example_config.json")) - print(config.pprint()) + config.pprint() # try `dict` interface print(*config) From fa31c70ca0027cf26c68be085531598367fb737c Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:05:12 +0200 Subject: [PATCH 04/27] refactor: add/improve type aliases --- coqpit/coqpit.py | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 5616860..e0bc8db 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -15,31 +15,37 @@ from types import GenericAlias, NoneType, UnionType from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, overload -T = TypeVar("T") +from typing_extensions import Self, TypeAlias, TypeGuard, TypeIs + +if TYPE_CHECKING: # pragma: no cover + from dataclasses import _MISSING_TYPE + + from _typeshed import SupportsKeysAndGetItem + +_T = TypeVar("_T") MISSING: Any = "???" -class _NoDefault(Generic[T]): +class _NoDefault(Generic[_T]): pass -NoDefaultVar = Union[_NoDefault[T], T] -no_default: NoDefaultVar = _NoDefault() +NoDefaultVar: TypeAlias = Union[_NoDefault[_T], _T] +no_default: NoDefaultVar[Any] = _NoDefault() +FieldType: TypeAlias = str | type | UnionType -def is_primitive_type(arg_type: Any) -> bool: + +def _is_primitive_type(field_type: FieldType) -> TypeGuard[type]: """Check if the input type is one of `int, float, str, bool`. Args: - arg_type (typing.Any): input type to check. + field_type: input type to check. Returns: bool: True if input type is one of `int, float, str, bool`. """ - try: - return isinstance(arg_type(), (int, float, str, bool)) - except (AttributeError, TypeError): - return False + return field_type is int or field_type is float or field_type is str or field_type is bool def is_list(arg_type: Any) -> bool: @@ -112,7 +118,7 @@ def _coqpit_json_default(obj: Any) -> Any: raise TypeError(msg) -def _default_value(x: Field): +def _default_value(x: Field[_T]) -> _T | Literal[_MISSING_TYPE.MISSING]: """Return the default value of the input Field. Args: @@ -121,9 +127,9 @@ def _default_value(x: Field): Returns: object: default value of the input Field. """ - if x.default not in (MISSING, _MISSING): + if x.default != MISSING and x.default is not _MISSING: return x.default - if x.default_factory not in (MISSING, _MISSING): + if x.default_factory != MISSING and x.default_factory is not _MISSING: return x.default_factory() return x.default @@ -181,7 +187,7 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]: return out_dict -def _deserialize_list(x: list, field_type: type) -> list: +def _deserialize_list(x: list[_T], field_type: FieldType) -> list[_T]: """Deserialize values for List typed fields. Args: @@ -275,7 +281,7 @@ def _deserialize(x: Any, field_type: Any) -> Any: return _deserialize_union(x, field_type) if issubclass(field_type, Serializable): return field_type.deserialize_immutable(x) - if is_primitive_type(field_type): + if _is_primitive_type(field_type): return _deserialize_primitive_types(x, field_type) msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type." raise ValueError(msg) @@ -479,7 +485,7 @@ def _init_argparse( has_default = True default = field_default_factory() - if not has_default and not is_primitive_type(field_type) and not is_list(field_type): + if not has_default and not _is_primitive_type(field_type) and not is_list(field_type): # aggregate types (fields with a Coqpit subclass as type) are not supported without None return parser arg_prefix = field_name if arg_prefix == "" else f"{arg_prefix}.{field_name}" @@ -508,7 +514,7 @@ def _init_argparse( return parser if not has_default or field_default_factory is list: - if not is_primitive_type(list_field_type) and not relaxed_parser: + if not _is_primitive_type(list_field_type) and not relaxed_parser: msg = " [!] Empty list with non primitive inner type is currently not supported." raise NotImplementedError(msg) @@ -561,7 +567,7 @@ def parse_bool(x: str) -> bool: help=f"Coqpit Field: {help_prefix}", metavar="true/false", ) - elif is_primitive_type(field_type): + elif _is_primitive_type(field_type): parser.add_argument( f"--{arg_prefix}", default=field_default, @@ -774,7 +780,7 @@ def init_from_argparse( has_default = True default = field_default_factory() - if has_default and (not is_primitive_type(field.type) or is_list(field.type)): + if has_default and (not _is_primitive_type(field.type) or is_list(field.type)): args_with_lists_processed[field.name] = default args_dict = vars(args) From 65245118d1b911ca1a677d39ef676fe4c994278f Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:07:34 +0200 Subject: [PATCH 05/27] chore: use Self return type hint --- coqpit/coqpit.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index e0bc8db..d9a326b 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -384,7 +384,7 @@ def serialize(self) -> dict[str, Any]: o[field.name] = value return o - def deserialize(self, data: dict) -> "Serializable": + def deserialize(self, data: dict[str, Any]) -> Self: """Parse input dictionary and deserialize its fields to a dataclass. Returns: @@ -416,7 +416,7 @@ def deserialize(self, data: dict) -> "Serializable": return self @classmethod - def deserialize_immutable(cls, data: dict) -> "Serializable": + def deserialize_immutable(cls, data: dict[str, Any]) -> Self: """Parse input dictionary and deserialize its fields to a dataclass. Returns: @@ -679,7 +679,7 @@ def check_values(self) -> None: def has(self, arg: str) -> bool: return arg in vars(self) - def copy(self): + def copy(self) -> Self: return replace(self) def update(self, new: dict, allow_new: bool = False) -> None: @@ -708,7 +708,7 @@ def from_dict(self, data: dict) -> None: self = self.deserialize(data) # pylint: disable=self-cls-assignment @classmethod - def new_from_dict(cls: Serializable, data: dict) -> "Coqpit": + def new_from_dict(cls, data: dict[str, Any]) -> Self: return cls.deserialize_immutable(data) def to_json(self) -> str: @@ -747,7 +747,7 @@ def init_from_argparse( cls, args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", - ) -> "Coqpit": + ) -> Self: """Create a new Coqpit instance from argparse input. Args: From 2ce69f5fd3fbf71c550633f2642d47f8e00f728c Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:11:28 +0200 Subject: [PATCH 06/27] test: correctly handle exceptions in tests --- tests/test_parse_argparse.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_parse_argparse.py b/tests/test_parse_argparse.py index 850e42e..6637e92 100644 --- a/tests/test_parse_argparse.py +++ b/tests/test_parse_argparse.py @@ -112,8 +112,8 @@ class Config(Coqpit): try: config.parse_args(args) - assert False, "should not reach this" # noqa: B011 - except: # noqa: E722 + raise AssertionError("should not reach this") # pragma: no cover + except SystemExit: pass @@ -125,9 +125,8 @@ class ArgparseWithRequiredField(Coqpit): def test_argparse_with_required_field() -> None: args = ["--coqpit.val_a", "10"] try: - c = ArgparseWithRequiredField() # pylint: disable=no-value-for-parameter - c.parse_args(args) - assert False # noqa: B011 + c = ArgparseWithRequiredField() # type: ignore[call-arg] + raise AssertionError("should not reach this") # pragma: no cover except TypeError: # __init__ should fail due to missing val_a pass From fd01bf9402fee38ee51ad8ce9d7cd76eb8214b45 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:40:17 +0200 Subject: [PATCH 07/27] fix: correctly serialize when dumping to json Remove now unused _coqpit_json_default(). --- coqpit/coqpit.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index d9a326b..446b0fc 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -111,13 +111,6 @@ def safe_issubclass(cls, classinfo) -> bool: return r -def _coqpit_json_default(obj: Any) -> Any: - if isinstance(obj, Path): - return str(obj) - msg = f"Can't encode object of type {type(obj).__name__}" - raise TypeError(msg) - - def _default_value(x: Field[_T]) -> _T | Literal[_MISSING_TYPE.MISSING]: """Return the default value of the input Field. @@ -713,7 +706,7 @@ def new_from_dict(cls, data: dict[str, Any]) -> Self: def to_json(self) -> str: """Returns a JSON string representation.""" - return json.dumps(asdict(self), indent=4, default=_coqpit_json_default) + return json.dumps(self.to_dict(), indent=4) def save_json(self, file_name: str | os.PathLike[Any]) -> None: """Save Coqpit to a json file. @@ -722,7 +715,7 @@ def save_json(self, file_name: str | os.PathLike[Any]) -> None: file_name (str): path to the output json file. """ with open(file_name, "w", encoding="utf8") as f: - json.dump(asdict(self), f, indent=4) + json.dump(self.to_dict(), f, indent=4) def load_json(self, file_name: str | os.PathLike[Any]) -> None: """Load a json file and update matching config fields with type checking. From 37299fb9b4c0f567d9e923bde9d25dd5ee946cdc Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:36:49 +0200 Subject: [PATCH 08/27] feat: add _drop_none_type() to better handle optional fields --- coqpit/coqpit.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 446b0fc..f170dd9 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -140,7 +140,24 @@ def _is_optional_field(field) -> bool: return type(None) in field.type.__args__ -def _serialize(x): +def _drop_none_type(field_type: FieldType) -> FieldType: + """Drop None from Union-like types. + + >>> _drop_none_type(str | int | None) + str | int + """ + if not _is_union(field_type): + return field_type + origin = typing.get_origin(field_type) + args = list(typing.get_args(field_type)) + if NoneType in args: + args.remove(NoneType) + if len(args) == 1: + return typing.cast(type, args[0]) + return typing.cast(UnionType, GenericAlias(origin, args)) + + +def _serialize(x: Any) -> Any: """Pick the right serialization for the datatype of the given input. Args: From 5fee09f6e040f4e7875b1a8cfd69c6e5febbc4f7 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:33:54 +0200 Subject: [PATCH 09/27] fix: use get_args and get_origin, update helpers - instead of undocumented __args__ and __origin__ attributes. - remove now unused safe_issubclass() - make helpers private --- coqpit/coqpit.py | 150 +++++++++++++++++++++++------------------------ 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index f170dd9..e6ea39b 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -48,67 +48,61 @@ def _is_primitive_type(field_type: FieldType) -> TypeGuard[type]: return field_type is int or field_type is float or field_type is str or field_type is bool -def is_list(arg_type: Any) -> bool: +def _is_list(field_type: FieldType) -> TypeGuard[type]: """Check if the input type is `list`. Args: - arg_type (typing.Any): input type. + field_type: input type. Returns: bool: True if input type is `list` """ - try: - return arg_type is list or arg_type is list or arg_type.__origin__ is list or arg_type.__origin__ is list - except AttributeError: - return False + return field_type is list or typing.get_origin(field_type) is list -def is_dict(arg_type: Any) -> bool: +def _is_dict(field_type: FieldType) -> TypeGuard[type]: """Check if the input type is `dict`. Args: - arg_type (typing.Any): input type. + field_type: input type. Returns: bool: True if input type is `dict` """ - try: - return arg_type is dict or arg_type is dict or arg_type.__origin__ is dict - except AttributeError: - return False + return field_type is dict or typing.get_origin(field_type) is dict -def is_union(arg_type: Any) -> bool: +def _is_union(field_type: FieldType) -> TypeIs[UnionType]: """Check if the input type is `Union`. Args: - arg_type (typing.Any): input type. + field_type: input type. Returns: bool: True if input type is `Union` """ - try: - return safe_issubclass(arg_type.__origin__, Union) - except AttributeError: - return False + origin = typing.get_origin(field_type) + return origin is Union or origin is UnionType -def safe_issubclass(cls, classinfo) -> bool: - """Check if the input type is a subclass of the given class. +def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionType]: + """Check if the input type is `Union`. + + Note: `int | None` would be of type Union, but here we don't consider such + cases where the only other accepted type is None. Args: - cls (type): input type. - classinfo (type): parent class. + field_type: input type. Returns: - bool: True if the input type is a subclass of the given class + bool: True if input type is `Union` and not optional type like `int | None` """ - try: - r = issubclass(cls, classinfo) - except Exception: # pylint: disable=broad-except - return cls is classinfo - else: - return r + args = typing.get_args(field_type) + is_python_union = _is_union(field_type) + if is_python_union and len(args) == 2 and NoneType in args: + # This is an Optional type like `int | None` + return False + return is_python_union def _default_value(x: Field[_T]) -> _T | Literal[_MISSING_TYPE.MISSING]: @@ -127,17 +121,16 @@ def _default_value(x: Field[_T]) -> _T | Literal[_MISSING_TYPE.MISSING]: return x.default -def _is_optional_field(field) -> bool: - """Check if the input field is optional. +def _is_optional_field(field_type: FieldType) -> TypeGuard[UnionType]: + """Check if the input field type is optional. Args: - field (Field): input Field to check. + field_type: input Field's type to check. Returns: - bool: True if the input field is optional. + bool: True if the input field type is optional. """ - # return isinstance(field.type, _GenericAlias) and type(None) in getattr(field.type, "__args__") - return type(None) in field.type.__args__ + return NoneType in typing.get_args(field_type) def _drop_none_type(field_type: FieldType) -> FieldType: @@ -210,22 +203,17 @@ def _deserialize_list(x: list[_T], field_type: FieldType) -> list[_T]: Returns: [List]: deserialized list. """ - field_args = None - if hasattr(field_type, "__args__") and field_type.__args__: - field_args = field_type.__args__ - elif hasattr(field_type, "__parameters__") and field_type.__parameters__: - # bandaid for python 3.6 - field_args = field_type.__parameters__ - if field_args: - if len(field_args) > 1: - msg = " [!] Coqpit does not support multi-type hinted 'List'" - raise ValueError(msg) - field_arg = field_args[0] - # if field type is TypeVar set the current type by the value's type. - if isinstance(field_arg, TypeVar): - field_arg = type(x) - return [_deserialize(xi, field_arg) for xi in x] - return x + field_args = typing.get_args(field_type) + if len(field_args) == 0: + return x + elif len(field_args) > 1: + msg = "Coqpit does not support multi-type hinted 'List'" + raise ValueError(msg) + field_arg = field_args[0] + # if field type is TypeVar set the current type by the value's type. + if isinstance(field_arg, TypeVar): + field_arg = type(x) + return [_deserialize(xi, field_arg) for xi in x] def _deserialize_union(x: Any, field_type: UnionType) -> Any: @@ -238,7 +226,7 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any: Returns: [Any]: deserialized value. """ - for arg in field_type.__args__: + for arg in typing.get_args(field_type): # stop after first matching type in Union try: x = _deserialize(x, arg) @@ -283,15 +271,20 @@ def _deserialize(x: Any, field_type: Any) -> Any: """ # pylint: disable=too-many-return-statements - if is_dict(field_type): + assert not isinstance(field_type, str) + if _is_dict(_drop_none_type(field_type)): return _deserialize_dict(x) - if is_list(field_type): - return _deserialize_list(x, field_type) - if is_union(field_type): + if _is_list(_drop_none_type(field_type)): + return _deserialize_list(x, _drop_none_type(field_type)) + if _is_union_and_not_simple_optional(field_type): return _deserialize_union(x, field_type) - if issubclass(field_type, Serializable): + if not _is_union(field_type) and isinstance(field_type, type) and issubclass(field_type, Serializable): return field_type.deserialize_immutable(x) - if _is_primitive_type(field_type): + if _drop_none_type(field_type) is Path: + if x is None and _is_optional_field(field_type): + return None + return Path(x) + if _is_primitive_type(_drop_none_type(field_type)): return _deserialize_primitive_types(x, field_type) msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type." raise ValueError(msg) @@ -351,7 +344,7 @@ def _validate_contracts(self) -> None: for field in dataclass_fields: value = getattr(self, field.name) - if value is None and not _is_optional_field(field): + if value is None and not _is_optional_field(field.type): msg = f"{field.name} is not optional" raise TypeError(msg) @@ -495,12 +488,17 @@ def _init_argparse( has_default = True default = field_default_factory() - if not has_default and not _is_primitive_type(field_type) and not is_list(field_type): - # aggregate types (fields with a Coqpit subclass as type) are not supported without None + if ( + not has_default + and not _is_primitive_type(_drop_none_type(field_type)) + and not _is_list(_drop_none_type(field_type)) + ): + # aggregate types (fields with a Coqpit subclass as type) are not + # supported without None return parser arg_prefix = field_name if arg_prefix == "" else f"{arg_prefix}.{field_name}" help_prefix = field_help if help_prefix == "" else f"{help_prefix} - {field_help}" - if is_dict(field_type): # pylint: disable=no-else-raise + if _is_dict(field_type): # pylint: disable=no-else-raise # NOTE: accept any string in json format as input to dict field. parser.add_argument( f"--{arg_prefix}", @@ -508,19 +506,19 @@ def _init_argparse( default=json.dumps(field_default) if field_default else None, type=json.loads, ) - elif is_list(field_type): + elif _is_list(_drop_none_type(field_type)): # TODO: We need a more clear help msg for lists. - if hasattr(field_type, "__args__"): # if the list is hinted - if len(field_type.__args__) > 1 and not relaxed_parser: - msg = " [!] Coqpit does not support multi-type hinted 'List'" - raise ValueError(msg) - list_field_type = field_type.__args__[0] - else: - msg = " [!] Coqpit does not support un-hinted 'List'" + field_args = typing.get_args(_drop_none_type(field_type)) + if len(field_args) > 1 and not relaxed_parser: + msg = "Coqpit does not support multi-type hinted 'List'" + raise ValueError(msg) + elif len(field_args) == 0: + msg = "Coqpit does not support un-hinted 'List'" raise ValueError(msg) + list_field_type = field_args[0] # TODO: handle list of lists - if is_list(list_field_type) and relaxed_parser: + if _is_list(list_field_type) and relaxed_parser: return parser if not has_default or field_default_factory is list: @@ -550,7 +548,7 @@ def _init_argparse( arg_prefix=f"{arg_prefix}", relaxed_parser=relaxed_parser, ) - elif is_union(field_type): + elif _is_union_and_not_simple_optional(field_type): # TODO: currently I don't know how to handle Union type on argparse if not relaxed_parser: msg = " [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue." @@ -577,11 +575,13 @@ def parse_bool(x: str) -> bool: help=f"Coqpit Field: {help_prefix}", metavar="true/false", ) - elif _is_primitive_type(field_type): + elif _is_primitive_type(_drop_none_type(field_type)): + base_type = _drop_none_type(field_type) + assert not _is_union(base_type) parser.add_argument( f"--{arg_prefix}", default=field_default, - type=field_type, + type=base_type, help=f"Coqpit Field: {help_prefix}", ) elif not relaxed_parser: @@ -790,7 +790,7 @@ def init_from_argparse( has_default = True default = field_default_factory() - if has_default and (not _is_primitive_type(field.type) or is_list(field.type)): + if has_default and (not _is_primitive_type(field.type) or _is_list(field.type)): args_with_lists_processed[field.name] = default args_dict = vars(args) From c98d2dea436b11f373dd9aa71989705eeb8d5329 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:43:00 +0200 Subject: [PATCH 10/27] chore: do not re-export dataclass --- coqpit/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/coqpit/__init__.py b/coqpit/__init__.py index aa97fe8..0635173 100644 --- a/coqpit/__init__.py +++ b/coqpit/__init__.py @@ -1,8 +1,7 @@ import importlib.metadata -from dataclasses import dataclass from coqpit.coqpit import MISSING, Coqpit, check_argument -__all__ = ["dataclass", "MISSING", "Coqpit", "check_argument"] +__all__ = ["MISSING", "Coqpit", "check_argument"] __version__ = importlib.metadata.version("coqpit") From bcb0f05c97d0dc6f7ea8acff3b84f03bac2f846b Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:43:57 +0200 Subject: [PATCH 11/27] test: add more serialization test cases --- tests/test_serialization.json | 8 +++++++- tests/test_serialization.py | 29 +++++++++++++++++++---------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/test_serialization.json b/tests/test_serialization.json index e90dd3a..81be55d 100644 --- a/tests/test_serialization.json +++ b/tests/test_serialization.json @@ -1,6 +1,7 @@ { "name": "Coqpit", "size": 3, + "path": "a/b", "people": [ { "name": "Eren", @@ -14,5 +15,10 @@ "name": "Ceren", "age": 15 } - ] + ], + "some_dict": { + "a": 1, + "b": 2, + "c": null + } } diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ef96618..f6bef15 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,26 +1,30 @@ -import os from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional from coqpit import Coqpit @dataclass class Person(Coqpit): - name: str = None - age: int = None + name: Optional[str] = None + age: Optional[int] = None @dataclass class Group(Coqpit): - name: str = None - size: int = None - people: list[Person] = None + name: Optional[str] = None + size: Optional[int] = None + path: Optional[Path] = None + people: list[Person] = field(default_factory=list) + some_dict: dict[str, Optional[int]] = field(default_factory=dict) @dataclass class Reference(Coqpit): - name: str = "Coqpit" - size: int = 3 + name: Optional[str] = "Coqpit" + size: Optional[int] = 3 + path: Path = Path("a/b") people: list[Person] = field( default_factory=lambda: [ Person(name="Eren", age=11), @@ -28,10 +32,11 @@ class Reference(Coqpit): Person(name="Ceren", age=15), ] ) + some_dict: dict[str, Optional[int]] = field(default_factory=lambda: {"a": 1, "b": 2, "c": None}) -def test_serizalization(): - file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_serialization.json") +def test_serialization() -> None: + file_path = Path(__file__).resolve().parent / "test_serialization.json" ref_config = Reference() ref_config.save_json(file_path) @@ -50,3 +55,7 @@ def test_serizalization(): assert ref_config.people[0].age == new_config.people[0].age assert ref_config.people[1].age == new_config.people[1].age assert ref_config.people[2].age == new_config.people[2].age + assert ref_config.path == new_config.path + assert ref_config.some_dict["a"] == new_config.some_dict["a"] + assert ref_config.some_dict["b"] == new_config.some_dict["b"] + assert ref_config.some_dict["c"] == new_config.some_dict["c"] From 0894100ea390e262ef4671d66d0479ad29557d8e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:45:36 +0200 Subject: [PATCH 12/27] fix: do not assign to self --- coqpit/coqpit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index e6ea39b..57e9ef4 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -714,8 +714,8 @@ def to_dict(self) -> dict[str, Any]: # return asdict(self) return self.serialize() - def from_dict(self, data: dict) -> None: - self = self.deserialize(data) # pylint: disable=self-cls-assignment + def from_dict(self, data: dict[str, Any]) -> None: + self.deserialize(data) @classmethod def new_from_dict(cls, data: dict[str, Any]) -> Self: @@ -748,8 +748,7 @@ def load_json(self, file_name: str | os.PathLike[Any]) -> None: with open(file_name, encoding="utf8") as f: input_str = f.read() dump_dict = json.loads(input_str) - # TODO: this looks stupid 💆 - self = self.deserialize(dump_dict) # pylint: disable=self-cls-assignment + self.deserialize(dump_dict) self.check_values() @classmethod From 013652727a7f327d50996dfd8fbd1a8f5da531f0 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:50:17 +0200 Subject: [PATCH 13/27] refactor: simplify recursive getters and setters --- coqpit/coqpit.py | 65 +++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 57e9ef4..aab7df9 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -2,7 +2,6 @@ import argparse import contextlib -import functools import json import operator import os @@ -290,41 +289,49 @@ def _deserialize(x: Any, field_type: Any) -> Any: raise ValueError(msg) -# Recursive setattr (supports dotted attr names) -def rsetattr(obj, attr, val): - def _setitem(obj, attr, val): - return operator.setitem(obj, int(attr), val) +CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"] +CoqpitNestedValue: TypeAlias = Union["CoqpitValue", CoqpitType] +CoqpitValue: TypeAlias = str | int | float | bool | None - pre, _, post = attr.rpartition(".") - setfunc = _setitem if post.isnumeric() else setattr - - return setfunc(rgetattr(obj, pre) if pre else obj, post, val) +# TODO: It should be possible to get rid of the next 3 `type: ignore`. At +# nested levels, the key can be `str | int` as well, not just `str`. +def _rsetattr(obj: CoqpitType, keys: str, val: CoqpitValue) -> None: + """Recursive setattr (supports dotted key names)""" + pre, _, post = keys.rpartition(".") + target = _rgetattr(obj, pre) if pre else obj + if post.isnumeric(): + operator.setitem(target, int(post), val) # type: ignore[misc] + else: + setattr(target, post, val) -# Recursive getattr (supports dotted attr names) -def rgetattr(obj, attr, *args): - def _getitem(obj, attr): - return operator.getitem(obj, int(attr), *args) - def _getattr(obj, attr): - getfunc = _getitem if attr.isnumeric() else getattr - return getfunc(obj, attr, *args) +def _rgetattr(obj: CoqpitType, keys: str) -> CoqpitType: + """Recursive getattr (supports dotted key names).""" + v = obj + for k in keys.split("."): + v = operator.getitem(v, int(k)) if k.isnumeric() else getattr(v, k) # type: ignore[arg-type] + return v - return functools.reduce(_getattr, [obj, *attr.split(".")]) +def _rsetitem(obj: CoqpitType, keys: str, value: CoqpitValue) -> None: + """Recursive setitem (supports dotted key names). -# Recursive setitem (supports dotted attr names) -def rsetitem(obj, attr, val): - pre, _, post = attr.rpartition(".") - return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val) + _rsetitem(a, "b.c", 1) => a["b"]["c"] = 1 + """ + pre, _, post = keys.rpartition(".") + operator.setitem(_rgetitem(obj, pre) if pre else obj, post, value) -# Recursive getitem (supports dotted attr names) -def rgetitem(obj, attr, *args): - def _getitem(obj, attr): - return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args) +def _rgetitem(obj: CoqpitType, keys: str) -> CoqpitType: + """Recursive getitem (supports dotted key names). - return functools.reduce(_getitem, [obj, *attr.split(".")]) + _rgetitem(a, "b.c") => a["b"]["c"] + """ + v = obj + for k in keys.split("."): + v = operator.getitem(v, int(k) if k.isnumeric() else k) # type: ignore[arg-type] + return v @dataclass @@ -798,7 +805,7 @@ def init_from_argparse( if k.startswith(f"{arg_prefix}."): k = k[len(f"{arg_prefix}.") :] - rsetitem(args_with_lists_processed, k, v) + _rsetitem(args_with_lists_processed, k, v) return cls(**args_with_lists_processed) @@ -828,12 +835,12 @@ def parse_args( if k.startswith(f"{arg_prefix}."): k = k[len(f"{arg_prefix}.") :] try: - rgetattr(self, k) + _rgetattr(self, k) except (TypeError, AttributeError) as e: msg = f" [!] '{k}' not exist to override from argparse." raise Exception(msg) from e - rsetattr(self, k, v) + _rsetattr(self, k, v) self.check_values() From 99de702547cf1250ae73c29b17ac16f80fa19efb Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:48:29 +0200 Subject: [PATCH 14/27] refactor: simplify code and add type hints --- coqpit/coqpit.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index aab7df9..96f85bc 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -235,7 +235,9 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any: return x -def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: type) -> Union[int, float, str, bool]: +def _deserialize_primitive_types( + x: int | float | str | bool | None, field_type: FieldType +) -> int | float | str | bool | None: """Deserialize python primitive types (float, int, str, bool). It handles `inf` values exclusively and keeps them float against int fields since int does not support inf values. @@ -247,18 +249,22 @@ def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: ty Returns: Union[int, float, str, bool]: deserialized value. """ + if x is None: + return None if isinstance(x, (str, bool)): return x if isinstance(x, (int, float)): + base_type = _drop_none_type(field_type) + if base_type is not float and base_type is not int and base_type is not str and base_type is not bool: + raise TypeError + base_type = typing.cast(type[int | float | str | bool], base_type) if x == float("inf") or x == float("-inf"): # if value type is inf return regardless. return x - return field_type(x) - # TODO: Raise an error when x does not match the types. - return None + return base_type(x) -def _deserialize(x: Any, field_type: Any) -> Any: +def _deserialize(x: Any, field_type: FieldType) -> Any: """Pick the right deserialization for the given object and the corresponding field type. Args: @@ -468,10 +474,9 @@ def deserialize_immutable(cls, data: dict[str, Any]) -> Self: def _get_help(field: Field[Any]) -> str: try: - field_help = field.metadata["help"] + return str(field.metadata["help"]) except KeyError: - field_help = "" - return field_help + return "" def _init_argparse( @@ -485,13 +490,15 @@ def _init_argparse( help_prefix="", *, relaxed_parser: bool = False, -): - has_default = False +) -> argparse.ArgumentParser: + """Add a new argument to the argparse parser, matching the given field.""" + assert not isinstance(field_type, str) default = None + has_default = False if field_default: has_default = True default = field_default - elif field_default_factory not in (None, _MISSING): + elif field_default_factory is not None and field_default_factory is not _MISSING: has_default = True default = field_default_factory() @@ -567,7 +574,7 @@ def _init_argparse( help_prefix=help_prefix, relaxed_parser=relaxed_parser, ) - elif isinstance(field_type(), bool): + elif field_type is bool: def parse_bool(x: str) -> bool: if x not in ("true", "false"): @@ -603,7 +610,7 @@ def parse_bool(x: str) -> bool: @dataclass -class Coqpit(Serializable, MutableMapping): +class Coqpit(Serializable, CoqpitType): """Coqpit base class to be inherited by any Coqpit dataclasses. It overrides Python `dict` interface and provides `dict` compatible API. @@ -782,7 +789,7 @@ def init_from_argparse( # Handle list and object attributes with defaults, which can be modified # directly (eg. --coqpit.list.0.val_a 1), by constructing real objects # from defaults and passing those to `cls.__init__` - args_with_lists_processed = {} + args_with_lists_processed: CoqpitType = {} class_fields = fields(cls) for field in class_fields: has_default = False @@ -971,7 +978,7 @@ def check_argument( if is_path: assert Path(c[name]).exists(), f' [!] path for {name} ("{c[name]}") does not exist.' # skip the rest if the alternative field is defined. - if alternative in c and c[alternative] is not None: + if alternative is not None and alternative in c and c[alternative] is not None: return # check value constraints if name in c: From 2349aff880c07153c124d4864497af4557224562 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 10:54:23 +0200 Subject: [PATCH 15/27] fix: correctly serialize Serializable type --- coqpit/coqpit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 96f85bc..ae60e88 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -167,7 +167,7 @@ def _serialize(x: Any) -> Any: if isinstance(x, Serializable) or issubclass(type(x), Serializable): return x.serialize() if isinstance(x, type) and issubclass(x, Serializable): - return x.serialize(x) + return x.serialize(x()) return x From 52b00ace72893666e5ee74ea2d2e995069b18f4d Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 11:00:01 +0200 Subject: [PATCH 16/27] fix!: make Coqpit.update() consistent with superclass - implement all possible overloads - new values will now be always added, this can't be configured anymore This method doesn't seem to be used in Trainer or TTS, so this change shouldn't have a direct effect. --- coqpit/coqpit.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index ae60e88..b346e68 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -706,19 +706,29 @@ def has(self, arg: str) -> bool: def copy(self) -> Self: return replace(self) - def update(self, new: dict, allow_new: bool = False) -> None: + @overload + def update(self, other: SupportsKeysAndGetItem[str, CoqpitNestedValue], /, **kwargs: CoqpitNestedValue) -> None: ... + @overload + def update(self, other: Iterable[tuple[str, CoqpitNestedValue]], /, **kwargs: CoqpitNestedValue) -> None: ... + @overload + def update(self, /, **kwargs: CoqpitNestedValue) -> None: ... + def update(self, other: Any = (), /, **kwargs: CoqpitNestedValue) -> None: """Update Coqpit fields by the input ```dict```. Args: - new (dict): dictionary with new values. - allow_new (bool, optional): allow new fields to add. Defaults to False. + other (dict): dictionary with new values. """ - for key, value in new.items(): - if allow_new or hasattr(self, key): + if isinstance(other, dict): + for key in other: + setattr(self, key, other[key]) + elif hasattr(other, "keys"): + for key in other.keys(): + setattr(self, key, other[key]) + else: + for key, value in other: setattr(self, key, value) - else: - msg = f" [!] No key - {key}" - raise KeyError(msg) + for key, value in kwargs.items(): + setattr(self, key, value) def pprint(self) -> None: """Print Coqpit fields in a format.""" From 578747b65d25e91d8e1378a5d364f86216eb4f9d Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 11:03:48 +0200 Subject: [PATCH 17/27] fix: deserialize needs to be called as instance method --- coqpit/coqpit.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index b346e68..3917b78 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -370,9 +370,7 @@ def _validate_contracts(self) -> None: def validate(self) -> None: """Validate if object can serialize / deserialize correctly.""" self._validate_contracts() - if self != self.__class__.deserialize( # pylint: disable=no-value-for-parameter - json.loads(json.dumps(self.serialize())), - ): + if self != self.__class__().deserialize(json.loads(json.dumps(self.serialize()))): msg = "could not be deserialized with same value" raise ValueError(msg) @@ -423,7 +421,7 @@ def deserialize(self, data: dict[str, Any]) -> Self: init_kwargs[field.name] = value continue if value == MISSING: - msg = f"deserialized with unknown value for {field.name} in {self.__name__}" + msg = f"deserialized with unknown value for {field.name} in {self.__class__.__name__}" raise ValueError(msg) value = _deserialize(value, field.type) init_kwargs[field.name] = value From db183f2d939f1f5177568bcf4f0c7c18ac40a5e0 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 11:06:22 +0200 Subject: [PATCH 18/27] fix: correctly initialize from argparse Previously, class and instance methods were inconsistently mixed. --- coqpit/coqpit.py | 77 ++++++++++++++++++++++-------------- tests/test_parse_argparse.py | 2 +- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 3917b78..604ec16 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -477,15 +477,15 @@ def _get_help(field: Field[Any]) -> str: return "" -def _init_argparse( - parser, - field_name, - field_type, - field_default, - field_default_factory, - field_help, - arg_prefix="", - help_prefix="", +def _add_argument( + parser: argparse.ArgumentParser, + field_name: str, + field_type: FieldType, + field_default: Any, + field_default_factory: Callable[[], Any] | Literal[_MISSING_TYPE.MISSING], + field_help: str, + arg_prefix: str = "", + help_prefix: str = "", *, relaxed_parser: bool = False, ) -> argparse.ArgumentParser: @@ -548,8 +548,9 @@ def _init_argparse( else: # If a default value is defined, just enable editing the values from argparse # TODO: allow inserting a new value/obj to the end of the list. + assert isinstance(default, list) for idx, fv in enumerate(default): - parser = _init_argparse( + parser = _add_argument( parser, str(idx), list_field_type, @@ -565,9 +566,11 @@ def _init_argparse( if not relaxed_parser: msg = " [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue." raise NotImplementedError(msg) - elif issubclass(field_type, Serializable): + elif not _is_union(field_type) and issubclass(field_type, Coqpit): + assert isinstance(default, Coqpit) return default.init_argparse( - parser, + instance=default, + parser=parser, arg_prefix=arg_prefix, help_prefix=help_prefix, relaxed_parser=relaxed_parser, @@ -787,12 +790,14 @@ def init_from_argparse( """ if not args: # If args was not specified, parse from sys.argv - parser = cls.init_argparse(cls, arg_prefix=arg_prefix) - args = parser.parse_args() # pylint: disable=E1120, E1111 + parser = cls.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args() if isinstance(args, list): - # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace - parser = cls.init_argparse(cls, arg_prefix=arg_prefix) - args = parser.parse_args(args) # pylint: disable=E1120, E1111 + # If a list was passed in (eg. the second result of + # `parse_known_args`, run that through argparse first to get a + # parsed Namespace + parser = cls.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args(args) # Handle list and object attributes with defaults, which can be modified # directly (eg. --coqpit.list.0.val_a 1), by constructing real objects @@ -837,11 +842,13 @@ def parse_args( """ if not args: # If args was not specified, parse from sys.argv - parser = self.init_argparse(arg_prefix=arg_prefix) + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix) args = parser.parse_args() if isinstance(args, list): - # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace - parser = self.init_argparse(arg_prefix=arg_prefix) + # If a list was passed in (eg. the second result of + # `parse_known_args`, run that through argparse first + # to get a parsed Namespace + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix) args = parser.parse_args(args) args_dict = vars(args) @@ -879,21 +886,26 @@ def parse_known_args( """ if not args: # If args was not specified, parse from sys.argv - parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) args, unknown = parser.parse_known_args() if isinstance(args, list): - # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace - parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + # If a list was passed in (eg. the second result of + # `parse_known_args`, run that through argparse first to get a + # parsed Namespace + parser = self.init_argparse(instance=self, arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) args, unknown = parser.parse_known_args(args) self.parse_args(args) return unknown + @classmethod def init_argparse( - self, - parser: Optional[argparse.ArgumentParser] = None, - arg_prefix="coqpit", - help_prefix="", + cls, + *, + instance: Self | None = None, + parser: argparse.ArgumentParser | None = None, + arg_prefix: str = "coqpit", + help_prefix: str = "", relaxed_parser: bool = False, ) -> argparse.ArgumentParser: """Create an argparse parser that can parse the Coqpit fields. @@ -901,6 +913,8 @@ def init_argparse( This allows to edit values through command-line. Args: + instance (Coqpit, optional): instance of the given Coqpit class + to initialize any default values. parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created. arg_prefix (str, optional): Prefix to be used for the argument name. Defaults to 'coqpit'. help_prefix (str, optional): Prefix to be used for the argument description. Defaults to ''. @@ -911,15 +925,18 @@ def init_argparse( """ if not parser: parser = argparse.ArgumentParser() - class_fields = fields(self) + cls_or_instance = cls if instance is None else instance + class_fields = fields(cls_or_instance) for field in class_fields: # use the current value of the field to prevent dropping the current value, # else use the default value of the field - field_default = vars(self).get(field.name, field.default if field.default is not _MISSING else None) + field_default = vars(cls_or_instance).get( + field.name, field.default if field.default is not _MISSING else None + ) field_type = field.type field_default_factory = field.default_factory field_help = _get_help(field) - _init_argparse( + _add_argument( parser, field.name, field_type, diff --git a/tests/test_parse_argparse.py b/tests/test_parse_argparse.py index 6637e92..e9adb46 100644 --- a/tests/test_parse_argparse.py +++ b/tests/test_parse_argparse.py @@ -70,7 +70,7 @@ def test_parse_argparse() -> None: ) # create and init argparser with Coqpit - parser = config.init_argparse() + parser = config.init_argparse(instance=config) parser.print_help() # parse the argsparser From b0090c273dc599c1270d8dd02d5ec289e32b12eb Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 11:36:08 +0200 Subject: [PATCH 19/27] build: switch from setuptools to hatchling --- .gitignore | 2 ++ MANIFEST.in | 9 --------- pyproject.toml | 16 +++++++++++----- setup.py | 5 ----- 4 files changed, 13 insertions(+), 19 deletions(-) delete mode 100644 MANIFEST.in delete mode 100644 setup.py diff --git a/.gitignore b/.gitignore index 33a2f6e..a3bbda1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +uv.lock + WadaSNR/ .idea/ *.pyc diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 8b8dab6..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,9 +0,0 @@ -include README.md -include LICENSE.txt -recursive-include coqpit *.json -recursive-include coqpit *.html -recursive-include coqpit *.png -recursive-include coqpit *.md -recursive-include coqpit *.py -recursive-include coqpit *.pyx -recursive-include images *.png diff --git a/pyproject.toml b/pyproject.toml index 2bba169..7d0f4d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,6 @@ [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" - -[tool.setuptools.packages.find] -include = ["coqpit*"] +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "coqpit" @@ -48,6 +45,15 @@ dev-dependencies = [ Repository = "https://github.com/idiap/coqui-ai-coqpit" Issues = "https://github.com/idiap/coqui-ai-coqpit/issues" +[tool.hatch.build] +exclude = [ + "/.github", + "/.gitignore", + "/.pre-commit-config.yaml", + "/Makefile", + "/tests", +] + [tool.ruff] target-version = "py39" line-length = 120 diff --git a/setup.py b/setup.py deleted file mode 100644 index beda28e..0000000 --- a/setup.py +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env python - -from setuptools import setup - -setup() From 2c6929e5b9824565993082e603b19806920fad83 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 11:37:40 +0200 Subject: [PATCH 20/27] feat: declare typing support, run mypy in pre-commit --- .pre-commit-config.yaml | 5 +++++ coqpit/py.typed | 0 pyproject.toml | 1 + 3 files changed, 6 insertions(+) create mode 100644 coqpit/py.typed diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 727d5cc..3d8409e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,3 +12,8 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.12.0 + hooks: + - id: mypy + args: [--strict] diff --git a/coqpit/py.typed b/coqpit/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 7d0f4d6..ca55f75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ [tool.uv] dev-dependencies = [ "coverage>=7", + "mypy>=1.12.0", "pre-commit>=3", "pytest>=8", "ruff==0.6.9", From 650926cde73027b2c56bf5ef5e17263bbe718a08 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 12:27:48 +0200 Subject: [PATCH 21/27] fix: adapt type hints to Python 3.9 Mostly NoneType and UnionType are only available from 3.10 --- coqpit/coqpit.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 604ec16..0c88ac7 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -5,17 +5,24 @@ import json import operator import os +import sys import typing from collections.abc import ItemsView, Iterable, Iterator, MutableMapping from dataclasses import MISSING as _MISSING from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace from pathlib import Path from pprint import pprint -from types import GenericAlias, NoneType, UnionType +from types import GenericAlias from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, overload from typing_extensions import Self, TypeAlias, TypeGuard, TypeIs +# TODO: Available from Python 3.10 +if sys.version_info >= (3, 10): + from types import UnionType +else: + UnionType: TypeAlias = Union + if TYPE_CHECKING: # pragma: no cover from dataclasses import _MISSING_TYPE @@ -32,7 +39,7 @@ class _NoDefault(Generic[_T]): NoDefaultVar: TypeAlias = Union[_NoDefault[_T], _T] no_default: NoDefaultVar[Any] = _NoDefault() -FieldType: TypeAlias = str | type | UnionType +FieldType: TypeAlias = Union[str, type, "UnionType"] def _is_primitive_type(field_type: FieldType) -> TypeGuard[type]: @@ -81,7 +88,10 @@ def _is_union(field_type: FieldType) -> TypeIs[UnionType]: bool: True if input type is `Union` """ origin = typing.get_origin(field_type) - return origin is Union or origin is UnionType + is_union = origin is Union + if sys.version_info >= (3, 10): + is_union = is_union or origin is UnionType + return is_union def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionType]: @@ -98,7 +108,7 @@ def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionT """ args = typing.get_args(field_type) is_python_union = _is_union(field_type) - if is_python_union and len(args) == 2 and NoneType in args: + if is_python_union and len(args) == 2 and type(None) in args: # This is an Optional type like `int | None` return False return is_python_union @@ -129,7 +139,7 @@ def _is_optional_field(field_type: FieldType) -> TypeGuard[UnionType]: Returns: bool: True if the input field type is optional. """ - return NoneType in typing.get_args(field_type) + return type(None) in typing.get_args(field_type) def _drop_none_type(field_type: FieldType) -> FieldType: @@ -142,11 +152,11 @@ def _drop_none_type(field_type: FieldType) -> FieldType: return field_type origin = typing.get_origin(field_type) args = list(typing.get_args(field_type)) - if NoneType in args: - args.remove(NoneType) + if type(None) in args: + args.remove(type(None)) if len(args) == 1: return typing.cast(type, args[0]) - return typing.cast(UnionType, GenericAlias(origin, args)) + return typing.cast("UnionType", GenericAlias(origin, args)) def _serialize(x: Any) -> Any: @@ -257,7 +267,7 @@ def _deserialize_primitive_types( base_type = _drop_none_type(field_type) if base_type is not float and base_type is not int and base_type is not str and base_type is not bool: raise TypeError - base_type = typing.cast(type[int | float | str | bool], base_type) + base_type = typing.cast(type[Union[int, float, str, bool]], base_type) if x == float("inf") or x == float("-inf"): # if value type is inf return regardless. return x @@ -297,7 +307,7 @@ def _deserialize(x: Any, field_type: FieldType) -> Any: CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"] CoqpitNestedValue: TypeAlias = Union["CoqpitValue", CoqpitType] -CoqpitValue: TypeAlias = str | int | float | bool | None +CoqpitValue: TypeAlias = Union[str, int, float, bool, None] # TODO: It should be possible to get rid of the next 3 `type: ignore`. At From 98912891dc1cf650ddfcdbebd91cc194c9a7c11e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 18:06:31 +0200 Subject: [PATCH 22/27] fix: replace asserts with exceptions --- coqpit/coqpit.py | 65 ++++++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 0c88ac7..deb1166 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -285,8 +285,9 @@ def _deserialize(x: Any, field_type: FieldType) -> Any: object: deserialized object """ - # pylint: disable=too-many-return-statements - assert not isinstance(field_type, str) + if isinstance(field_type, str): + msg = "Strings as type hints are not supported." + raise NotImplementedError(msg) if _is_dict(_drop_none_type(field_type)): return _deserialize_dict(x) if _is_list(_drop_none_type(field_type)): @@ -500,7 +501,9 @@ def _add_argument( relaxed_parser: bool = False, ) -> argparse.ArgumentParser: """Add a new argument to the argparse parser, matching the given field.""" - assert not isinstance(field_type, str) + if isinstance(field_type, str): + msg = "Strings as type hints are not supported." + raise NotImplementedError(msg) default = None has_default = False if field_default: @@ -558,7 +561,9 @@ def _add_argument( else: # If a default value is defined, just enable editing the values from argparse # TODO: allow inserting a new value/obj to the end of the list. - assert isinstance(default, list) + if not isinstance(default, list): + msg = f"Default value must be a list, got {default}" + raise TypeError(msg) for idx, fv in enumerate(default): parser = _add_argument( parser, @@ -577,7 +582,9 @@ def _add_argument( msg = " [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue." raise NotImplementedError(msg) elif not _is_union(field_type) and issubclass(field_type, Coqpit): - assert isinstance(default, Coqpit) + if not isinstance(default, Coqpit): + msg = f"Default value must be a Coqpit instance, got {default}" + raise TypeError(msg) return default.init_argparse( instance=default, parser=parser, @@ -602,7 +609,8 @@ def parse_bool(x: str) -> bool: ) elif _is_primitive_type(_drop_none_type(field_type)): base_type = _drop_none_type(field_type) - assert not _is_union(base_type) + if _is_union(base_type): + raise TypeError parser.add_argument( f"--{arg_prefix}", default=field_default, @@ -870,7 +878,7 @@ def parse_args( _rgetattr(self, k) except (TypeError, AttributeError) as e: msg = f" [!] '{k}' not exist to override from argparse." - raise Exception(msg) from e + raise TypeError(msg) from e _rsetattr(self, k, v) @@ -997,29 +1005,38 @@ def check_argument( >>> check_argument('fft_size', c, restricted=True, min_val=128, max_val=4058) """ # check if None allowed - if allow_none and c[name] is None: - return - if not allow_none: - assert c[name] is not None, f" [!] None value is not allowed for {name}." + if c[name] is None: + if allow_none: + return + msg = f" [!] None value is not allowed for {name}." + raise TypeError(msg) # check if restricted and it it is check if it exists - if isinstance(restricted, bool) and restricted: - assert name in c, f" [!] {name} not defined in config.json" + if isinstance(restricted, bool) and restricted and name not in c: + msg = f" [!] {name} not defined in config.json" + raise KeyError(msg) # check prerequest fields are defined if isinstance(prerequest, list): - assert any(f not in c for f in prerequest), f" [!] prequested fields {prerequest} for {name} are not defined." - else: - assert prerequest is None or prerequest in c, f" [!] prequested fields {prerequest} for {name} are not defined." + if any(f not in c for f in prerequest): + msg = f" [!] prequested fields {prerequest} for {name} are not defined." + raise KeyError(msg) + elif prerequest is not None and prerequest not in c: + msg = f" [!] prequested field {prerequest} for {name} is not defined." + raise KeyError(msg) # check if the path exists - if is_path: - assert Path(c[name]).exists(), f' [!] path for {name} ("{c[name]}") does not exist.' + if is_path and not Path(c[name]).exists(): + msg = f' [!] path for {name} ("{c[name]}") does not exist.' + raise FileNotFoundError(msg) # skip the rest if the alternative field is defined. if alternative is not None and alternative in c and c[alternative] is not None: return # check value constraints if name in c: - if max_val is not None: - assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}" - if min_val is not None: - assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" - if enum_list is not None: - assert c[name].lower() in enum_list, f" [!] {name} is not a valid value" + if max_val is not None and c[name] > max_val: + msg = f" [!] {name} is larger than max value {max_val}" + raise ValueError + if min_val is not None and c[name] < min_val: + msg = f" [!] {name} is smaller than min value {min_val}" + raise ValueError + if enum_list is not None and c[name].lower() not in enum_list: + msg = f" [!] {name} is not a valid value" + raise ValueError From 0563504e948dc7960a700609ff3ee2b51cf21172 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 18:07:28 +0200 Subject: [PATCH 23/27] refactor: enable (almost) all lint rules --- coqpit/coqpit.py | 131 +++++++++++++++++++---------- pyproject.toml | 17 +++- tests/test_init_from_dict.py | 8 +- tests/test_nested_configs.py | 8 +- tests/test_parse_argparse.py | 13 +-- tests/test_parse_known_argparse.py | 6 +- tests/test_serialization.py | 2 +- tests/test_simple_config.py | 10 +-- 8 files changed, 127 insertions(+), 68 deletions(-) diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index deb1166..2a627e5 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -1,10 +1,11 @@ +"""Simple, light-weight config handling through Python data classes.""" + from __future__ import annotations import argparse import contextlib import json import operator -import os import sys import typing from collections.abc import ItemsView, Iterable, Iterator, MutableMapping @@ -24,6 +25,7 @@ UnionType: TypeAlias = Union if TYPE_CHECKING: # pragma: no cover + import os from dataclasses import _MISSING_TYPE from _typeshed import SupportsKeysAndGetItem @@ -108,7 +110,7 @@ def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionT """ args = typing.get_args(field_type) is_python_union = _is_union(field_type) - if is_python_union and len(args) == 2 and type(None) in args: + if is_python_union and len(args) == 2 and type(None) in args: # noqa: PLR2004 # This is an Optional type like `int | None` return False return is_python_union @@ -215,7 +217,7 @@ def _deserialize_list(x: list[_T], field_type: FieldType) -> list[_T]: field_args = typing.get_args(field_type) if len(field_args) == 0: return x - elif len(field_args) > 1: + if len(field_args) > 1: msg = "Coqpit does not support multi-type hinted 'List'" raise ValueError(msg) field_arg = field_args[0] @@ -246,7 +248,8 @@ def _deserialize_union(x: Any, field_type: UnionType) -> Any: def _deserialize_primitive_types( - x: int | float | str | bool | None, field_type: FieldType + x: int | float | str | bool | None, # noqa: PYI041 + field_type: FieldType, ) -> int | float | str | bool | None: """Deserialize python primitive types (float, int, str, bool). @@ -259,8 +262,6 @@ def _deserialize_primitive_types( Returns: Union[int, float, str, bool]: deserialized value. """ - if x is None: - return None if isinstance(x, (str, bool)): return x if isinstance(x, (int, float)): @@ -272,6 +273,14 @@ def _deserialize_primitive_types( # if value type is inf return regardless. return x return base_type(x) + return None + + +def _deserialize_path(x: Any, field_type: FieldType) -> Path | None: + """Deserialize to a Path.""" + if x is None and _is_optional_field(field_type): + return None + return Path(x) def _deserialize(x: Any, field_type: FieldType) -> Any: @@ -297,9 +306,7 @@ def _deserialize(x: Any, field_type: FieldType) -> Any: if not _is_union(field_type) and isinstance(field_type, type) and issubclass(field_type, Serializable): return field_type.deserialize_immutable(x) if _drop_none_type(field_type) is Path: - if x is None and _is_optional_field(field_type): - return None - return Path(x) + return _deserialize_path(x, field_type) if _is_primitive_type(_drop_none_type(field_type)): return _deserialize_primitive_types(x, field_type) msg = f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type." @@ -314,7 +321,7 @@ def _deserialize(x: Any, field_type: FieldType) -> Any: # TODO: It should be possible to get rid of the next 3 `type: ignore`. At # nested levels, the key can be `str | int` as well, not just `str`. def _rsetattr(obj: CoqpitType, keys: str, val: CoqpitValue) -> None: - """Recursive setattr (supports dotted key names)""" + """Recursive setattr (supports dotted key names).""" pre, _, post = keys.rpartition(".") target = _rgetattr(obj, pre) if pre else obj if post.isnumeric(): @@ -356,6 +363,7 @@ class Serializable: """Gives serialization ability to any inheriting dataclass.""" def __post_init__(self) -> None: + """Validate contracts and check required arguments are specified.""" self._validate_contracts() for key, value in self.__dict__.items(): if value is no_default: @@ -363,6 +371,7 @@ def __post_init__(self) -> None: raise TypeError(msg) def _validate_contracts(self) -> None: + """Validate contracts specified in the dataclass.""" dataclass_fields = fields(self) for field in dataclass_fields: @@ -488,7 +497,7 @@ def _get_help(field: Field[Any]) -> str: return "" -def _add_argument( +def _add_argument( # noqa: C901, PLR0913, PLR0912, PLR0915 parser: argparse.ArgumentParser, field_name: str, field_type: FieldType, @@ -523,7 +532,7 @@ def _add_argument( return parser arg_prefix = field_name if arg_prefix == "" else f"{arg_prefix}.{field_name}" help_prefix = field_help if help_prefix == "" else f"{help_prefix} - {field_help}" - if _is_dict(field_type): # pylint: disable=no-else-raise + if _is_dict(field_type): # NOTE: accept any string in json format as input to dict field. parser.add_argument( f"--{arg_prefix}", @@ -537,7 +546,7 @@ def _add_argument( if len(field_args) > 1 and not relaxed_parser: msg = "Coqpit does not support multi-type hinted 'List'" raise ValueError(msg) - elif len(field_args) == 0: + if len(field_args) == 0: msg = "Coqpit does not support un-hinted 'List'" raise ValueError(msg) list_field_type = field_args[0] @@ -633,7 +642,9 @@ class Coqpit(Serializable, CoqpitType): """Coqpit base class to be inherited by any Coqpit dataclasses. It overrides Python `dict` interface and provides `dict` compatible API. - It also enables serializing/deserializing a dataclass to/from a json file, plus some semi-dynamic type and value check. + It also enables serializing/deserializing a dataclass to/from a json file, + plus some semi-dynamic type and value check. + Note that it does not support all datatypes and likely to fail in some cases. """ @@ -648,6 +659,7 @@ def _is_initialized(self) -> bool: return "_initialized" in vars(self) and self._initialized def __post_init__(self) -> None: + """Check values if a check_values() method is defined.""" self._initialized = True with contextlib.suppress(AttributeError): self.check_values() @@ -655,12 +667,15 @@ def __post_init__(self) -> None: ## `dict` API functions def __iter__(self) -> Iterator[str]: + """Return iterator over the Coqpit.""" return iter(asdict(self)) def __len__(self) -> int: + """Return the number of fields in the Coqpit.""" return len(fields(self)) def __setitem__(self, arg: str, value: Any) -> None: + """Set the value for the given attribute.""" setattr(self, arg, value) def __getitem__(self, arg: str) -> Any: @@ -668,14 +683,15 @@ def __getitem__(self, arg: str) -> Any: return self.__dict__[arg] def __delitem__(self, arg: str) -> None: + """Remove an attribute.""" delattr(self, arg) - def _keytransform(self, key: str) -> str: # pylint: disable=no-self-use + def _keytransform(self, key: str) -> str: return key ## end `dict` API functions - def __getattribute__(self, arg: str) -> Any: # pylint: disable=no-self-use + def __getattribute__(self, arg: str) -> Any: """Check if the mandatory field is defined when accessing it.""" value = super().__getattribute__(arg) if isinstance(value, str) and value == "???": @@ -684,14 +700,17 @@ def __getattribute__(self, arg: str) -> Any: # pylint: disable=no-self-use return value def __contains__(self, arg: object) -> bool: + """Check whether the Coqpit contains the given attribute.""" return arg in self.to_dict() def get(self, key: str, default: Any = None) -> Any: + """Return value of the given attribute if present, otherwise the default.""" if self.has(key): return asdict(self)[key] return default def items(self) -> ItemsView[str, Any]: + """Return (key, value) items of the Coqpit.""" return asdict(self).items() def merge(self, coqpits: Coqpit | list[Coqpit]) -> None: @@ -717,12 +736,17 @@ def _merge(coqpit: Coqpit) -> None: _merge(coqpits) def check_values(self) -> None: - pass + """Perform data validation after initialization. + + Can be implemented in subclasses. + """ def has(self, arg: str) -> bool: + """Check whether the Coqpit has the given attribute.""" return arg in vars(self) def copy(self) -> Self: + """Return a copy of the Coqpit.""" return replace(self) @overload @@ -735,13 +759,14 @@ def update(self, other: Any = (), /, **kwargs: CoqpitNestedValue) -> None: """Update Coqpit fields by the input ```dict```. Args: - other (dict): dictionary with new values. + other: dictionary or iterable with new values. + **kwargs: alternative way to pass new keys and values. """ if isinstance(other, dict): for key in other: setattr(self, key, other[key]) elif hasattr(other, "keys"): - for key in other.keys(): + for key in other.keys(): # noqa: SIM118 setattr(self, key, other[key]) else: for key, value in other: @@ -751,21 +776,23 @@ def update(self, other: Any = (), /, **kwargs: CoqpitNestedValue) -> None: def pprint(self) -> None: """Print Coqpit fields in a format.""" - pprint(asdict(self)) + pprint(asdict(self)) # noqa: T203 def to_dict(self) -> dict[str, Any]: - # return asdict(self) + """Convert the Coqpit to a dictionary, serializing any values.""" return self.serialize() def from_dict(self, data: dict[str, Any]) -> None: + """Update Coqpit from the dictionary.""" self.deserialize(data) @classmethod def new_from_dict(cls, data: dict[str, Any]) -> Self: + """Create a new Coqpit from a dictionary.""" return cls.deserialize_immutable(data) def to_json(self) -> str: - """Returns a JSON string representation.""" + """Return a JSON string representation.""" return json.dumps(self.to_dict(), indent=4) def save_json(self, file_name: str | os.PathLike[Any]) -> None: @@ -774,7 +801,7 @@ def save_json(self, file_name: str | os.PathLike[Any]) -> None: Args: file_name (str): path to the output json file. """ - with open(file_name, "w", encoding="utf8") as f: + with Path(file_name).open("w", encoding="utf8") as f: json.dump(self.to_dict(), f, indent=4) def load_json(self, file_name: str | os.PathLike[Any]) -> None: @@ -788,7 +815,7 @@ def load_json(self, file_name: str | os.PathLike[Any]) -> None: Returns: Coqpit: new Coqpit with updated config fields. """ - with open(file_name, encoding="utf8") as f: + with Path(file_name).open(encoding="utf8") as f: input_str = f.read() dump_dict = json.loads(input_str) self.deserialize(dump_dict) @@ -803,8 +830,11 @@ def init_from_argparse( """Create a new Coqpit instance from argparse input. Args: - args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. - arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + args (namespace or list of str, optional): parsed argparse.Namespace + or list of command line parameters. If unspecified will use a + newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to + ```init_argparse``` when ```args``` is not passed. """ if not args: # If args was not specified, parse from sys.argv @@ -838,11 +868,9 @@ def init_from_argparse( args_with_lists_processed[field.name] = default args_dict = vars(args) - for k, v in args_dict.items(): + for key, v in args_dict.items(): # Remove argparse prefix (eg. "--coqpit." if present) - if k.startswith(f"{arg_prefix}."): - k = k[len(f"{arg_prefix}.") :] - + k = key.removeprefix(f"{arg_prefix}.") _rsetitem(args_with_lists_processed, k, v) return cls(**args_with_lists_processed) @@ -855,8 +883,11 @@ def parse_args( """Update config values from argparse arguments with some meta-programming ✨. Args: - args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. - arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + args (namespace or list of str, optional): parsed argparse.Namespace + or list of command line parameters. If unspecified will use a + newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to + ```init_argparse``` when ```args``` is not passed. """ if not args: # If args was not specified, parse from sys.argv @@ -871,9 +902,8 @@ def parse_args( args_dict = vars(args) - for k, v in args_dict.items(): - if k.startswith(f"{arg_prefix}."): - k = k[len(f"{arg_prefix}.") :] + for key, v in args_dict.items(): + k = key.removeprefix(f"{arg_prefix}.") try: _rgetattr(self, k) except (TypeError, AttributeError) as e: @@ -888,6 +918,7 @@ def parse_known_args( self, args: argparse.Namespace | list[str] | None = None, arg_prefix: str = "coqpit", + *, relaxed_parser: bool = False, ) -> list[str]: """Update config values from argparse arguments. Ignore unknown arguments. @@ -895,9 +926,13 @@ def parse_known_args( This is analog to argparse.ArgumentParser.parse_known_args (vs parse_args). Args: - args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. - arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. - relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + args (namespace or list of str, optional): parsed argparse.Namespace + or list of command line parameters. If unspecified will use a + newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to + ```init_argparse``` when ```args``` is not passed. + relaxed_parser (bool, optional): If True, do not force all the fields + to have compatible types with the argparser. Defaults to False. Returns: List of unknown parameters. @@ -933,10 +968,14 @@ def init_argparse( Args: instance (Coqpit, optional): instance of the given Coqpit class to initialize any default values. - parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created. - arg_prefix (str, optional): Prefix to be used for the argument name. Defaults to 'coqpit'. - help_prefix (str, optional): Prefix to be used for the argument description. Defaults to ''. - relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + parser (argparse.ArgumentParser, optional): argparse.ArgumentParser + instance. If unspecified a new one will be created. + arg_prefix (str, optional): Prefix to be used for the argument name. + Defaults to 'coqpit'. + help_prefix (str, optional): Prefix to be used for the argument + description. Defaults to ''. + relaxed_parser (bool, optional): If True, do not force all the fields + to have compatible types with the argparser. Defaults to False. Returns: argparse.ArgumentParser: parser instance with the new arguments. @@ -949,7 +988,8 @@ def init_argparse( # use the current value of the field to prevent dropping the current value, # else use the default value of the field field_default = vars(cls_or_instance).get( - field.name, field.default if field.default is not _MISSING else None + field.name, + field.default if field.default is not _MISSING else None, ) field_type = field.type field_default_factory = field.default_factory @@ -968,11 +1008,12 @@ def init_argparse( return parser -def check_argument( +def check_argument( # noqa: C901, PLR0913 name: str, c: dict[str, Any], + *, is_path: bool = False, - prerequest: str | None = None, + prerequest: list[str] | str | None = None, enum_list: list[Any] | None = None, max_val: float | None = None, min_val: float | None = None, diff --git a/pyproject.toml b/pyproject.toml index ca55f75..709c60c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,22 @@ exclude = [ [tool.ruff] target-version = "py39" line-length = 120 -lint.extend-select = ["I", "UP", "B", "W", "A", "PLC", "PLE"] +lint.select = ["ALL"] +lint.ignore = [ + "ANN401", + "D104", + "FIX", + "TD", +] [tool.ruff.lint.pydocstyle] convention = "google" + +[tool.ruff.lint.per-file-ignores] +"tests/**" = [ + "D", + "FA100", + "PLR2004", + "S101", + "T201", +] diff --git a/tests/test_init_from_dict.py b/tests/test_init_from_dict.py index a18fb67..616d198 100644 --- a/tests/test_init_from_dict.py +++ b/tests/test_init_from_dict.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from typing import Optional +import pytest # type: ignore[import-not-found] + from coqpit import Coqpit @@ -19,7 +21,7 @@ class Reference(Coqpit): Person(name="Eren", age=11), Person(name="Geren", age=12), Person(name="Ceren", age=15), - ] + ], ) people_ids: list[int] = field(default_factory=lambda: [1, 2, 3]) @@ -41,7 +43,5 @@ def test_new_from_dict() -> None: assert ref_config.people[0].name == new_config.people[0].name assert ref_config.people[0].age == new_config.people[0].age - try: + with pytest.raises(ValueError, match="Missing required field"): WithRequired.new_from_dict({}) - except ValueError as e: - assert "Missing required field" in e.args[0] diff --git a/tests/test_nested_configs.py b/tests/test_nested_configs.py index 8abb381..4117fe6 100644 --- a/tests/test_nested_configs.py +++ b/tests/test_nested_configs.py @@ -1,5 +1,5 @@ -import os from dataclasses import asdict, dataclass, field +from pathlib import Path from typing import Optional, Union from coqpit import Coqpit, check_argument @@ -39,16 +39,16 @@ def check_values(self) -> None: def test_nested() -> None: - file_path = os.path.dirname(os.path.abspath(__file__)) + file_path = Path(__file__).resolve().parent / "example_config.json" # init 🐸 dataclass config = NestedConfig() # save to a json file - config.save_json(os.path.join(file_path, "example_config.json")) + config.save_json(file_path) # load a json file config2 = NestedConfig(val_e=500) # update the config with the json file. - config2.load_json(os.path.join(file_path, "example_config.json")) + config2.load_json(file_path) # now they should be having the same values. assert config == config2 diff --git a/tests/test_parse_argparse.py b/tests/test_parse_argparse.py index e9adb46..ec7d137 100644 --- a/tests/test_parse_argparse.py +++ b/tests/test_parse_argparse.py @@ -24,10 +24,12 @@ class SimpleConfig(Coqpit): empty_int_list: Optional[list[int]] = field(default=None, metadata={"help": "int list without default value"}) empty_str_list: Optional[list[str]] = field(default=None, metadata={"help": "str list without default value"}) list_with_default_factory: list[str] = field( - default_factory=list, metadata={"help": "str list with default factory"} + default_factory=list, + metadata={"help": "str list with default factory"}, ) - # mylist_without_default: List[SimplerConfig] = field(default=None, metadata={'help': 'list of SimplerConfig'}) # NOT SUPPORTED YET! + # TODO: not supported yet + # mylist_without_default: List[SimplerConfig] = field(default=None) noqa: ERA001 def check_values(self) -> None: """Check config fields""" @@ -112,7 +114,7 @@ class Config(Coqpit): try: config.parse_args(args) - raise AssertionError("should not reach this") # pragma: no cover + raise AssertionError # pragma: no cover, should not reach this except SystemExit: pass @@ -126,7 +128,7 @@ def test_argparse_with_required_field() -> None: args = ["--coqpit.val_a", "10"] try: c = ArgparseWithRequiredField() # type: ignore[call-arg] - raise AssertionError("should not reach this") # pragma: no cover + raise AssertionError # pragma: no cover, should not reach this except TypeError: # __init__ should fail due to missing val_a pass @@ -151,7 +153,8 @@ class SimpleConfig2(Coqpit): metadata={"help": "list of SimplerConfig2"}, ) - # mylist_without_default: List[SimplerConfig2] = field(default=None, metadata={'help': 'list of SimplerConfig2'}) # NOT SUPPORTED YET! + # TODO: not supported yet + # mylist_without_default: List[SimplerConfig2] = field(default=None) # noqa: ERA001 def check_values(self) -> None: """Check config fields""" diff --git a/tests/test_parse_known_argparse.py b/tests/test_parse_known_argparse.py index 2208354..a1b3218 100644 --- a/tests/test_parse_known_argparse.py +++ b/tests/test_parse_known_argparse.py @@ -62,9 +62,9 @@ def test_parse_argparse() -> None: def test_parse_edited_argparse() -> None: - """calling `parse_known_argparse` after some modifications in the config values. - `parse_known_argparse` should keep the modified values if not defined in argv""" - + """Calling `parse_known_argparse` after some modifications in the config values. + `parse_known_argparse` should keep the modified values if not defined in argv + """ unknown_args = ["--coqpit.arg_does_not_exist", "111"] args = [] args.extend(["--coqpit.mylist_with_default.1.val_a", "111"]) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index f6bef15..15a5041 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -30,7 +30,7 @@ class Reference(Coqpit): Person(name="Eren", age=11), Person(name="Geren", age=12), Person(name="Ceren", age=15), - ] + ], ) some_dict: dict[str, Optional[int]] = field(default_factory=lambda: {"a": 1, "b": 2, "c": None}) diff --git a/tests/test_simple_config.py b/tests/test_simple_config.py index cf45257..ec84ec7 100644 --- a/tests/test_simple_config.py +++ b/tests/test_simple_config.py @@ -1,5 +1,5 @@ -import os from dataclasses import asdict, dataclass, field +from pathlib import Path from typing import Any, Optional, Union from coqpit.coqpit import MISSING, Coqpit, check_argument @@ -20,7 +20,7 @@ class SimpleConfig(Coqpit): # list of list val_listoflist: list[list[int]] = field(default_factory=lambda: [[1, 2], [3, 4]]) val_listofunion: list[list[Union[str, int, bool]]] = field( - default_factory=lambda: [[1, 3], [1, "Hi!"], [True, False]] + default_factory=lambda: [[1, 3], [1, "Hi!"], [True, False]], ) def check_values( @@ -34,7 +34,7 @@ def check_values( def test_simple_config() -> None: - file_path = os.path.dirname(os.path.abspath(__file__)) + file_path = Path(__file__).resolve().parent / "example_config.json" config = SimpleConfig() # try MISSING class argument @@ -47,8 +47,8 @@ def test_simple_config() -> None: # try serialization and deserialization print(config.serialize()) print(config.to_json()) - config.save_json(os.path.join(file_path, "example_config.json")) - config.load_json(os.path.join(file_path, "example_config.json")) + config.save_json(file_path) + config.load_json(file_path) config.pprint() # try `dict` interface From 44128eca7921a02d4cb374f6633d709a6a0f5b83 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 18:49:58 +0200 Subject: [PATCH 24/27] test: improve coverage --- .pre-commit-config.yaml | 2 ++ coqpit/coqpit.py | 2 +- pyproject.toml | 1 + tests/test_init_from_dict.py | 2 +- tests/test_serialization.py | 9 ++++++++- tests/test_simple_config.py | 5 +++++ 6 files changed, 18 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3d8409e..5d5f10e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,3 +17,5 @@ repos: hooks: - id: mypy args: [--strict] + additional_dependencies: + - "pytest" diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 2a627e5..73d34f6 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -201,7 +201,7 @@ def _deserialize_dict(x: dict[Any, Any]) -> dict[Any, Any]: return out_dict -def _deserialize_list(x: list[_T], field_type: FieldType) -> list[_T]: +def _deserialize_list(x: list[Any], field_type: FieldType) -> list[Any]: """Deserialize values for List typed fields. Args: diff --git a/pyproject.toml b/pyproject.toml index 709c60c..a29a727 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,5 +75,6 @@ convention = "google" "FA100", "PLR2004", "S101", + "SLF001", "T201", ] diff --git a/tests/test_init_from_dict.py b/tests/test_init_from_dict.py index 616d198..1e0a735 100644 --- a/tests/test_init_from_dict.py +++ b/tests/test_init_from_dict.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Optional -import pytest # type: ignore[import-not-found] +import pytest from coqpit import Coqpit diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 15a5041..557d261 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Optional -from coqpit import Coqpit +from coqpit.coqpit import Coqpit, _deserialize_list @dataclass @@ -59,3 +59,10 @@ def test_serialization() -> None: assert ref_config.some_dict["a"] == new_config.some_dict["a"] assert ref_config.some_dict["b"] == new_config.some_dict["b"] assert ref_config.some_dict["c"] == new_config.some_dict["c"] + + +def test_deserialize_list() -> None: + assert _deserialize_list([1, 2, 3], list) == [1, 2, 3] + assert _deserialize_list([1, 2, 3], list[int]) == [1, 2, 3] + assert _deserialize_list([1, 2, 3], list[float]) == [1.0, 2.0, 3.0] + assert _deserialize_list([1, 2, 3], list[str]) == ["1", "2", "3"] diff --git a/tests/test_simple_config.py b/tests/test_simple_config.py index ec84ec7..cc2d17a 100644 --- a/tests/test_simple_config.py +++ b/tests/test_simple_config.py @@ -37,6 +37,8 @@ def test_simple_config() -> None: file_path = Path(__file__).resolve().parent / "example_config.json" config = SimpleConfig() + assert config._is_initialized() + # try MISSING class argument try: _ = config.val_k @@ -44,6 +46,9 @@ def test_simple_config() -> None: print(" val_k needs a different value before accessing it.") config.val_k = 1000 + assert "val_a" in config + assert config.has("val_a") + # try serialization and deserialization print(config.serialize()) print(config.to_json()) From 61066c9e48afbbd322b7b2ba2ed396595550366f Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 18:58:40 +0200 Subject: [PATCH 25/27] chore: update pypi package name --- .github/workflows/pypi-release.yml | 2 +- README.md | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index cd46b55..8b45c1b 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -43,7 +43,7 @@ jobs: needs: [build] environment: name: release - url: https://pypi.org/p/coqui-tts-coqpit + url: https://pypi.org/p/coqpit-config permissions: id-token: write steps: diff --git a/README.md b/README.md index 762ab5f..4b0dea8 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,12 @@ [![CI](https://github.com/idiap/coqui-ai-coqpit/actions/workflows/main.yml/badge.svg?branch=main)](https://github.com/idiap/coqui-ai-coqpit/actions/workflows/main.yml) -Simple, light-weight and no dependency config handling through python data classes with to/from JSON serialization/deserialization. +Simple, light-weight and no dependency config handling through python data +classes with to/from JSON serialization/deserialization. -Currently it is being used by [🐸TTS](https://github.com/idiap/coqui-ai-TTS). +Fork of the [original, unmaintained repository](https://github.com/coqui-ai/coqpit). New PyPI package: [coqpit-config](https://pypi.org/project/coqpit-config) + +Currently it is being used by [coqui-tts](https://github.com/idiap/coqui-ai-TTS). ## ❔ Why I need this What I need from a ML configuration library... From 39adee3433aa1c2ef5a91551175dbf07e07521d5 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 19 Oct 2024 18:58:54 +0200 Subject: [PATCH 26/27] chore: bump version to 0.1.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a29a727..5bd45d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "coqpit" -version = "0.0.17" +version = "0.1.0" description = "Simple (maybe too simple), light-weight config management through python data-classes." readme = "README.md" requires-python = ">=3.9" From 8fbbc9ddd8718ced5bb20bcd5d7afcbd2d655d8b Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 25 Oct 2024 21:53:32 +0200 Subject: [PATCH 27/27] ci: update uv version, switch to pep 735 standard --- .github/workflows/main.yml | 2 +- .github/workflows/pypi-release.yml | 2 +- pyproject.toml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3614ed4..ed1e0ee 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,7 +19,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v3 with: - version: "0.4.21" + version: "0.4.27" enable-cache: true cache-dependency-glob: "**/pyproject.toml" - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 8b45c1b..29e9038 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -22,7 +22,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v3 with: - version: "0.4.21" + version: "0.4.27" enable-cache: true cache-dependency-glob: "**/pyproject.toml" - name: Set up Python diff --git a/pyproject.toml b/pyproject.toml index 5bd45d4..3cadc39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,8 @@ dependencies = [ "typing_extensions>=4.10", ] -[tool.uv] -dev-dependencies = [ +[dependency-groups] +dev = [ "coverage>=7", "mypy>=1.12.0", "pre-commit>=3",