diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 604ec16..d8258ce 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -11,11 +11,15 @@ from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace from pathlib import Path from pprint import pprint -from types import GenericAlias, NoneType, UnionType +from types import GenericAlias from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, overload from typing_extensions import Self, TypeAlias, TypeGuard, TypeIs +# TODO: Available from Python 3.10 +with contextlib.suppress(ImportError): + from types import UnionType + if TYPE_CHECKING: # pragma: no cover from dataclasses import _MISSING_TYPE @@ -32,7 +36,7 @@ class _NoDefault(Generic[_T]): NoDefaultVar: TypeAlias = Union[_NoDefault[_T], _T] no_default: NoDefaultVar[Any] = _NoDefault() -FieldType: TypeAlias = str | type | UnionType +FieldType: TypeAlias = Union[str, type, "UnionType"] def _is_primitive_type(field_type: FieldType) -> TypeGuard[type]: @@ -81,7 +85,10 @@ def _is_union(field_type: FieldType) -> TypeIs[UnionType]: bool: True if input type is `Union` """ origin = typing.get_origin(field_type) - return origin is Union or origin is UnionType + is_union = origin is Union + with contextlib.suppress(NameError): + is_union = is_union or origin is UnionType + return is_union def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionType]: @@ -98,7 +105,7 @@ def _is_union_and_not_simple_optional(field_type: FieldType) -> TypeGuard[UnionT """ args = typing.get_args(field_type) is_python_union = _is_union(field_type) - if is_python_union and len(args) == 2 and NoneType in args: + if is_python_union and len(args) == 2 and type(None) in args: # This is an Optional type like `int | None` return False return is_python_union @@ -129,7 +136,7 @@ def _is_optional_field(field_type: FieldType) -> TypeGuard[UnionType]: Returns: bool: True if the input field type is optional. """ - return NoneType in typing.get_args(field_type) + return type(None) in typing.get_args(field_type) def _drop_none_type(field_type: FieldType) -> FieldType: @@ -142,11 +149,11 @@ def _drop_none_type(field_type: FieldType) -> FieldType: return field_type origin = typing.get_origin(field_type) args = list(typing.get_args(field_type)) - if NoneType in args: - args.remove(NoneType) + if type(None) in args: + args.remove(type(None)) if len(args) == 1: return typing.cast(type, args[0]) - return typing.cast(UnionType, GenericAlias(origin, args)) + return typing.cast("UnionType", GenericAlias(origin, args)) def _serialize(x: Any) -> Any: @@ -257,7 +264,7 @@ def _deserialize_primitive_types( base_type = _drop_none_type(field_type) if base_type is not float and base_type is not int and base_type is not str and base_type is not bool: raise TypeError - base_type = typing.cast(type[int | float | str | bool], base_type) + base_type = typing.cast(type[Union[int, float, str, bool]], base_type) if x == float("inf") or x == float("-inf"): # if value type is inf return regardless. return x @@ -297,7 +304,7 @@ def _deserialize(x: Any, field_type: FieldType) -> Any: CoqpitType: TypeAlias = MutableMapping[str, "CoqpitNestedValue"] CoqpitNestedValue: TypeAlias = Union["CoqpitValue", CoqpitType] -CoqpitValue: TypeAlias = str | int | float | bool | None +CoqpitValue: TypeAlias = Union[str, int, float, bool, None] # TODO: It should be possible to get rid of the next 3 `type: ignore`. At