Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter overload items based on self type during type inference #17873

Merged
merged 11 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import (
ARG_OPT,
ARG_POS,
ARG_STAR,
ARG_STAR2,
Expand Down Expand Up @@ -305,9 +306,28 @@ class B(A): pass

"""
if isinstance(method, Overloaded):
items = [
bind_self(c, original_type, is_classmethod, ignore_instances) for c in method.items
]
items = []
original_type = get_proper_type(original_type)
for c in method.items:
if isinstance(original_type, Instance):
# Filter based on whether declared self type can match actual object type.
# For example, if self has type C[int] and method is accessed on a C[str] value,
# omit this item. This is best effort since bind_self can be called in many
# contexts, and doing complete validation might trigger infinite recursion.
#
# Note that overload item filtering normally happens elsewhere. This is needed
# at least during constraint inference.
keep = is_valid_self_type_best_effort(c, original_type)
else:
keep = True
if keep:
items.append(bind_self(c, original_type, is_classmethod, ignore_instances))
if len(items) == 0:
# We must return a valid overloaded type, so pick the first item if none
# are matching (arbitrarily).
items.append(
bind_self(method.items[0], original_type, is_classmethod, ignore_instances)
)
return cast(F, Overloaded(items))
assert isinstance(method, CallableType)
func = method
Expand Down Expand Up @@ -379,6 +399,43 @@ class B(A): pass
return cast(F, res)


def is_valid_self_type_best_effort(c: CallableType, self_type: Instance) -> bool:
"""Quickly check if self_type might match the self in a callable.

Avoid performing any complex type operations. This is performance-critical.

Default to returning True if we don't know (or it would be too expensive).
"""
if (
self_type.args
and c.arg_types
and isinstance((arg_type := get_proper_type(c.arg_types[0])), Instance)
and c.arg_kinds[0] in (ARG_POS, ARG_OPT)
and arg_type.args
and self_type.type.fullname != "functools._SingleDispatchCallable"
):
if self_type.type is not arg_type.type:
# We can't map to supertype, since it could trigger expensive checks for
# protocol types, so we consevatively assume this is fine.
return True

# Fast path: no explicit annotation on self
if all(
(
type(arg) is TypeVarType
and type(arg.upper_bound) is Instance
and arg.upper_bound.type.fullname == "builtins.object"
)
for arg in arg_type.args
):
return True

from mypy.meet import is_overlapping_types

return is_overlapping_types(self_type, c.arg_types[0])
return True


def erase_to_bound(t: Type) -> Type:
# TODO: use value restrictions to produce a union?
t = get_proper_type(t)
Expand Down
88 changes: 88 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -4127,3 +4127,91 @@ class P(Protocol):

class C(P): ...
C(0) # OK

[case testTypeVarValueConstraintAgainstGenericProtocol]
from typing import TypeVar, Generic, Protocol, overload

T_contra = TypeVar("T_contra", contravariant=True)
AnyStr = TypeVar("AnyStr", str, bytes)

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra, /) -> None: ...

class Buffer: ...

class IO(Generic[AnyStr]):
@overload
def write(self: IO[bytes], s: Buffer, /) -> None: ...
@overload
def write(self, s: AnyStr, /) -> None: ...
def write(self, s): ...

def foo(fdst: SupportsWrite[AnyStr]) -> None: ...

x: IO[str]
foo(x)

[case testTypeVarValueConstraintAgainstGenericProtocol2]
from typing import Generic, Protocol, TypeVar, overload

AnyStr = TypeVar("AnyStr", str, bytes)
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)

class SupportsRead(Generic[T_co]):
def read(self) -> T_co: ...

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra) -> object: ...

def copyfileobj(fsrc: SupportsRead[AnyStr], fdst: SupportsWrite[AnyStr]) -> None: ...

class WriteToMe(Generic[AnyStr]):
@overload
def write(self: WriteToMe[str], s: str) -> int: ...
@overload
def write(self: WriteToMe[bytes], s: bytes) -> int: ...
def write(self, s): ...

class WriteToMeOrReadFromMe(WriteToMe[AnyStr], SupportsRead[AnyStr]): ...

copyfileobj(WriteToMeOrReadFromMe[bytes](), WriteToMe[bytes]())

[case testOverloadedMethodWithExplictSelfTypes]
from typing import Generic, overload, Protocol, TypeVar, Union

AnyStr = TypeVar("AnyStr", str, bytes)
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)

class SupportsRead(Protocol[T_co]):
def read(self) -> T_co: ...

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra) -> int: ...

class Input(Generic[AnyStr]):
def read(self) -> AnyStr: ...

class Output(Generic[AnyStr]):
@overload
def write(self: Output[str], s: str) -> int: ...
@overload
def write(self: Output[bytes], s: bytes) -> int: ...
def write(self, s: Union[str, bytes]) -> int: ...

def f(src: SupportsRead[AnyStr], dst: SupportsWrite[AnyStr]) -> None: ...

def g1(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g2(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g3(a: Input[str], b: Output[bytes]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

def g4(a: Input[bytes], b: Output[str]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

[builtins fixtures/tuple.pyi]
Loading