From 9cb5caec5ea7bebc5276426d2434cb6cb8d0f63d Mon Sep 17 00:00:00 2001 From: Anis Da Silva Campos Date: Wed, 6 Sep 2023 09:25:23 +0200 Subject: [PATCH] fixup! fix _generic_type_add_any --- marshmallow_dataclass/__init__.py | 56 +++++++++++++++++++------------ 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 7ad8da5..ac3d251 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -68,7 +68,6 @@ class User: from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute - if sys.version_info >= (3, 11): from typing import dataclass_transform elif sys.version_info >= (3, 7): @@ -480,20 +479,35 @@ def _field_by_supertype( ) +COLLECTIONS_TYPES: Dict[str, Set[Type]] = { + "list": {list, List}, + "dict": {dict, Dict}, + "mapping": {Mapping, collections.abc.Mapping}, + "sequence": {Sequence, collections.abc.Sequence}, + "set": {set, Set, collections.abc.Set}, + "frozenset": {frozenset, FrozenSet}, + "tuple": {tuple, Tuple}, +} + + def _generic_type_add_any(typ: type) -> type: """if typ is generic type without arguments, replace them by Any.""" - if typ in (list, List): - typ = List[Any] - elif typ in (dict, Dict): - typ = Dict[Any, Any] - elif typ in (Mapping, collections.abc.Mapping): - typ = Mapping[Any, Any] - elif typ in (Sequence, collections.abc.Sequence): - typ = Sequence[Any] - elif typ in (set, Set, collections.abc.Set): - typ = Set[Any] - elif typ in (frozenset, FrozenSet): - typ = FrozenSet[Any] + args = typing_inspect.get_args(typ) + if not args or any(typing_inspect.is_typevar(arg) for arg in args): + if typ in COLLECTIONS_TYPES["list"]: + typ = List[Any] + elif typ in COLLECTIONS_TYPES["dict"]: + typ = Dict[Any, Any] + elif typ in COLLECTIONS_TYPES["mapping"]: + typ = Mapping[Any, Any] + elif typ in COLLECTIONS_TYPES["sequence"]: + typ = Sequence[Any] + elif typ in COLLECTIONS_TYPES["set"]: + typ = Set[Any] + elif typ in COLLECTIONS_TYPES["frozenset"]: + typ = FrozenSet[Any] + elif typ in COLLECTIONS_TYPES["tuple"]: + typ = Tuple[Any, ...] return typ @@ -512,7 +526,7 @@ def _field_for_generic_type( # Override base_schema.TYPE_MAPPING to change the class used for generic types below type_mapping = base_schema.TYPE_MAPPING if base_schema else {} - if origin in (list, List): + if origin in COLLECTIONS_TYPES["list"]: child_type = field_for_schema( arguments[0], base_schema=base_schema, typ_frame=typ_frame ) @@ -521,8 +535,8 @@ def _field_for_generic_type( type_mapping.get(List, marshmallow.fields.List), ) return list_type(child_type, **metadata) - if origin in (collections.abc.Sequence, Sequence) or ( - origin in (tuple, Tuple) + if origin in COLLECTIONS_TYPES["sequence"] or ( + origin in COLLECTIONS_TYPES["tuple"] and len(arguments) == 2 and arguments[1] is Ellipsis ): @@ -532,7 +546,7 @@ def _field_for_generic_type( arguments[0], base_schema=base_schema, typ_frame=typ_frame ) return collection_field.Sequence(cls_or_instance=child_type, **metadata) - if origin in (set, Set): + if origin in COLLECTIONS_TYPES["set"]: from . import collection_field child_type = field_for_schema( @@ -541,7 +555,7 @@ def _field_for_generic_type( return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata ) - if origin in (frozenset, FrozenSet): + if origin in COLLECTIONS_TYPES["frozenset"]: from . import collection_field child_type = field_for_schema( @@ -550,7 +564,7 @@ def _field_for_generic_type( return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) - if origin in (tuple, Tuple): + if origin in COLLECTIONS_TYPES["tuple"]: children = tuple( field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame) for arg in arguments @@ -562,7 +576,7 @@ def _field_for_generic_type( ), ) return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): + elif origin in COLLECTIONS_TYPES["dict"] | COLLECTIONS_TYPES["mapping"]: dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=field_for_schema( @@ -727,7 +741,7 @@ def field_for_schema( ) # enumerations - if issubclass(typ, Enum): + if isinstance(typ, type) and issubclass(typ, Enum): try: return marshmallow.fields.Enum(typ, **metadata) except AttributeError: