Skip to content

Commit

Permalink
New type inference: add support for upper bounds and values (#15813)
Browse files Browse the repository at this point in the history
This is a third PR in series following
#15287 and
#15754. This one is quite simple: I
just add basic support for polymorphic inference involving type
variables with upper bounds and values. A complete support would be
quite complicated, and it will be a corner case to already rare
situation. Finally, it is written in a way that is easy to tune in the
future.

I also use this PR to add some unit tests for all three PRs so far,
other two PRs only added integration tests (and I clean up existing unit
tests as well).
  • Loading branch information
ilevkivskyi authored Aug 9, 2023
1 parent a7c4852 commit 8c21953
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 40 deletions.
80 changes: 69 additions & 11 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from mypy.expandtype import expand_type
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.join import join_types
from mypy.meet import meet_types
from mypy.meet import meet_type_list, meet_types
from mypy.subtypes import is_subtype
from mypy.typeops import get_type_vars
from mypy.types import (
AnyType,
Instance,
NoneType,
ProperType,
Type,
TypeOfAny,
Expand Down Expand Up @@ -108,15 +110,15 @@ def solve_constraints(
else:
candidate = AnyType(TypeOfAny.special_form)
res.append(candidate)
return res, [originals[tv] for tv in free_vars]
return res, free_vars


def solve_with_dependent(
vars: list[TypeVarId],
constraints: list[Constraint],
original_vars: list[TypeVarId],
originals: dict[TypeVarId, TypeVarLikeType],
) -> tuple[Solutions, list[TypeVarId]]:
) -> tuple[Solutions, list[TypeVarLikeType]]:
"""Solve set of constraints that may depend on each other, like T <: List[S].
The whole algorithm consists of five steps:
Expand All @@ -135,23 +137,24 @@ def solve_with_dependent(
raw_batches = list(topsort(prepare_sccs(sccs, dmap)))

free_vars = []
free_solutions = {}
for scc in raw_batches[0]:
# If there are no bounds on this SCC, then the only meaningful solution we can
# express, is that each variable is equal to a new free variable. For example,
# if we have T <: S, S <: U, we deduce: T = S = U = <free>.
if all(not lowers[tv] and not uppers[tv] for tv in scc):
# For convenience with current type application machinery, we use a stable
# choice that prefers the original type variables (not polymorphic ones) in SCC.
# TODO: be careful about upper bounds (or values) when introducing free vars.
free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0])
best_free = choose_free([originals[tv] for tv in scc], original_vars)
if best_free:
free_vars.append(best_free.id)
free_solutions[best_free.id] = best_free

# Update lowers/uppers with free vars, so these can now be used
# as valid solutions.
for l, u in graph.copy():
for l, u in graph:
if l in free_vars:
lowers[u].add(originals[l])
lowers[u].add(free_solutions[l])
if u in free_vars:
uppers[l].add(originals[u])
uppers[l].add(free_solutions[u])

# Flatten the SCCs that are independent, we can solve them together,
# since we don't need to update any targets in between.
Expand All @@ -166,7 +169,7 @@ def solve_with_dependent(
for flat_batch in batches:
res = solve_iteratively(flat_batch, graph, lowers, uppers)
solutions.update(res)
return solutions, free_vars
return solutions, [free_solutions[tv] for tv in free_vars]


def solve_iteratively(
Expand Down Expand Up @@ -276,6 +279,61 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
return candidate


def choose_free(
scc: list[TypeVarLikeType], original_vars: list[TypeVarId]
) -> TypeVarLikeType | None:
"""Choose the best solution for an SCC containing only type variables.
This is needed to preserve e.g. the upper bound in a situation like this:
def dec(f: Callable[[T], S]) -> Callable[[T], S]: ...
@dec
def test(x: U) -> U: ...
where U <: A.
"""

if len(scc) == 1:
# Fast path, choice is trivial.
return scc[0]

common_upper_bound = meet_type_list([t.upper_bound for t in scc])
common_upper_bound_p = get_proper_type(common_upper_bound)
# We include None for when strict-optional is disabled.
if isinstance(common_upper_bound_p, (UninhabitedType, NoneType)):
# This will cause to infer <nothing>, which is better than a free TypeVar
# that has an upper bound <nothing>.
return None

values: list[Type] = []
for tv in scc:
if isinstance(tv, TypeVarType) and tv.values:
if values:
# It is too tricky to support multiple TypeVars with values
# within the same SCC.
return None
values = tv.values.copy()

if values and not is_trivial_bound(common_upper_bound_p):
# If there are both values and upper bound present, we give up,
# since type variables having both are not supported.
return None

# For convenience with current type application machinery, we use a stable
# choice that prefers the original type variables (not polymorphic ones) in SCC.
best = sorted(scc, key=lambda x: (x.id not in original_vars, x.id.raw_id))[0]
if isinstance(best, TypeVarType):
return best.copy_modified(values=values, upper_bound=common_upper_bound)
if is_trivial_bound(common_upper_bound_p):
# TODO: support more cases for ParamSpecs/TypeVarTuples
return best
return None


def is_trivial_bound(tp: ProperType) -> bool:
return isinstance(tp, Instance) and tp.type.fullname == "builtins.object"


def normalize_constraints(
constraints: list[Constraint], vars: list[TypeVarId]
) -> list[Constraint]:
Expand Down
Loading

0 comments on commit 8c21953

Please sign in to comment.