diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index aa47ab0..e6b6633 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -67,6 +67,7 @@ class User: import typing_inspect from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute +from typing_extensions import get_args if sys.version_info >= (3, 11): @@ -482,18 +483,36 @@ def _field_by_supertype( def _generic_type_add_any(typ: type) -> type: """if typ is generic type without arguments, replace them by Any.""" - if typ is list or typ is List: - typ = List[Any] - elif typ is dict or typ is Dict: - typ = Dict[Any, Any] - elif typ is Mapping: - typ = Mapping[Any, Any] - elif typ is Sequence: - typ = Sequence[Any] - elif typ is set or typ is Set: - typ = Set[Any] - elif typ is frozenset or typ is FrozenSet: - typ = FrozenSet[Any] + if get_args(typ): + # there is arguments in the generic type, no need to check it + return typ + if sys.version_info >= (3, 9): + # supports the collections.abc generic type + if issubclass(typ, List): + typ = List[Any] + elif issubclass(typ, Dict): + typ = Dict[Any, Any] + elif issubclass(typ, Mapping): + typ = Mapping[Any, Any] + elif issubclass(typ, Sequence): + typ = Sequence[Any] + elif issubclass(typ, Set): + typ = Set[Any] + elif issubclass(typ, FrozenSet): + typ = FrozenSet[Any] + else: + if typ is list or typ is List: + typ = List[Any] + elif typ is dict or typ is Dict: + typ = Dict[Any, Any] + elif typ is Mapping: + typ = Mapping[Any, Any] + elif typ is Sequence: + typ = Sequence[Any] + elif typ is set or typ is Set: + typ = Set[Any] + elif typ is frozenset or typ is FrozenSet: + typ = FrozenSet[Any] return typ diff --git a/tests/test_collection.py b/tests/test_collection.py index 594e8cb..1b876e6 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -322,6 +322,39 @@ class AnyMapping: ) self.assertEqual(schema.dump(loaded), data_in) + @unittest.skipIf(sys.version_info < (3, 9), "PEP 585 unsupported") + def test_collections_mapping_no_arg(self): + import collections.abc + + @dataclass + class AnyMapping: + value: collections.abc.Mapping + + schema = AnyMapping.Schema() + + # can load a sequence of mixed kind + data_in = { + "value": { + 1: "a number key a str value", + "a str key a number value": 2, + None: "this is still valid", + "even this": None, + } + } + loaded = schema.load(data_in) + self.assertEqual( + loaded, + AnyMapping( + value={ + 1: "a number key a str value", + "a str key a number value": 2, + None: "this is still valid", + "even this": None, + } + ), + ) + self.assertEqual(schema.dump(loaded), data_in) + def test_mapping_of_frozen_dataclass(self): @dataclass(frozen=True) class Elm: