Skip to content

Use checkmember.py to check protocol subtyping #18943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/mypy_primer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
--debug \
--additional-flags="--debug-serialize" \
--output concise \
--show-speed-regression \
| tee diff_${{ matrix.shard-index }}.txt
) || [ $? -eq 1 ]
- if: ${{ matrix.shard-index == 0 }}
Expand Down
4 changes: 2 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def check_first_pass(self) -> None:
Deferred functions will be processed by check_second_pass().
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self):
self.errors.set_file(
self.path, self.tree.fullname, scope=self.tscope, options=self.options
)
Expand Down Expand Up @@ -496,7 +496,7 @@ def check_second_pass(
This goes through deferred nodes, returning True if there were any.
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self):
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(
Expand Down
55 changes: 24 additions & 31 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
is_self: bool = False,
rvalue: Expression | None = None,
suppress_errors: bool = False,
preserve_type_var_ids: bool = False,
) -> None:
self.is_lvalue = is_lvalue
self.is_super = is_super
Expand All @@ -112,6 +113,10 @@ def __init__(
assert is_lvalue
self.rvalue = rvalue
self.suppress_errors = suppress_errors
# This attribute is only used to preserve old protocol member access logic.
# It is needed to avoid infinite recursion in cases involving self-referential
# generic methods, see find_member() for details. Do not use for other purposes!
self.preserve_type_var_ids = preserve_type_var_ids

def named_type(self, name: str) -> Instance:
return self.chk.named_type(name)
Expand Down Expand Up @@ -142,6 +147,7 @@ def copy_modified(
no_deferral=self.no_deferral,
rvalue=self.rvalue,
suppress_errors=self.suppress_errors,
preserve_type_var_ids=self.preserve_type_var_ids,
)
if self_type is not None:
mx.self_type = self_type
Expand Down Expand Up @@ -231,8 +237,6 @@ def analyze_member_access(
def _analyze_member_access(
name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None
) -> Type:
# TODO: This and following functions share some logic with subtypes.find_member;
# consider refactoring.
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return analyze_instance_member_access(name, typ, mx, override_info)
Expand Down Expand Up @@ -355,7 +359,8 @@ def analyze_instance_member_access(
return AnyType(TypeOfAny.special_form)
assert isinstance(method.type, Overloaded)
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not mx.preserve_type_var_ids:
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
signature = check_self_arg(
signature, mx.self_type, method.is_class, mx.context, name, mx.msg
Expand Down Expand Up @@ -928,7 +933,8 @@ def analyze_var(
def expand_without_binding(
typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext
) -> Type:
typ = freshen_all_functions_type_vars(typ)
if not mx.preserve_type_var_ids:
typ = freshen_all_functions_type_vars(typ)
typ = expand_self_type_if_needed(typ, mx, var, original_itype)
expanded = expand_type_by_instance(typ, itype)
freeze_all_type_vars(expanded)
Expand All @@ -938,7 +944,8 @@ def expand_without_binding(
def expand_and_bind_callable(
functype: FunctionLike, var: Var, itype: Instance, name: str, mx: MemberContext
) -> Type:
functype = freshen_all_functions_type_vars(functype)
if not mx.preserve_type_var_ids:
functype = freshen_all_functions_type_vars(functype)
typ = get_proper_type(expand_self_type(var, functype, mx.original_type))
assert isinstance(typ, FunctionLike)
typ = check_self_arg(typ, mx.self_type, var.is_classmethod, mx.context, name, mx.msg)
Expand Down Expand Up @@ -1033,10 +1040,12 @@ def f(self: S) -> T: ...
return functype
else:
selfarg = get_proper_type(item.arg_types[0])
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
if subtypes.is_subtype(
# This matches similar special-casing in bind_self(), see more details there.
self_callable = name == "__call__" and isinstance(selfarg, CallableType)
if self_callable or subtypes.is_subtype(
dispatched_arg_type,
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
erase_typevars(erase_to_bound(selfarg)),
# This is to work around the fact that erased ParamSpec and TypeVarTuple
# callables are not always compatible with non-erased ones both ways.
Expand Down Expand Up @@ -1197,15 +1206,10 @@ def analyze_class_attribute_access(
is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or (
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class
)
is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or (
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static
)
t = get_proper_type(t)
if isinstance(t, FunctionLike) and is_classmethod:
t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg)
result = add_class_tvars(
t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars
)
result = add_class_tvars(t, isuper, is_classmethod, mx, original_vars=original_vars)
# __set__ is not called on class objects.
if not mx.is_lvalue:
result = analyze_descriptor_access(result, mx)
Expand Down Expand Up @@ -1337,8 +1341,7 @@ def add_class_tvars(
t: ProperType,
isuper: Instance | None,
is_classmethod: bool,
is_staticmethod: bool,
original_type: Type,
mx: MemberContext,
original_vars: Sequence[TypeVarLikeType] | None = None,
) -> Type:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not appear to be a performance bottleneck (at least in self check).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JukkaL If you will have time, could you please check if there is any slowness because of bind_self() and check_self_arg()? Although they are not modified, they may be called much more often now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_self_arg could be more expensive -- it appears to consume an extra ~0.5% of runtime in this PR. We are now spending maybe 2-3% of CPU in it, so it's quite hot, but it already was pretty hot before this PR. This could be noise though.

I didn't see any major change in bind_self when doing self check, though it's pretty hot both before and after, though less hot than check_self_arg.

"""Instantiate type variables during analyze_class_attribute_access,
Expand All @@ -1356,9 +1359,6 @@ class B(A[str]): pass
isuper: Current instance mapped to the superclass where method was defined, this
is usually done by map_instance_to_supertype()
is_classmethod: True if this method is decorated with @classmethod
is_staticmethod: True if this method is decorated with @staticmethod
original_type: The value of the type B in the expression B.foo() or the corresponding
component in case of a union (this is used to bind the self-types)
original_vars: Type variables of the class callable on which the method was accessed
Returns:
Expanded method type with added type variables (when needed).
Expand All @@ -1379,11 +1379,11 @@ class B(A[str]): pass
# (i.e. appear in the return type of the class object on which the method was accessed).
if isinstance(t, CallableType):
tvars = original_vars if original_vars is not None else []
t = freshen_all_functions_type_vars(t)
if not mx.preserve_type_var_ids:
t = freshen_all_functions_type_vars(t)
if is_classmethod:
t = bind_self(t, original_type, is_classmethod=True)
if is_classmethod or is_staticmethod:
assert isuper is not None
t = bind_self(t, mx.self_type, is_classmethod=True)
if isuper is not None:
t = expand_type_by_instance(t, isuper)
freeze_all_type_vars(t)
return t.copy_modified(variables=list(tvars) + list(t.variables))
Expand All @@ -1392,14 +1392,7 @@ class B(A[str]): pass
[
cast(
CallableType,
add_class_tvars(
item,
isuper,
is_classmethod,
is_staticmethod,
original_type,
original_vars=original_vars,
),
add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars),
)
for item in t.items
]
Expand Down
3 changes: 2 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Final, TypeVar, cast, overload

from mypy.nodes import ARG_STAR, FakeInfo, Var
from mypy.state import state
from mypy.types import (
ANY_STRATEGY,
AnyType,
Expand Down Expand Up @@ -544,6 +543,8 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]:
* Remove everything else if there is an `object`
* Remove strict duplicate types
"""
from mypy.state import state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This nested imports also causes a small performance regression (maybe 0.1% to 0.2%).


removed_none = False
new_types = []
all_types = set()
Expand Down
9 changes: 7 additions & 2 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,8 +2220,13 @@ def report_protocol_problems(
exp = get_proper_type(exp)
got = get_proper_type(got)
setter_suffix = " setter type" if is_lvalue else ""
if not isinstance(exp, (CallableType, Overloaded)) or not isinstance(
got, (CallableType, Overloaded)
if (
not isinstance(exp, (CallableType, Overloaded))
or not isinstance(got, (CallableType, Overloaded))
# If expected type is a type object, it means it is a nested class.
# Showing constructor signature in errors would be confusing in this case,
# since we don't check the signature, only subclassing of type objects.
or exp.is_type_obj()
):
self.note(
"{}: expected{} {}, got {}".format(
Expand Down
8 changes: 5 additions & 3 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,13 @@ class C: pass
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Callable, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar

from mypy_extensions import mypyc_attr, trait

from mypy.errorcodes import ErrorCode
from mypy.lookup import lookup_fully_qualified
from mypy.message_registry import ErrorMessage
from mypy.messages import MessageBuilder
from mypy.nodes import (
ArgKind,
CallExpr,
Expand All @@ -138,7 +137,6 @@ class C: pass
TypeInfo,
)
from mypy.options import Options
from mypy.tvar_scope import TypeVarLikeScope
from mypy.types import (
CallableType,
FunctionLike,
Expand All @@ -149,6 +147,10 @@ class C: pass
UnboundType,
)

if TYPE_CHECKING:
from mypy.messages import MessageBuilder
from mypy.tvar_scope import TypeVarLikeScope


@trait
class TypeAnalyzerPluginInterface:
Expand Down
20 changes: 16 additions & 4 deletions mypy/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from contextlib import contextmanager
from typing import Final

from mypy.checker_shared import TypeCheckerSharedApi

# These are global mutable state. Don't add anything here unless there's a very
# good reason.


class StrictOptionalState:
class SubtypeState:
# Wrap this in a class since it's faster that using a module-level attribute.

def __init__(self, strict_optional: bool) -> None:
# Value varies by file being processed
def __init__(self, strict_optional: bool, type_checker: TypeCheckerSharedApi | None) -> None:
# Values vary by file being processed
self.strict_optional = strict_optional
self.type_checker = type_checker

@contextmanager
def strict_optional_set(self, value: bool) -> Iterator[None]:
Expand All @@ -24,6 +27,15 @@ def strict_optional_set(self, value: bool) -> Iterator[None]:
finally:
self.strict_optional = saved

@contextmanager
def type_checker_set(self, value: TypeCheckerSharedApi) -> Iterator[None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dependency on TypeCheckerSharedApi probably makes various import cycles worse, and I assume this why there are some additional nested imports. Defining type_checker_set in a new module would improve things, right? Splitting this module seems better than making import cycles bigger, and it should also reduce the performance regression.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining type_checker_set in a new module would improve things, right?

Yeah, I think this should be better. I will play with this.

saved = self.type_checker
self.type_checker = value
try:
yield
finally:
self.type_checker = saved


state: Final = StrictOptionalState(strict_optional=True)
state: Final = SubtypeState(strict_optional=True, type_checker=None)
find_occurrences: tuple[str, str] | None = None
Loading