Skip to content

Commit

Permalink
fixup! fix _generic_type_add_any
Browse files Browse the repository at this point in the history
  • Loading branch information
anis-campos committed Sep 11, 2023
1 parent 1b4fb02 commit 9cb5cae
Showing 1 changed file with 35 additions and 21 deletions.
56 changes: 35 additions & 21 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class User:

from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute


if sys.version_info >= (3, 11):
from typing import dataclass_transform
elif sys.version_info >= (3, 7):
Expand Down Expand Up @@ -480,20 +479,35 @@ def _field_by_supertype(
)


COLLECTIONS_TYPES: Dict[str, Set[Type]] = {
"list": {list, List},
"dict": {dict, Dict},
"mapping": {Mapping, collections.abc.Mapping},
"sequence": {Sequence, collections.abc.Sequence},
"set": {set, Set, collections.abc.Set},
"frozenset": {frozenset, FrozenSet},
"tuple": {tuple, Tuple},
}


def _generic_type_add_any(typ: type) -> type:
"""if typ is generic type without arguments, replace them by Any."""
if typ in (list, List):
typ = List[Any]
elif typ in (dict, Dict):
typ = Dict[Any, Any]
elif typ in (Mapping, collections.abc.Mapping):
typ = Mapping[Any, Any]
elif typ in (Sequence, collections.abc.Sequence):
typ = Sequence[Any]
elif typ in (set, Set, collections.abc.Set):
typ = Set[Any]
elif typ in (frozenset, FrozenSet):
typ = FrozenSet[Any]
args = typing_inspect.get_args(typ)
if not args or any(typing_inspect.is_typevar(arg) for arg in args):
if typ in COLLECTIONS_TYPES["list"]:
typ = List[Any]
elif typ in COLLECTIONS_TYPES["dict"]:
typ = Dict[Any, Any]
elif typ in COLLECTIONS_TYPES["mapping"]:
typ = Mapping[Any, Any]
elif typ in COLLECTIONS_TYPES["sequence"]:
typ = Sequence[Any]
elif typ in COLLECTIONS_TYPES["set"]:
typ = Set[Any]
elif typ in COLLECTIONS_TYPES["frozenset"]:
typ = FrozenSet[Any]
elif typ in COLLECTIONS_TYPES["tuple"]:
typ = Tuple[Any, ...]
return typ


Expand All @@ -512,7 +526,7 @@ def _field_for_generic_type(
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}

if origin in (list, List):
if origin in COLLECTIONS_TYPES["list"]:
child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
Expand All @@ -521,8 +535,8 @@ def _field_for_generic_type(
type_mapping.get(List, marshmallow.fields.List),
)
return list_type(child_type, **metadata)
if origin in (collections.abc.Sequence, Sequence) or (
origin in (tuple, Tuple)
if origin in COLLECTIONS_TYPES["sequence"] or (
origin in COLLECTIONS_TYPES["tuple"]
and len(arguments) == 2
and arguments[1] is Ellipsis
):
Expand All @@ -532,7 +546,7 @@ def _field_for_generic_type(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
if origin in (set, Set):
if origin in COLLECTIONS_TYPES["set"]:
from . import collection_field

child_type = field_for_schema(
Expand All @@ -541,7 +555,7 @@ def _field_for_generic_type(
return collection_field.Set(
cls_or_instance=child_type, frozen=False, **metadata
)
if origin in (frozenset, FrozenSet):
if origin in COLLECTIONS_TYPES["frozenset"]:
from . import collection_field

child_type = field_for_schema(
Expand All @@ -550,7 +564,7 @@ def _field_for_generic_type(
return collection_field.Set(
cls_or_instance=child_type, frozen=True, **metadata
)
if origin in (tuple, Tuple):
if origin in COLLECTIONS_TYPES["tuple"]:
children = tuple(
field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame)
for arg in arguments
Expand All @@ -562,7 +576,7 @@ def _field_for_generic_type(
),
)
return tuple_type(children, **metadata)
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
elif origin in COLLECTIONS_TYPES["dict"] | COLLECTIONS_TYPES["mapping"]:
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
return dict_type(
keys=field_for_schema(
Expand Down Expand Up @@ -727,7 +741,7 @@ def field_for_schema(
)

# enumerations
if issubclass(typ, Enum):
if isinstance(typ, type) and issubclass(typ, Enum):
try:
return marshmallow.fields.Enum(typ, **metadata)
except AttributeError:
Expand Down

0 comments on commit 9cb5cae

Please sign in to comment.