diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index e6cc862..90b6f7e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -47,4 +47,4 @@ jobs: if: ${{ matrix.python_version != 'pypy3' && matrix.python_version != '3.6' }} run: pre-commit run --all-files - name: Test with pytest - run: pytest + run: pytest -ra diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 22fe440..9d47786 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -1,3 +1,4 @@ +import copy import inspect import sys import typing @@ -26,18 +27,11 @@ class TestFieldForSchema(unittest.TestCase): def assertFieldsEqual(self, a: fields.Field, b: fields.Field): self.assertEqual(a.__class__, b.__class__, "field class") - def canonical(k, v): - if k == "union_fields": - # See https://github.com/lovasoa/marshmallow_dataclass/pull/246#issuecomment-1722291806 - return k, sorted(map(repr, v)) - elif inspect.isclass(v): - return k, f"{v!r} ({v.__mro__!r})" - else: - return k, repr(v) - def attrs(x): return sorted( - canonical(k, v) for k, v in vars(x).items() if not k.startswith("_") + (k, f"{v!r} ({v.__mro__!r})" if inspect.isclass(v) else repr(v)) + for k, v in x.__dict__.items() + if not k.startswith("_") ) self.assertEqual(attrs(a), attrs(b)) @@ -189,7 +183,11 @@ def test_union_multiple_types_with_none(self): ), ) + @unittest.expectedFailure def test_optional_multiple_types(self): + # excercise bug (see #247) + Optional[Union[str, int]] + self.assertFieldsEqual( field_for_schema(Optional[Union[int, str]]), union_field.Union( @@ -203,6 +201,26 @@ def test_optional_multiple_types(self): ), ) + def test_optional_multiple_types_ignoring_union_field_order(self): + # see https://github.com/lovasoa/marshmallow_dataclass/pull/246#issuecomment-1722204048 + result = field_for_schema(Optional[Union[int, str]]) + expected = union_field.Union( + [ + (int, fields.Integer(required=True)), + (str, fields.String(required=True)), + ], + required=False, + dump_default=None, + load_default=None, + ) + + def sort_union_fields(field): + rv = copy.copy(field) + rv.union_fields = sorted(field.union_fields, key=repr) + return rv + + self.assertFieldsEqual(sort_union_fields(result), sort_union_fields(expected)) + def test_newtype(self): self.assertFieldsEqual( field_for_schema(typing.NewType("UserId", int), default=0),