Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Annotated support #257

Merged
merged 12 commits into from
Jun 23, 2024
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ venv.bak/
# Rope project settings
.ropeproject

# VSCode project settings
.vscode

# mkdocs documentation
/site

Expand Down
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ repos:
rev: v3.3.1
hooks:
- id: pyupgrade
args: ["--py36-plus"]
# I've kept it on py3.7 so that it doesn't replace `Dict` with `dict`
args: ["--py37-plus"]
- repo: https://github.com/python/black
rev: 23.1.0
hooks:
Expand All @@ -19,7 +20,7 @@ repos:
rev: v1.1.1
hooks:
- id: mypy
additional_dependencies: [marshmallow-enum,typeguard,marshmallow]
additional_dependencies: [typeguard,marshmallow]
args: [--show-error-codes]
- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Every commit is checked with pre-commit hooks for :
- type safety with [mypy](http://mypy-lang.org/)
- test conformance by running [tests](./tests) with [pytest](https://docs.pytest.org/en/latest/)
- You can run `pytest` from the command line.

48 changes: 41 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,47 @@ class Sample:

See [marshmallow's documentation about extending `Schema`](https://marshmallow.readthedocs.io/en/stable/extending.html).

### Custom NewType declarations
mvanderlee marked this conversation as resolved.
Show resolved Hide resolved
### Custom type aliases

This library allows you to specify [customized marshmallow fields](https://marshmallow.readthedocs.io/en/stable/custom_fields.html#creating-a-field-class) using python's Annoted type [PEP-593](https://peps.python.org/pep-0593/).

```python
from typing import Annotated
mvanderlee marked this conversation as resolved.
Show resolved Hide resolved
import marshmallow.fields as mf
import marshmallow.validate as mv

IPv4 = Annotated[str, mf.String(validate=mv.Regexp(r"^([0-9]{1,3}\\.){3}[0-9]{1,3}$"))]
```

You can also pass a marshmallow field class.

```python
from typing import Annotated
import marshmallow
from marshmallow_dataclass import NewType

Email = Annotated[str, marshmallow.fields.Email]
```

For convenience, some custom types are provided:

```python
from marshmallow_dataclass.typing import Email, Url
```

When using Python 3.8, you must import `Annotated` from the typing_extensions package

```python
# Version agnostic import code:
if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
```

### Custom NewType declarations [__deprecated__]

> NewType is deprecated in favor or type aliases using Annotated, as described above.

This library exports a `NewType` function to create types that generate [customized marshmallow fields](https://marshmallow.readthedocs.io/en/stable/custom_fields.html#creating-a-field-class).

Expand All @@ -266,12 +306,6 @@ from marshmallow_dataclass import NewType
Email = NewType("Email", str, field=marshmallow.fields.Email)
```

For convenience, some custom types are provided:

```python
from marshmallow_dataclass.typing import Email, Url
```

Note: if you are using `mypy`, you will notice that `mypy` throws an error if a variable defined with
`NewType` is used in a type annotation. To resolve this, add the `marshmallow_dataclass.mypy` plugin
to your `mypy` configuration, e.g.:
Expand Down
108 changes: 86 additions & 22 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class User:
})
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
"""

import collections.abc
import dataclasses
import inspect
Expand All @@ -47,11 +48,13 @@ class User:
Any,
Callable,
Dict,
FrozenSet,
Generic,
List,
Mapping,
NewType as typing_NewType,
Optional,
Sequence,
Set,
Tuple,
Type,
Expand All @@ -60,24 +63,23 @@ class User:
cast,
get_type_hints,
overload,
Sequence,
FrozenSet,
)

import marshmallow
import typing_extensions
import typing_inspect

from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated

if sys.version_info >= (3, 11):
from typing import dataclass_transform
elif sys.version_info >= (3, 7):
from typing_extensions import dataclass_transform
else:
# @dataclass_transform() only helps us with mypy>=1.1 which is only available for python>=3.7
def dataclass_transform(**kwargs):
return lambda cls: cls
from typing_extensions import dataclass_transform


__all__ = ["dataclass", "add_schema", "class_schema", "field_for_schema", "NewType"]
Expand Down Expand Up @@ -511,7 +513,15 @@ def _internal_class_schema(
base_schema: Optional[Type[marshmallow.Schema]] = None,
) -> Type[marshmallow.Schema]:
schema_ctx = _schema_ctx_stack.top
schema_ctx.seen_classes[clazz] = clazz.__name__

if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10):
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined]
else:
class_name = clazz.__name__

schema_ctx.seen_classes[clazz] = class_name

try:
# noinspection PyDataclass
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
Expand Down Expand Up @@ -546,9 +556,18 @@ def _internal_class_schema(
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)

# Update the schema members to contain marshmallow fields instead of dataclass fields
type_hints = get_type_hints(
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)

if sys.version_info >= (3, 9):
type_hints = get_type_hints(
clazz,
globalns=schema_ctx.globalns,
localns=schema_ctx.localns,
include_extras=True,
)
else:
type_hints = get_type_hints(
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)
attributes.update(
(
field.name,
Expand Down Expand Up @@ -639,8 +658,8 @@ def _field_for_generic_type(
"""
If the type is a generic interface, resolve the arguments and construct the appropriate Field.
"""
origin = typing_inspect.get_origin(typ)
arguments = typing_inspect.get_args(typ, True)
origin = typing_extensions.get_origin(typ)
arguments = typing_extensions.get_args(typ)
if origin:
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
Expand Down Expand Up @@ -694,6 +713,46 @@ def _field_for_generic_type(
**metadata,
)

return None


def _field_for_annotated_type(
typ: type,
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
"""
If the type is an Annotated interface, resolve the arguments and construct the appropriate Field.
"""
origin = typing_extensions.get_origin(typ)
arguments = typing_extensions.get_args(typ)
if origin and origin is Annotated:
marshmallow_annotations = [
arg
for arg in arguments[1:]
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
or isinstance(arg, marshmallow.fields.Field)
]
if marshmallow_annotations:
if len(marshmallow_annotations) > 1:
warnings.warn(
"Multiple marshmallow Field annotations found. Using the last one."
)

field = marshmallow_annotations[-1]
# Got a field instance, return as is. User must know what they're doing
if isinstance(field, marshmallow.fields.Field):
return field

return field(**metadata)
return None


def _field_for_union_type(
typ: type,
base_schema: Optional[Type[marshmallow.Schema]],
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
arguments = typing_extensions.get_args(typ)
if typing_inspect.is_union_type(typ):
if typing_inspect.is_optional_type(typ):
metadata["allow_none"] = metadata.get("allow_none", True)
Expand Down Expand Up @@ -806,6 +865,7 @@ def _field_for_schema(
metadata.setdefault("allow_none", True)
return marshmallow.fields.Raw(**metadata)

# i.e.: Literal['abc']
if typing_inspect.is_literal_type(typ):
arguments = typing_inspect.get_args(typ)
return marshmallow.fields.Raw(
Expand All @@ -817,6 +877,7 @@ def _field_for_schema(
**metadata,
)

# i.e.: Final[str] = 'abc'
if typing_inspect.is_final_type(typ):
arguments = typing_inspect.get_args(typ)
if arguments:
Expand Down Expand Up @@ -851,6 +912,14 @@ def _field_for_schema(
subtyp = Any
return _field_for_schema(subtyp, default, metadata, base_schema)

annotated_field = _field_for_annotated_type(typ, **metadata)
if annotated_field:
return annotated_field

union_field = _field_for_union_type(typ, base_schema, **metadata)
if union_field:
return union_field

# Generic types
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
if generic_field:
Expand All @@ -869,14 +938,8 @@ def _field_for_schema(
)

# enumerations
if issubclass(typ, Enum):
try:
return marshmallow.fields.Enum(typ, **metadata)
except AttributeError:
# Remove this once support for python 3.6 is dropped.
import marshmallow_enum

return marshmallow_enum.EnumField(typ, **metadata)
if inspect.isclass(typ) and issubclass(typ, Enum):
return marshmallow.fields.Enum(typ, **metadata)

# Nested marshmallow dataclass
# it would be just a class name instead of actual schema util the schema is not ready yet
Expand Down Expand Up @@ -939,7 +1002,8 @@ def NewType(
field: Optional[Type[marshmallow.fields.Field]] = None,
**kwargs,
) -> Callable[[_U], _U]:
"""NewType creates simple unique types
"""DEPRECATED: Use typing.Annotated instead.
NewType creates simple unique types
to which you can attach custom marshmallow attributes.
All the keyword arguments passed to this function will be transmitted
to the marshmallow field constructor.
Expand Down
12 changes: 9 additions & 3 deletions marshmallow_dataclass/typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import sys

import marshmallow.fields
from . import NewType

Url = NewType("Url", str, field=marshmallow.fields.Url)
Email = NewType("Email", str, field=marshmallow.fields.Email)
if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated

Url = Annotated[str, marshmallow.fields.Url]
Email = Annotated[str, marshmallow.fields.Email]

# Aliases
URL = Url
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@ target-version = ['py36', 'py37', 'py38', 'py39', 'py310', 'py310']
filterwarnings = [
"error:::marshmallow_dataclass|test",
]

[tool.coverage.report]
exclude_also = [
'^\s*\.\.\.\s*$',
'^\s*pass\s*$',
]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup

VERSION = "9.0.0"

Expand Down
37 changes: 37 additions & 0 deletions tests/test_annotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import sys
import unittest
from typing import Optional

import marshmallow
import marshmallow.fields

from marshmallow_dataclass import dataclass

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated


class TestAnnotatedField(unittest.TestCase):
def test_annotated_field(self):
@dataclass
class AnnotatedValue:
value: Annotated[str, marshmallow.fields.Email]
default_string: Annotated[
Optional[str], marshmallow.fields.String(load_default="Default String")
] = None

schema = AnnotatedValue.Schema()

self.assertEqual(
schema.load({"value": "test@test.com"}),
AnnotatedValue(value="test@test.com", default_string="Default String"),
)
self.assertEqual(
schema.load({"value": "test@test.com", "default_string": "override"}),
AnnotatedValue(value="test@test.com", default_string="override"),
)

with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": "notavalidemail"})
Loading