Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type argument inference for overloaded functions with explicit self types (Fixes #14943). #14975

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
22 changes: 2 additions & 20 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,35 +1940,17 @@ def bind_and_map_method(
sub_info: class where the method is used
super_info: class where the method was defined
"""
mapped_typ = cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info))
if isinstance(sym.node, (FuncDef, OverloadedFuncDef, Decorator)) and not is_static(
sym.node
):
if isinstance(sym.node, Decorator):
is_class_method = sym.node.func.is_class
else:
is_class_method = sym.node.is_class

mapped_typ = cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info))
active_self_type = self.scope.active_self_type()
if isinstance(mapped_typ, Overloaded) and active_self_type:
# If we have an overload, filter to overloads that match the self type.
# This avoids false positives for concrete subclasses of generic classes,
# see testSelfTypeOverrideCompatibility for an example.
filtered_items = [
item
for item in mapped_typ.items
if not item.arg_types or is_subtype(active_self_type, item.arg_types[0])
]
# If we don't have any filtered_items, maybe it's always a valid override
# of the superclass? However if you get to that point you're in murky type
# territory anyway, so we just preserve the type and have the behaviour match
# that of older versions of mypy.
if filtered_items:
mapped_typ = Overloaded(filtered_items)

return bind_self(mapped_typ, active_self_type, is_class_method)
else:
return cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info))
return mapped_typ

def get_op_other_domain(self, tp: FunctionLike) -> Type | None:
if isinstance(tp, CallableType):
Expand Down
9 changes: 9 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
# Supported for both proper and non-proper
ignore_promotions: bool = False,
ignore_uninhabited: bool = False,
ignore_type_vars: bool = False,
# Proper subtype flags
erase_instances: bool = False,
keep_erased_types: bool = False,
Expand All @@ -96,6 +97,7 @@ def __init__(
self.ignore_declared_variance = ignore_declared_variance
self.ignore_promotions = ignore_promotions
self.ignore_uninhabited = ignore_uninhabited
self.ignore_type_vars = ignore_type_vars
self.erase_instances = erase_instances
self.keep_erased_types = keep_erased_types
self.options = options
Expand All @@ -119,6 +121,7 @@ def is_subtype(
ignore_declared_variance: bool = False,
ignore_promotions: bool = False,
ignore_uninhabited: bool = False,
ignore_type_vars: bool = False,
options: Options | None = None,
) -> bool:
"""Is 'left' subtype of 'right'?
Expand All @@ -139,6 +142,7 @@ def is_subtype(
ignore_declared_variance=ignore_declared_variance,
ignore_promotions=ignore_promotions,
ignore_uninhabited=ignore_uninhabited,
ignore_type_vars=ignore_type_vars,
options=options,
)
else:
Expand Down Expand Up @@ -287,6 +291,11 @@ def _is_subtype(
# ErasedType as we do for non-proper subtyping.
return True

if subtype_context.ignore_type_vars and (
isinstance(left, TypeVarType) or isinstance(right, TypeVarType)
):
return True

if isinstance(right, UnionType) and not isinstance(left, UnionType):
# Normally, when 'left' is not itself a union, the only way
# 'left' can be a subtype of the union 'right' is if it is a
Expand Down
39 changes: 36 additions & 3 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def type_object_type_from_function(
# ...
#
# We need to map B's __init__ to the type (List[T]) -> None.
signature = bind_self(signature, original_type=default_self, is_classmethod=is_new)
signature = bind_self(
signature, original_type=default_self, is_classmethod=is_new, selftypes=orig_self_types
)
signature = cast(FunctionLike, map_type_from_supertype(signature, info, def_info))

special_sig: str | None = None
Expand Down Expand Up @@ -251,7 +253,12 @@ def supported_self_type(typ: ProperType) -> bool:
F = TypeVar("F", bound=FunctionLike)


def bind_self(method: F, original_type: Type | None = None, is_classmethod: bool = False) -> F:
def bind_self(
method: F,
original_type: Type | None = None,
is_classmethod: bool = False,
selftypes: list[Type | None] | None = None,
) -> F:
"""Return a copy of `method`, with the type of its first parameter (usually
self or cls) bound to original_type.

Expand All @@ -274,10 +281,36 @@ class B(A): pass
b = B().copy() # type: B

"""

from mypy.subtypes import is_subtype

if isinstance(method, Overloaded):
# Try to remove overload items with non-matching self types first (fixes #14943)
origtype = get_proper_type(original_type)
if isinstance(origtype, Instance):
methoditems = []
if selftypes is not None:
selftypes_copy = selftypes.copy()
selftypes.clear()
for idx, methoditem in enumerate(method.items):
selftype = get_self_type(methoditem, origtype)
selftype_proper = get_proper_type(selftype)
if not isinstance(selftype_proper, Instance) or is_subtype(
origtype, selftype_proper, ignore_type_vars=True
):
methoditems.append(methoditem)
if selftypes is not None:
selftypes.append(selftypes_copy[idx])
if len(methoditems) == 0:
methoditems = method.items
if selftypes is not None:
selftypes.extend(selftypes_copy)
else:
methoditems = method.items
return cast(
F, Overloaded([bind_self(c, original_type, is_classmethod) for c in method.items])
F, Overloaded([bind_self(mi, original_type, is_classmethod) for mi in methoditems])
)

assert isinstance(method, CallableType)
func = method
if not func.arg_types:
Expand Down
39 changes: 39 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -4020,3 +4020,42 @@ class P(Protocol):

[file lib.py]
class C: ...

[case TestOverloadedMethodWithExplictSelfTypes]
from typing import Generic, overload, Protocol, TypeVar, Union

AnyStr = TypeVar("AnyStr", str, bytes)
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)

class SupportsRead(Protocol[T_co]):
def read(self) -> T_co: ...

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra) -> int: ...

class Input(Generic[AnyStr]):
def read(self) -> AnyStr: ...

class Output(Generic[AnyStr]):
@overload
def write(self: Output[str], s: str) -> int: ...
@overload
def write(self: Output[bytes], s: bytes) -> int: ...
def write(self, s: Union[str, bytes]) -> int: ...

def f(src: SupportsRead[AnyStr], dst: SupportsWrite[AnyStr]) -> None: ...

def g1(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g2(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g3(a: Input[str], b: Output[bytes]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

def g4(a: Input[bytes], b: Output[str]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

[builtins fixtures/tuple.pyi]
22 changes: 21 additions & 1 deletion test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class C(A[None]):
# N: def f(self, s: int) -> int
[builtins fixtures/tuple.pyi]

[case testSelfTypeOverrideCompatibilityTypeVar-xfail]
[case testSelfTypeOverrideCompatibilityTypeVar]
from typing import overload, TypeVar, Union

AT = TypeVar("AT", bound="A")
Expand Down Expand Up @@ -266,6 +266,26 @@ class B(A):
def f(*a, **kw): ...
[builtins fixtures/dict.pyi]

[case testSelfTypeOverrideCompatibilitySelfTypeVar]
from typing import Any, Generic, Self, TypeVar, overload

T_co = TypeVar('T_co', covariant=True)

class Config(Generic[T_co]):
@overload
def get(self, instance: None) -> Self: ...
@overload
def get(self, instance: Any) -> T_co: ...
def get(self, *a, **kw): ...

class MultiConfig(Config[T_co]):
@overload
def get(self, instance: None) -> Self: ...
@overload
def get(self, instance: Any) -> T_co: ...
def get(self, *a, **kw): ...
[builtins fixtures/dict.pyi]

[case testSelfTypeSuper]
from typing import TypeVar, cast

Expand Down