Skip to content

Commit

Permalink
fix: adapt type hints to Python 3.9
Browse files Browse the repository at this point in the history
Mostly NoneType and UnionType are only available from 3.10
  • Loading branch information
eginhard committed Oct 19, 2024
1 parent 2c6929e commit c9ef025
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions coqpit/coqpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
from pathlib import Path
from pprint import pprint
from types import GenericAlias, NoneType, UnionType
from types import GenericAlias
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, overload

from typing_extensions import Self, TypeAlias, TypeGuard, TypeIs

# TODO: Available from Python 3.10
with contextlib.suppress(ImportError):
from types import UnionType

if TYPE_CHECKING: # pragma: no cover
from dataclasses import _MISSING_TYPE

Expand All @@ -32,7 +36,7 @@ class _NoDefault(Generic[_T]):
NoDefaultVar: TypeAlias = Union[_NoDefault[_T], _T]
no_default: NoDefaultVar[Any] = _NoDefault()

FieldType: TypeAlias = str | type | UnionType
FieldType: TypeAlias = Union[str, type, "UnionType"]


def _is_primitive_type(field_type: FieldType) -> TypeGuard[type]:
Expand Down Expand Up @@ -81,7 +85,10 @@ def _is_union(field_type: FieldType) -> TypeIs[UnionType]:
bool: True if input type is `Union`
"""
origin = typing.get_origin(field_type)
return origin is Union or origin is UnionType
is_union = origin is Union
with contextlib.suppress(NameError):
is_union = is_union or origin is UnionType
return is_union


def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionType]:
Expand All @@ -98,7 +105,7 @@ def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionT
"""
args = typing.get_args(field_type)
is_python_union = _is_union(field_type)
if is_python_union and len(args) == 2 and NoneType in args:
if is_python_union and len(args) == 2 and type(None) in args:
# This is an Optional type like `int | None`
return False
return is_python_union
Expand Down Expand Up @@ -129,7 +136,7 @@ def _is_optional_field(field_type: FieldType) -> TypeGuard[UnionType]:
Returns:
bool: True if the input field type is optional.
"""
return NoneType in typing.get_args(field_type)
return type(None) in typing.get_args(field_type)


def _drop_none_type(field_type: FieldType) -> FieldType:
Expand All @@ -142,11 +149,11 @@ def _drop_none_type(field_type: FieldType) -> FieldType:
return field_type
origin = typing.get_origin(field_type)
args = list(typing.get_args(field_type))
if NoneType in args:
args.remove(NoneType)
if type(None) in args:
args.remove(type(None))
if len(args) == 1:
return typing.cast(type, args[0])
return typing.cast(UnionType, GenericAlias(origin, args))
return typing.cast("UnionType", GenericAlias(origin, args))


def _serialize(x: Any) -> Any:
Expand Down Expand Up @@ -257,7 +264,7 @@ def _deserialize_primitive_types(
base_type = _drop_none_type(field_type)
if base_type is not float and base_type is not int and base_type is not str and base_type is not bool:
raise TypeError
base_type = typing.cast(type[int | float | str | bool], base_type)
base_type = typing.cast(type[Union[int, float, str, bool]], base_type)
if x == float("inf") or x == float("-inf"):
# if value type is inf return regardless.
return x
Expand Down Expand Up @@ -297,7 +304,7 @@ def _deserialize(x: Any, field_type: FieldType) -> Any:

CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"]
CoqpitNestedValue: TypeAlias = Union["CoqpitValue", CoqpitType]
CoqpitValue: TypeAlias = str | int | float | bool | None
CoqpitValue: TypeAlias = Union[str, int, float, bool, None]


# TODO: It should be possible to get rid of the next 3 `type: ignore`. At
Expand Down

0 comments on commit c9ef025

Please sign in to comment.