diff --git a/overrides/enforce.py b/overrides/enforce.py index e40fd5b..0dbb5af 100644 --- a/overrides/enforce.py +++ b/overrides/enforce.py @@ -80,7 +80,7 @@ def is_param_defined_in_sub(name, sub_has_var_args, sub_has_var_kwargs, sub_sig, name in sub_sig.parameters or (super_param.kind == Parameter.VAR_POSITIONAL and sub_has_var_args) or (super_param.kind == Parameter.VAR_KEYWORD and sub_has_var_kwargs) - or (super_param.kind == Parameter.POSITIONAL_ONLY and not sub_has_var_args) + or (super_param.kind == Parameter.POSITIONAL_ONLY and sub_has_var_args) or ( super_param.kind == Parameter.POSITIONAL_OR_KEYWORD and sub_has_var_args diff --git a/tests/test_enforce.py b/tests/test_enforce.py index 48a1d3b..886d63f 100644 --- a/tests/test_enforce.py +++ b/tests/test_enforce.py @@ -1,4 +1,3 @@ -import sys import unittest from typing import Union, Optional @@ -116,15 +115,14 @@ class SubClass(MetaClassMethodOverrider): def register(self): pass - if sys.version_info[0] > 3 and sys.version_info[1] > 7: - def test_ensure_compatible_when_compatible(self): - def sup(a, /, b: str, c: int, *, d, e, **kwargs) -> object: - pass + def test_ensure_compatible_when_compatible(self): + def sup(a, /, b: str, c: int, *, d, e, **kwargs) -> object: + pass - def sub(a, b: object, c, d, f: str = "foo", *args, g: str = "bar", e, **kwargs) -> str: - pass + def sub(a, b: object, c, d, f: str = "foo", *args, g: str = "bar", e, **kwargs) -> str: + pass - ensure_compatible(sup, sub) + ensure_compatible(sup, sub) def test_ensure_compatible_when_return_types_are_incompatible(self): def sup(x) -> int: @@ -219,35 +217,75 @@ def generic_method(*args, **kwargs): ensure_compatible(generic_method, better_typed_method) def test_if_super_has_args_then_sub_must_have(self): - def sub1(x, y, z, /): + def sub1(x=2, y=3, z=4, /): pass - def subbest(x, /, *burgs): + def subbest(x=1, /, *burgs): pass def supah(*args): pass + # supah() => subbest() + # supah(2) => subbest(2) + # supah(2,3) => subbest(2,3) + # supah(*args) => subbest(*args) ensure_compatible(supah, subbest) + + # sub1(1,2,3) => subbest(1,2,3) + # sub1() => subbest() ensure_compatible(sub1, subbest) + with self.assertRaises(TypeError): + # supah() => sub1() ok + # supah(2) => sub1(2) ok + # supah(1,2,3,4) => sub1() takes from 0 to 3 positional arguments but 4 were given ensure_compatible(supah, sub1) + with self.assertRaises(TypeError): + # subbest() => sub1() ok + # subbest(1,2,3) => sub1(1,2,3) ok + # subbest(1,2,3,4) => sub1() takes from 0 to 3 positional arguments but 4 were given ensure_compatible(subbest, sub1) def test_if_super_has_kwargs_then_sub_must_have(self): def sub1(*, x=3, y=3, z=4): pass - def subbest(*, x=3, **kwargs): + def sus(*, x=3, **kwargs): pass - def supah(**kwargs): + def superb(**kwargs): pass - ensure_compatible(supah, subbest) - ensure_compatible(sub1, subbest) + # superb() => sus() + # superb(foo=1) => sus(foo=1) + # superb(x=4) => sus(x=4) + # superb(x=4, foo=1) => sus(x=4, foo=1) + # superb(**kwargs) => sus(**kwargs) + ensure_compatible(superb, sus) + ensure_compatible(sub1, sus) with self.assertRaises(TypeError): - ensure_compatible(supah, sub1) + ensure_compatible(superb, sub1) with self.assertRaises(TypeError): - ensure_compatible(subbest, sub1) + ensure_compatible(sus, sub1) + + def test_allowed_extra_args_in_overrider(self): + def superb(): + pass + + def optional_arg(arg=1): + pass + + def optional_positional_arg(arg2=2, /): + pass + + def optional_kw_only_arg(*, arg3=3): + pass + + # superb() => optional_arg() + ensure_compatible(superb, optional_arg) + # superb() => optional_positional_arg() + ensure_compatible(superb, optional_positional_arg) + # superb() => optional_kw_only_arg() + ensure_compatible(superb, optional_kw_only_arg)