diff --git a/crates/ty_python_semantic/resources/mdtest/call/methods.md b/crates/ty_python_semantic/resources/mdtest/call/methods.md index d814ab5a0f647..b498384a3103a 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/methods.md +++ b/crates/ty_python_semantic/resources/mdtest/call/methods.md @@ -447,6 +447,31 @@ The `owner` argument takes precedence over the `instance` argument: reveal_type(getattr_static(C, "f").__get__("dummy", C)) # revealed: bound method .f() -> Unknown ``` +Implicit `cls` parameters should stay in class-object space when a classmethod is accessed through +`type[T]`: + +```py +from typing import Any, Type, TypeVar + +Model = TypeVar("Model", bound="BaseModel") + +class BaseModel: + @classmethod + def normalize(cls, obj: Any) -> Any: + return obj + + @classmethod + def parse_obj(cls: Type[Model], obj: Any) -> Model: + reveal_type(cls.normalize) # revealed: bound method type[Model@parse_obj].normalize(obj: Any) -> Any + + cls.normalize(obj) + cls.normalize.__func__(cls, obj) + + # error: [invalid-argument-type] + cls.normalize.__func__(cls(), obj) + return cls() +``` + ### Classmethods mixed with other decorators ```toml diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/union_call.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/union_call.md index 10f95d65e4126..07df248f2b42c 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/union_call.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/union_call.md @@ -179,8 +179,6 @@ class B: T = TypeVar("T", A, B) def _(x: T, y: int) -> T: - # error: [invalid-argument-type] - # error: [invalid-argument-type] # error: [invalid-argument-type] return x.foo(y) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 491b5c0198dd4..8751b89e1fa11 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -666,6 +666,274 @@ def g(b: B[T]): return f(b.x) # Fine ``` +## Calling shared methods on union-bounded TypeVars + +Calling a method that exists on all members of a union upper bound should be valid. + +```py +from typing import TypeVar + +class A: + def call_me(self) -> int: + return 1 + +class B: + def call_me(self) -> None: + return None + +TUnionBound = TypeVar("TUnionBound", bound=A | B) + +def call_shared_method(x: TUnionBound) -> None: + reveal_type(x.call_me()) # revealed: int | None + reveal_type(type(x).call_me(x)) # revealed: int | None + reveal_type(x.call_me.__self__) # revealed: TUnionBound@call_shared_method + reveal_type(x.call_me.__func__) # revealed: (def call_me(self) -> int) | (def call_me(self) -> None) +``` + +Shared inherited methods should also stay callable when iterating a generic container of a bounded +type variable: + +```py +from typing import Generic, Iterator, TypeVar + +class Rule: + def apply_value(self) -> None: + pass + + def reset_parameters(self) -> None: + pass + +class Replace(Rule): + pass + +class Add(Rule): + pass + +class RuleIUH(Rule): + def reset_parameters(self) -> None: + pass + +class ReplaceIUH(RuleIUH): + def apply_value(self) -> None: + pass + +TypeRule = TypeVar("TypeRule", bound=Replace | Add | ReplaceIUH) + +class CalibrationInterface(Generic[TypeRule]): + def __iter__(self) -> Iterator[TypeRule]: + raise NotImplementedError + + def apply_values(self) -> None: + for rule in self: + rule.apply_value() + + def reset_parameters(self) -> None: + for rule in self: + rule.reset_parameters() +``` + +If the method is not shared by all upper-bound variants, we still diagnose it: + +```py +class C: + def call_me(self) -> int: + return 1 + +class D: + pass + +TPartiallyBound = TypeVar("TPartiallyBound", bound=C | D) + +def call_missing_method(x: TPartiallyBound) -> None: + x.call_me() # error: [possibly-missing-attribute] +``` + +Explicit `self`-typed overloads should still be enforced for union-bounded type variables: + +```py +from typing import TypeVar, overload + +class BaseRequest: + @overload + def payload(self: "JsonRequest") -> dict[str, object]: ... + @overload + def payload(self: "BinaryRequest") -> bytes: ... + def payload(self): + raise NotImplementedError + +class JsonRequest(BaseRequest): + pass + +class BinaryRequest(BaseRequest): + pass + +class StreamingRequest(BaseRequest): + pass + +TRequest = TypeVar("TRequest", bound=JsonRequest | StreamingRequest) + +def call_payload(request: TRequest) -> None: + # error: [no-matching-overload] + request.payload() +``` + +Constrained `TypeVar`s should keep the same overload behavior: + +```py +from typing import TypeVar + +TConstrainedRequest = TypeVar("TConstrainedRequest", JsonRequest, StreamingRequest) + +def call_payload_constrained(request: TConstrainedRequest) -> None: + # error: [no-matching-overload] + request.payload() +``` + +Narrowing away an impossible union arm should also make plain member lookup succeed on the remaining +bound variant: + +```py +from typing import Any, TypeVar + +TMaybeList = TypeVar("TMaybeList", bound=list[Any] | None) + +def append_value(x: TMaybeList) -> TMaybeList: + if x is None: + return x + + x.append(1) + return x +``` + +Method calls on union-bounded type variables should stay assignable to the original type variable +when each bound arm returns the same receiver occurrence: + +```py +from typing import TypeVar +from typing_extensions import Self + +class MaybeA: + def maybe(self) -> "Self | None": + return self + +class MaybeB: + def maybe(self) -> "Self | None": + return self + +TMaybeReturn = TypeVar("TMaybeReturn", bound=MaybeA | MaybeB) + +def preserve_method_return_identity(x: TMaybeReturn) -> TMaybeReturn | None: + return x.maybe() +``` + +Nominal returns that only mention the selected bound arm should not be treated as +receiver-correlated: + +```py +from typing import TypeVar + +class FreshA: + def maybe(self) -> "FreshA | None": + return FreshA() + +class FreshB: + def maybe(self) -> "FreshB | None": + return FreshB() + +TFreshReturn = TypeVar("TFreshReturn", bound=FreshA | FreshB) + +def reject_nominal_return_correlation(x: TFreshReturn) -> TFreshReturn | None: + # error: [invalid-return-type] "Return type does not match returned value: expected `TFreshReturn@reject_nominal_return_correlation | None`, found `FreshA | None | FreshB`" + return x.maybe() +``` + +Calling `cls.__new__(cls)` through `self.__class__` should stay valid for ordinary bounded typevars. +This should not trigger the union/constrained receiver-rebinding path. + +```py +from typing import TypeVar + +class BaseModel: ... + +Model = TypeVar("Model", bound=BaseModel) + +def clone(self: Model) -> Model: + cls = self.__class__ + return cls.__new__(cls) +``` + +Method calls on narrowed constrained TypeVars should preserve the original typevar identity when the +method returns the narrowed receiver arm: + +```py +from typing import TypeVar +from typing_extensions import Self + +class Prefix: + def normalize(self) -> Self: + return self + +class Suffix: + def normalize(self) -> Self: + return self + +TConstrainedText = TypeVar("TConstrainedText", Prefix, Suffix) + +def apply_values(template: TConstrainedText) -> TConstrainedText: + if isinstance(template, Prefix): + template = template.normalize() + return template + return template +``` + +Bound helper methods on generic classes should stay callable when the outer `self` annotation is an +ordinary bounded typevar: + +```py +from typing import Any, Callable, Generic, TypeVar + +TypeDevices = TypeVar("TypeDevices", bound="Devices[Any]") + +class Devices(Generic[TypeDevices]): + def __compare(self, other: object, func: Callable[[Any, Any], bool]) -> bool: + return True + + def __lt__(self: TypeDevices, other: TypeDevices) -> bool: + return self.__compare(other, lambda a, b: True) +``` + +## Known limitations + +Mixed branch joins do not yet preserve the original constrained typevar when only one branch uses a +method call: + +```py +from typing import TypeVar + +class PeriodIndex: + def asfreq(self, *, freq: object) -> "PeriodIndex": + return PeriodIndex() + +class DatetimeIndex: ... +class TimedeltaIndex: ... + +FreqIndexT = TypeVar("FreqIndexT", PeriodIndex, DatetimeIndex, TimedeltaIndex) + +def asfreq_compat(index: FreqIndexT, freq: object) -> FreqIndexT: + if isinstance(index, PeriodIndex): + new_index = index.asfreq(freq=freq) + elif isinstance(index, DatetimeIndex): + new_index = DatetimeIndex() + elif isinstance(index, TimedeltaIndex): + new_index = TimedeltaIndex() + else: + raise TypeError(type(index)) + + # TODO: This should not be an error. + # error: [invalid-return-type] "Return type does not match returned value: expected `FreqIndexT@asfreq_compat`, found `PeriodIndex | DatetimeIndex | TimedeltaIndex`" + return new_index +``` + ## Constrained TypeVar in a union This is a regression test for an issue that surfaced in the primer report of an early version of diff --git "a/crates/ty_python_semantic/resources/mdtest/snapshots/union_call.md_-_Calling_a_union_of_f\342\200\246_-_Try_to_cover_all_pos\342\200\246_-_Attribute_access_on_\342\200\246_(7bdb97302c27c412).snap" "b/crates/ty_python_semantic/resources/mdtest/snapshots/union_call.md_-_Calling_a_union_of_f\342\200\246_-_Try_to_cover_all_pos\342\200\246_-_Attribute_access_on_\342\200\246_(7bdb97302c27c412).snap" index 102d5bba97dbb..282ef4d1ac4c3 100644 --- "a/crates/ty_python_semantic/resources/mdtest/snapshots/union_call.md_-_Calling_a_union_of_f\342\200\246_-_Try_to_cover_all_pos\342\200\246_-_Attribute_access_on_\342\200\246_(7bdb97302c27c412).snap" +++ "b/crates/ty_python_semantic/resources/mdtest/snapshots/union_call.md_-_Calling_a_union_of_f\342\200\246_-_Try_to_cover_all_pos\342\200\246_-_Attribute_access_on_\342\200\246_(7bdb97302c27c412).snap" @@ -27,48 +27,18 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/union_call.m 12 | 13 | def _(x: T, y: int) -> T: 14 | # error: [invalid-argument-type] -15 | # error: [invalid-argument-type] -16 | # error: [invalid-argument-type] -17 | return x.foo(y) +15 | return x.foo(y) ``` # Diagnostics ``` error[invalid-argument-type]: Argument to bound method `foo` is incorrect - --> src/mdtest_snippet.py:17:12 + --> src/mdtest_snippet.py:15:18 | -15 | # error: [invalid-argument-type] -16 | # error: [invalid-argument-type] -17 | return x.foo(y) - | ^^^^^^^^ Argument type `T@_` does not satisfy upper bound `A` of type variable `Self` - | -info: Union variant `bound method T@_.foo(x: int) -> T@_` is incompatible with this call site -info: Attempted to call union type `(bound method T@_.foo(x: int) -> T@_) | (bound method T@_.foo(x: str) -> T@_)` - -``` - -``` -error[invalid-argument-type]: Argument to bound method `foo` is incorrect - --> src/mdtest_snippet.py:17:12 - | -15 | # error: [invalid-argument-type] -16 | # error: [invalid-argument-type] -17 | return x.foo(y) - | ^^^^^^^^ Argument type `T@_` does not satisfy upper bound `B` of type variable `Self` - | -info: Union variant `bound method T@_.foo(x: str) -> T@_` is incompatible with this call site -info: Attempted to call union type `(bound method T@_.foo(x: int) -> T@_) | (bound method T@_.foo(x: str) -> T@_)` - -``` - -``` -error[invalid-argument-type]: Argument to bound method `foo` is incorrect - --> src/mdtest_snippet.py:17:18 - | -15 | # error: [invalid-argument-type] -16 | # error: [invalid-argument-type] -17 | return x.foo(y) +13 | def _(x: T, y: int) -> T: +14 | # error: [invalid-argument-type] +15 | return x.foo(y) | ^ Expected `str`, found `int` | info: Method defined here @@ -79,7 +49,7 @@ info: Method defined here | ^^^ ------ Parameter declared here 9 | return self | -info: Union variant `bound method T@_.foo(x: str) -> T@_` is incompatible with this call site -info: Attempted to call union type `(bound method T@_.foo(x: int) -> T@_) | (bound method T@_.foo(x: str) -> T@_)` +info: Union variant `bound method T@_.foo(x: str) -> T@_ & B` is incompatible with this call site +info: Attempted to call union type `(bound method T@_.foo(x: int) -> T@_ & A) | (bound method T@_.foo(x: str) -> T@_ & B)` ``` diff --git a/crates/ty_python_semantic/src/place.rs b/crates/ty_python_semantic/src/place.rs index c67aefd806ade..a5e08ed9590d2 100644 --- a/crates/ty_python_semantic/src/place.rs +++ b/crates/ty_python_semantic/src/place.rs @@ -265,7 +265,7 @@ impl<'db> Place<'db> { Place::Defined(defined) => { if let Some((dunder_get_return_ty, _)) = - defined.ty.try_call_dunder_get(db, None, owner) + defined.ty.try_call_dunder_get(db, None, owner, false) { Place::Defined(DefinedPlace { ty: dunder_get_return_ty, diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 7dfc334102fac..8ab32dd900dfc 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -77,8 +77,8 @@ pub(crate) use crate::types::narrow::{ infer_narrowing_constraint, }; use crate::types::newtype::NewType; +use crate::types::signatures::{CallableSignature, ParameterForm, walk_signature}; pub(crate) use crate::types::signatures::{Parameter, Parameters}; -use crate::types::signatures::{ParameterForm, walk_signature}; use crate::types::special_form::TypeQualifier; use crate::types::tuple::TupleSpec; use crate::types::type_alias::TypeAliasType; @@ -2582,6 +2582,233 @@ impl<'db> Type<'db> { } } + fn member_lookup_variants_for_typevar( + db: &'db dyn Db, + bound_typevar: BoundTypeVarInstance<'db>, + ) -> Option> { + match bound_typevar.typevar(db).bound_or_constraints(db) { + Some(TypeVarBoundOrConstraints::UpperBound(bound)) => { + let Type::Union(union) = bound.resolve_type_alias(db) else { + return None; + }; + + Some(UnionType::from_elements( + db, + union.elements(db).iter().copied(), + )) + } + Some(TypeVarBoundOrConstraints::Constraints(constraints)) + if constraints.elements(db).len() > 1 => + { + Some(UnionType::from_elements( + db, + constraints.elements(db).iter().copied(), + )) + } + _ => None, + } + } + + fn narrowed_typevar_member_lookup_target(self, db: &'db dyn Db) -> Option> { + match self { + Type::TypeVar(bound_typevar) => { + Self::member_lookup_variants_for_typevar(db, bound_typevar) + } + Type::Intersection(intersection) => { + let positives = intersection.positive(db); + let negatives = intersection.negative(db); + + let (selected_index, narrowed_variants) = + positives.iter().enumerate().find_map(|(index, positive)| { + positive + .narrowed_typevar_member_lookup_target(db) + .map(|variants| (index, variants)) + })?; + + let other_positives = positives + .iter() + .enumerate() + .filter_map(|(index, positive)| (index != selected_index).then_some(*positive)) + .collect::>(); + + let narrow_variant = |variant: Type<'db>| { + let mut builder = IntersectionBuilder::new(db).add_positive(variant); + for positive in &other_positives { + builder = builder.add_positive(*positive); + } + for negative in negatives { + builder = builder.add_negative(*negative); + } + builder.build() + }; + + let is_impossible_variant = |narrowed: Type<'db>| match narrowed { + Type::Intersection(narrowed_intersection) => narrowed_intersection + .with_expanded_typevars_and_newtypes(db) + .is_never(), + _ => narrowed.is_never(), + }; + + let narrowed = match narrowed_variants { + Type::Union(union) => UnionType::from_elements( + db, + union.elements(db).iter().filter_map(|variant| { + let narrowed = narrow_variant(*variant); + (!is_impossible_variant(narrowed)).then_some(narrowed) + }), + ), + variant => narrow_variant(variant), + }; + + if narrowed == self { + None + } else { + Some(narrowed) + } + } + _ => None, + } + } + + fn has_narrowed_typevar_receiver_identity(self, db: &'db dyn Db) -> bool { + match self { + Type::TypeVar(typevar) => { + !typevar.typevar(db).is_self(db) + && Self::member_lookup_variants_for_typevar(db, typevar).is_some() + } + Type::Intersection(intersection) => { + intersection + .positive(db) + .iter() + .any(|positive| match positive { + Type::TypeVar(typevar) => { + !typevar.typevar(db).is_self(db) + && Self::member_lookup_variants_for_typevar(db, *typevar).is_some() + } + Type::SubclassOf(subclass_of) => { + subclass_of.into_type_var().is_some_and(|typevar| { + !typevar.typevar(db).is_self(db) + && Self::member_lookup_variants_for_typevar(db, typevar) + .is_some() + }) + } + _ => false, + }) + } + _ => false, + } + } + + fn rewrite_function_for_narrowed_receiver( + function: FunctionType<'db>, + db: &'db dyn Db, + new_receiver_type: Type<'db>, + ) -> Option> { + let new_receiver_type = if function.is_classmethod(db) { + match new_receiver_type { + Type::ClassLiteral(..) | Type::GenericAlias(..) | Type::SubclassOf(..) => { + new_receiver_type + } + _ => SubclassOfType::try_from_instance(db, new_receiver_type) + .unwrap_or(new_receiver_type), + } + } else { + new_receiver_type + }; + + let rewritten_function = function.rebind_implicit_receiver(db, new_receiver_type); + (rewritten_function != function).then_some(rewritten_function) + } + + fn with_narrowed_lookup_receiver_identity( + self, + db: &'db dyn Db, + narrowed_target: Type<'db>, + ) -> Type<'db> { + match self { + Type::TypeVar(_) => match narrowed_target { + Type::Union(union) => UnionType::from_elements( + db, + union + .elements(db) + .iter() + .map(|variant| IntersectionType::from_two_elements(db, self, *variant)), + ), + variant => IntersectionType::from_two_elements(db, self, variant), + }, + _ => narrowed_target, + } + } + + fn rebind_narrowed_method_receiver( + self, + db: &'db dyn Db, + new_receiver_type: Type<'db>, + ) -> Type<'db> { + match self { + Type::FunctionLiteral(function) => { + Self::rewrite_function_for_narrowed_receiver(function, db, new_receiver_type) + .map(Type::FunctionLiteral) + .unwrap_or(self) + } + Type::Callable(callable) => { + Type::Callable(CallableType::new( + db, + CallableSignature::from_overloads(callable.signatures(db).iter().map( + |signature| signature.rebind_implicit_receiver(db, new_receiver_type), + )), + callable.kind(db), + )) + } + Type::BoundMethod(bound_method) => { + if bound_method + .receiver_type(db) + .has_narrowed_typevar_receiver_identity(db) + && bound_method.receiver_type(db) != new_receiver_type + { + // Preserve the existing call-binding receiver once it has already been + // narrowed; only the exposed `__self__` should change when we thread the + // original receiver identity back onto the bound method object. + Type::BoundMethod(BoundMethodType::new( + // Keep the existing function-level receiver specialization; this method + // only swaps which object is exposed through `__self__`. + db, + bound_method.function(db), + new_receiver_type, + bound_method.receiver_type(db), + )) + } else { + let function = Self::rewrite_function_for_narrowed_receiver( + bound_method.function(db), + db, + new_receiver_type, + ) + .unwrap_or(bound_method.function(db)); + Type::BoundMethod(BoundMethodType::new( + db, + function, + new_receiver_type, + new_receiver_type, + )) + } + } + Type::Union(union) => UnionType::from_elements( + db, + union + .elements(db) + .iter() + .map(|element| element.rebind_narrowed_method_receiver(db, new_receiver_type)), + ), + Type::Intersection(intersection) => intersection.map_positive(db, |positive| { + positive.rebind_narrowed_method_receiver(db, new_receiver_type) + }), + Type::TypeAlias(alias) => alias + .value_type(db) + .rebind_narrowed_method_receiver(db, new_receiver_type), + _ => self, + } + } + /// This function roughly corresponds to looking up an attribute in the `__dict__` of an object. /// For instance-like types, this goes through the classes MRO and discovers attribute assignments /// in methods, as well as class-body declarations that we consider to be evidence for the presence @@ -2719,6 +2946,7 @@ impl<'db> Type<'db> { db: &'db dyn Db, instance: Option>, owner: Type<'db>, + allow_narrowed_receiver_rebinding: bool, ) -> Option<(Type<'db>, AttributeKind)> { tracing::trace!( "try_call_dunder_get: {}, {}, {}", @@ -2727,7 +2955,12 @@ impl<'db> Type<'db> { owner.display(db) ); if let Some(fallback) = self.materialized_divergent_fallback() { - return fallback.try_call_dunder_get(db, instance, owner); + return fallback.try_call_dunder_get( + db, + instance, + owner, + allow_narrowed_receiver_rebinding, + ); } match self { @@ -2773,10 +3006,41 @@ impl<'db> Type<'db> { // unbound function). This incorrectly matches when the instance is actually // an instance of `None` return Some(( - Type::BoundMethod(BoundMethodType::new(db, function, instance.unwrap())), + Type::BoundMethod(BoundMethodType::new( + db, + function, + instance.unwrap(), + instance.unwrap(), + )), AttributeKind::NormalOrNonDataDescriptor, )); } + Type::FunctionLiteral(function) + if instance.is_some() + && allow_narrowed_receiver_rebinding + && !function.is_staticmethod(db) + && !function.is_classmethod(db) => + { + let self_type = instance.unwrap(); + if let Some(self_typevar) = self_type.as_typevar() + && !self_typevar.typevar(db).is_self(db) + && Self::member_lookup_variants_for_typevar(db, self_typevar).is_some() + { + if let Some(rewritten_function) = + Self::rewrite_function_for_narrowed_receiver(function, db, self_type) + { + return Some(( + Type::BoundMethod(BoundMethodType::new( + db, + rewritten_function, + self_type, + self_type, + )), + AttributeKind::NormalOrNonDataDescriptor, + )); + } + } + } _ => {} } @@ -2822,6 +3086,7 @@ impl<'db> Type<'db> { attribute: PlaceAndQualifiers<'db>, instance: Option>, owner: Type<'db>, + allow_narrowed_receiver_rebinding: bool, ) -> (PlaceAndQualifiers<'db>, AttributeKind) { if let PlaceAndQualifiers { place: @@ -2846,6 +3111,7 @@ impl<'db> Type<'db> { .with_qualifiers(qualifiers), instance, owner, + allow_narrowed_receiver_rebinding, ); } @@ -2881,7 +3147,12 @@ impl<'db> Type<'db> { .map_with_boundness(db, |elem| { Place::Defined(DefinedPlace { ty: elem - .try_call_dunder_get(db, instance, owner) + .try_call_dunder_get( + db, + instance, + owner, + allow_narrowed_receiver_rebinding, + ) .map_or(*elem, |(ty, _)| ty), origin, definedness: boundness, @@ -2891,7 +3162,7 @@ impl<'db> Type<'db> { .with_qualifiers(qualifiers), // TODO: avoid the duplication here: if union.elements(db).iter().all(|elem| { - elem.try_call_dunder_get(db, instance, owner) + elem.try_call_dunder_get(db, instance, owner, allow_narrowed_receiver_rebinding) .is_some_and(|(_, kind)| kind.is_data()) }) { AttributeKind::DataDescriptor @@ -2917,7 +3188,12 @@ impl<'db> Type<'db> { .map_with_boundness(db, |elem| { Place::Defined(DefinedPlace { ty: elem - .try_call_dunder_get(db, instance, owner) + .try_call_dunder_get( + db, + instance, + owner, + allow_narrowed_receiver_rebinding, + ) .map_or(*elem, |(ty, _)| ty), origin, definedness, @@ -2940,9 +3216,12 @@ impl<'db> Type<'db> { }), qualifiers: _, } => { - if let Some((return_ty, attribute_kind)) = - attribute_ty.try_call_dunder_get(db, instance, owner) - { + if let Some((return_ty, attribute_kind)) = attribute_ty.try_call_dunder_get( + db, + instance, + owner, + allow_narrowed_receiver_rebinding, + ) { ( Place::Defined(DefinedPlace { ty: return_ty, @@ -3025,6 +3304,7 @@ impl<'db> Type<'db> { fallback: PlaceAndQualifiers<'db>, policy: InstanceFallbackShadowsNonDataDescriptor, member_policy: MemberLookupPolicy, + allow_narrowed_receiver_rebinding: bool, ) -> PlaceAndQualifiers<'db> { let ( PlaceAndQualifiers { @@ -3037,6 +3317,7 @@ impl<'db> Type<'db> { self.class_member_with_policy(db, name.into(), member_policy), Some(self), self.to_meta_type(db), + allow_narrowed_receiver_rebinding, ); let PlaceAndQualifiers { @@ -3162,6 +3443,16 @@ impl<'db> Type<'db> { db: &'db dyn Db, name: Name, policy: MemberLookupPolicy, + ) -> PlaceAndQualifiers<'db> { + self.member_lookup_with_policy_impl(db, name, policy, true) + } + + fn member_lookup_with_policy_impl( + self, + db: &'db dyn Db, + name: Name, + policy: MemberLookupPolicy, + allow_narrowed_receiver_rebinding: bool, ) -> PlaceAndQualifiers<'db> { tracing::trace!("member_lookup_with_policy: {}.{}", self.display(db), name); if let Some(fallback) = self.materialized_divergent_fallback() { @@ -3172,17 +3463,53 @@ impl<'db> Type<'db> { return Place::bound(self.dunder_class(db)).into(); } + if let Some(target) = self.narrowed_typevar_member_lookup_target(db) { + let lookup_target = self.with_narrowed_lookup_receiver_identity(db, target); + let result = lookup_target.member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ); + + if allow_narrowed_receiver_rebinding && self.has_narrowed_typevar_receiver_identity(db) + { + return result.map_type(|ty| ty.rebind_narrowed_method_receiver(db, self)); + } + + return result; + } + let name_str = name.as_str(); match self { Type::Union(union) => union.map_with_boundness_and_qualifiers(db, |elem| { - elem.member_lookup_with_policy(db, name_str.into(), policy) + elem.member_lookup_with_policy_impl( + db, + name_str.into(), + policy, + allow_narrowed_receiver_rebinding, + ) }), - Type::Intersection(intersection) => intersection - .map_with_boundness_and_qualifiers(db, |elem| { - elem.member_lookup_with_policy(db, name_str.into(), policy) - }), + Type::Intersection(intersection) => { + let result = intersection.map_with_boundness_and_qualifiers(db, |elem| { + elem.member_lookup_with_policy_impl( + db, + name_str.into(), + policy, + allow_narrowed_receiver_rebinding, + ) + }); + + if allow_narrowed_receiver_rebinding + && self.has_narrowed_typevar_receiver_identity(db) + { + result.map_type(|ty| ty.rebind_narrowed_method_receiver(db, self)) + } else { + result + } + } Type::Dynamic(..) | Type::Divergent(_) | Type::Never => Place::bound(self).into(), @@ -3302,25 +3629,50 @@ impl<'db> Type<'db> { _ => { KnownClass::MethodType .to_instance(db) - .member_lookup_with_policy(db, name.clone(), policy) + .member_lookup_with_policy_impl( + db, + name.clone(), + policy, + allow_narrowed_receiver_rebinding, + ) .or_fall_back_to(db, || { // If an attribute is not available on the bound method object, // it will be looked up on the underlying function object: Type::FunctionLiteral(bound_method.function(db)) - .member_lookup_with_policy(db, name, policy) + .member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ) }) } }, Type::KnownBoundMethod(method) => method .class() .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ), Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ), Type::DataclassDecorator(_) => KnownClass::FunctionType .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ), Type::Callable(_) | Type::DataclassTransformer(_) if name_str == "__call__" => { Place::bound(self).into() @@ -3328,11 +3680,20 @@ impl<'db> Type<'db> { Type::Callable(callable) if callable.is_function_like(db) => KnownClass::FunctionType .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ), - Type::Callable(_) | Type::DataclassTransformer(_) => { - Type::object().member_lookup_with_policy(db, name, policy) - } + Type::Callable(_) | Type::DataclassTransformer(_) => Type::object() + .member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ), Type::NominalInstance(instance) if matches!(name.as_str(), "major" | "minor") @@ -3399,12 +3760,20 @@ impl<'db> Type<'db> { Type::NewTypeInstance(new_type_instance) if self.as_union_like(db).is_some() => { new_type_instance .concrete_base_type(db) - .member_lookup_with_policy(db, name, policy) + .member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ) } - Type::TypeAlias(alias) => alias - .value_type(db) - .member_lookup_with_policy(db, name, policy), + Type::TypeAlias(alias) => alias.value_type(db).member_lookup_with_policy_impl( + db, + name, + policy, + allow_narrowed_receiver_rebinding, + ), _ if policy.no_instance_fallback() => self.invoke_descriptor_protocol( db, @@ -3412,8 +3781,8 @@ impl<'db> Type<'db> { Place::Undefined.into(), InstanceFallbackShadowsNonDataDescriptor::No, policy, + allow_narrowed_receiver_rebinding, ), - Type::LiteralValue(literal) if literal.as_enum().is_some() && matches!(name_str, "name" | "_name_" | "value" | "_value_") => @@ -3512,6 +3881,7 @@ impl<'db> Type<'db> { fallback, InstanceFallbackShadowsNonDataDescriptor::No, policy, + allow_narrowed_receiver_rebinding, ); if result.is_class_var() && self.is_typed_dict() { @@ -3553,11 +3923,32 @@ impl<'db> Type<'db> { let self_instance = self .to_instance(db) .expect("`to_instance` always returns `Some` for `ClassLiteral`, `GenericAlias`, and `SubclassOf`"); - let class_attr_plain = - class_attr_plain.map_type(|ty| ty.bind_self_typevars(db, self_instance)); + let narrowed_receiver = allow_narrowed_receiver_rebinding + .then(|| self_instance.as_typevar()) + .flatten() + .filter(|typevar| { + !typevar.typevar(db).is_self(db) + && Self::member_lookup_variants_for_typevar(db, *typevar).is_some() + && name_str != "__new__" + }) + .map(|_| self_instance); + let class_attr_plain = class_attr_plain.map_type(|ty| { + let ty = ty.bind_self_typevars(db, self_instance); + if let Some(new_receiver_type) = narrowed_receiver { + ty.rebind_narrowed_method_receiver(db, new_receiver_type) + } else { + ty + } + }); - let class_attr_fallback = - Self::try_call_dunder_get_on_attribute(db, class_attr_plain, None, self).0; + let class_attr_fallback = Self::try_call_dunder_get_on_attribute( + db, + class_attr_plain, + None, + self, + allow_narrowed_receiver_rebinding, + ) + .0; let result = self.invoke_descriptor_protocol( db, @@ -3565,6 +3956,7 @@ impl<'db> Type<'db> { class_attr_fallback, InstanceFallbackShadowsNonDataDescriptor::Yes, policy, + allow_narrowed_receiver_rebinding, ); // A class is an instance of its metaclass. If attribute lookup on the class @@ -3768,9 +4160,10 @@ impl<'db> Type<'db> { Type::BoundMethod(bound_method) => { let signature = bound_method.function(db).signature(db); - CallableBinding::from_overloads(self, signature.overloads.iter().cloned()) - .with_bound_type(bound_method.self_instance(db)) - .into() + let binding = + CallableBinding::from_overloads(self, signature.overloads.iter().cloned()) + .with_bound_type(bound_method.receiver_type(db)); + binding.into() } Type::KnownBoundMethod(method) => { @@ -5625,6 +6018,7 @@ impl<'db> Type<'db> { db, method.function(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor), method.self_instance(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor), + method.receiver_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor), )), Type::NominalInstance(instance) if matches!(type_mapping, TypeMapping::Promote(PromotionMode::On, PromotionKind::Regular)) => { @@ -5920,6 +6314,12 @@ impl<'db> Type<'db> { typevars, visitor, ); + method.receiver_type(db).find_legacy_typevars_impl( + db, + binding_context, + typevars, + visitor, + ); method.function(db).find_legacy_typevars_impl( db, binding_context, diff --git a/crates/ty_python_semantic/src/types/bound_super.rs b/crates/ty_python_semantic/src/types/bound_super.rs index e82e75888511f..d22029edd714b 100644 --- a/crates/ty_python_semantic/src/types/bound_super.rs +++ b/crates/ty_python_semantic/src/types/bound_super.rs @@ -885,7 +885,7 @@ impl<'db> BoundSuperType<'db> { attribute: PlaceAndQualifiers<'db>, ) -> Option> { let (instance, owner) = self.owner(db).descriptor_binding(db)?; - Some(Type::try_call_dunder_get_on_attribute(db, attribute, instance, owner).0) + Some(Type::try_call_dunder_get_on_attribute(db, attribute, instance, owner, false).0) } /// Similar to `Type::find_name_in_mro_with_policy`, but performs lookup starting *after* the diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 243ab7f9972a4..3a07d87edf714 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -11,6 +11,7 @@ mod constructor; use std::borrow::Cow; +use std::cell::RefCell; use std::collections::HashSet; use std::fmt; @@ -50,6 +51,7 @@ use crate::types::signatures::{ }; use crate::types::tuple::{TupleLength, TupleSpec, TupleType}; use crate::types::typevar::BoundTypeVarIdentity; +use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams, EvaluationMode, GenericAlias, InternedConstraintSet, @@ -1056,7 +1058,7 @@ impl<'db> Bindings<'db> { match overload.parameter_types() { [_, Some(owner)] => { overload.set_return_type(Type::BoundMethod( - BoundMethodType::new(db, function, *owner), + BoundMethodType::new(db, function, *owner, *owner), )); } [Some(instance), None] => { @@ -1065,6 +1067,7 @@ impl<'db> Bindings<'db> { db, function, instance.to_meta_type(db), + instance.to_meta_type(db), ), )); } @@ -1077,7 +1080,7 @@ impl<'db> Bindings<'db> { overload.set_return_type(Type::FunctionLiteral(function)); } else { overload.set_return_type(Type::BoundMethod(BoundMethodType::new( - db, function, *first, + db, function, *first, *first, ))); } } @@ -1091,7 +1094,7 @@ impl<'db> Bindings<'db> { match overload.parameter_types() { [_, _, Some(owner)] => { overload.set_return_type(Type::BoundMethod( - BoundMethodType::new(db, *function, *owner), + BoundMethodType::new(db, *function, *owner, *owner), )); } @@ -1101,6 +1104,7 @@ impl<'db> Bindings<'db> { db, *function, instance.to_meta_type(db), + instance.to_meta_type(db), ), )); } @@ -1116,7 +1120,9 @@ impl<'db> Bindings<'db> { } [_, Some(instance), _] => { overload.set_return_type(Type::BoundMethod( - BoundMethodType::new(db, *function, *instance), + BoundMethodType::new( + db, *function, *instance, *instance, + ), )); } @@ -2544,7 +2550,16 @@ impl<'db> CallableBinding<'db> { return; }; for overload in &mut self.overloads { + let had_implicit_receiver = overload + .signature + .parameters() + .as_slice() + .first() + .is_some_and(Parameter::is_implicit_receiver); overload.signature = overload.signature.bind_self(db, Some(bound_self)); + if had_implicit_receiver { + overload.signature = overload.signature.clone().prune_unused_generic_context(db); + } } } @@ -4455,16 +4470,79 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { partially_specialized_declared_type: &FxHashSet>, specialization_errors: &mut Vec>, ) -> bool { + #[derive(Default)] + struct CollectMentionedTypeVars<'db> { + typevars: RefCell>>, + recursion_guard: TypeCollector<'db>, + } + + impl<'db> TypeVisitor<'db> for CollectMentionedTypeVars<'db> { + fn should_visit_lazy_type_attributes(&self) -> bool { + false + } + + fn visit_bound_type_var_type( + &self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarInstance<'db>, + ) { + let identity = if bound_typevar.is_paramspec(db) { + bound_typevar.without_paramspec_attr(db).identity(db) + } else { + bound_typevar.identity(db) + }; + self.typevars.borrow_mut().insert(identity); + walk_type_with_recursion_guard( + db, + Type::TypeVar(bound_typevar), + self, + &self.recursion_guard, + ); + } + + fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) { + walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard); + } + } + let mut assignable_to_declared_type = true; + let declared_type_mentions_generic_context = + |declared_type: Type<'db>, generic_context: GenericContext<'db>| { + let mentioned = CollectMentionedTypeVars::default(); + mentioned.visit_type(self.db, declared_type); + + generic_context.variables(self.db).any(|bound_typevar| { + mentioned + .typevars + .borrow() + .contains(&if bound_typevar.is_paramspec(self.db) { + bound_typevar + .without_paramspec_attr(self.db) + .identity(self.db) + } else { + bound_typevar.identity(self.db) + }) + }) + }; let parameters = self.signature.parameters(); - for (argument_index, adjusted_argument_index, _, argument_types) in + for (argument_index, adjusted_argument_index, argument, argument_types) in self.enumerate_argument_types() { for (parameter_index, variadic_argument_type) in self.argument_matches[argument_index].iter() { let declared_type = parameters[parameter_index].annotated_type(); + if matches!(self.signature_type, Type::BoundMethod(_)) + && !matches!(argument, Argument::Synthetic) + { + let Some(generic_context) = self.signature.generic_context else { + continue; + }; + if !declared_type_mentions_generic_context(declared_type, generic_context) { + continue; + } + } let argument_type = argument_types.get_for_declared_type(declared_type); let specialization_result = builder.infer_map( diff --git a/crates/ty_python_semantic/src/types/class/static_literal.rs b/crates/ty_python_semantic/src/types/class/static_literal.rs index f02c542e4bb0f..525c8514e21d9 100644 --- a/crates/ty_python_semantic/src/types/class/static_literal.rs +++ b/crates/ty_python_semantic/src/types/class/static_literal.rs @@ -1275,7 +1275,7 @@ impl<'db> StaticClassLiteral<'db> { if let Some(ref mut default_ty) = default_ty { *default_ty = default_ty - .try_call_dunder_get(db, None, Type::from(self)) + .try_call_dunder_get(db, None, Type::from(self), false) .map(|(return_ty, _)| return_ty) .unwrap_or_else(Type::unknown); } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index b73018374f336..b33130f132956 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -960,6 +960,35 @@ impl<'db> FunctionType<'db> { ) } + pub(crate) fn rebind_implicit_receiver( + self, + db: &'db dyn Db, + new_receiver_type: Type<'db>, + ) -> Self { + let updated_signature = CallableSignature::from_overloads( + self.signature(db) + .overloads + .iter() + .map(|signature| signature.rebind_implicit_receiver(db, new_receiver_type)), + ); + let updated_last_definition_signature = self + .last_definition_signature(db) + .rebind_implicit_receiver(db, new_receiver_type); + + if updated_signature == *self.signature(db) + && updated_last_definition_signature == *self.last_definition_signature(db) + { + self + } else { + Self::new( + db, + self.literal(db), + Some(updated_signature), + Some(updated_last_definition_signature), + ) + } + } + pub(crate) fn with_dataclass_transformer_params( self, db: &'db dyn Db, @@ -1194,7 +1223,7 @@ impl<'db> FunctionType<'db> { db: &'db dyn Db, self_instance: Type<'db>, ) -> BoundMethodType<'db> { - BoundMethodType::new(db, self, self_instance) + BoundMethodType::new(db, self, self_instance, self_instance) } pub(crate) fn find_legacy_typevars_impl( diff --git a/crates/ty_python_semantic/src/types/method.rs b/crates/ty_python_semantic/src/types/method.rs index 436c09209e0b9..739b86d70261a 100644 --- a/crates/ty_python_semantic/src/types/method.rs +++ b/crates/ty_python_semantic/src/types/method.rs @@ -24,6 +24,12 @@ pub struct BoundMethodType<'db> { /// The instance on which this method has been called. Corresponds to the `__self__` /// attribute on a bound method object pub(super) self_instance: Type<'db>, + /// The receiver type used when binding the implicit `self` or `cls` parameter for calls. + /// + /// This is usually identical to `self_instance`, but narrowed receiver rebinding can preserve + /// a more specific same-occurrence receiver here while exposing the original object through + /// `__self__`. + pub(super) receiver_type: Type<'db>, } // The Salsa heap is tracked separately. @@ -36,6 +42,7 @@ pub(super) fn walk_bound_method_type<'db, V: visitor::TypeVisitor<'db> + ?Sized> ) { visitor.visit_function_type(db, method.function(db)); visitor.visit_type(db, method.self_instance(db)); + visitor.visit_type(db, method.receiver_type(db)); } fn into_callable_type_cycle_initial<'db>( @@ -52,19 +59,17 @@ impl<'db> BoundMethodType<'db> { /// This is normally the bound-instance type (the type of `self` or `cls`), but if the bound method is /// a `@classmethod`, then it should be an instance of that bound-instance type. pub(crate) fn typing_self_type(self, db: &'db dyn Db) -> Type<'db> { - let mut self_instance = self.self_instance(db); + let mut self_instance = self.receiver_type(db); if self.function(db).is_classmethod(db) { self_instance = self_instance.to_instance(db).unwrap_or_else(Type::unknown); } self_instance } - pub(crate) fn map_self_type( - self, - db: &'db dyn Db, - f: impl FnOnce(Type<'db>) -> Type<'db>, - ) -> Self { - Self::new(db, self.function(db), f(self.self_instance(db))) + pub(crate) fn map_self_type(self, db: &'db dyn Db, f: impl Fn(Type<'db>) -> Type<'db>) -> Self { + let self_instance = f(self.self_instance(db)); + let receiver_type = f(self.receiver_type(db)); + Self::new(db, self.function(db), self_instance, receiver_type) } #[salsa::tracked(cycle_initial=into_callable_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)] @@ -97,6 +102,8 @@ impl<'db> BoundMethodType<'db> { .recursive_type_normalized_impl(db, div, nested)?, self.self_instance(db) .recursive_type_normalized_impl(db, div, true)?, + self.receiver_type(db) + .recursive_type_normalized_impl(db, div, true)?, )) } } @@ -114,7 +121,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { // bound self parameter, are contravariant.) self.check_function_pair(db, source.function(db), target.function(db)) .and(db, self.constraints, || { - self.check_type_pair(db, target.self_instance(db), source.self_instance(db)) + self.check_type_pair(db, target.receiver_type(db), source.receiver_type(db)) }) } } diff --git a/crates/ty_python_semantic/src/types/property_tests/type_generation.rs b/crates/ty_python_semantic/src/types/property_tests/type_generation.rs index bd757640e8c71..8f9a018e16d80 100644 --- a/crates/ty_python_semantic/src/types/property_tests/type_generation.rs +++ b/crates/ty_python_semantic/src/types/property_tests/type_generation.rs @@ -143,6 +143,7 @@ fn create_bound_method<'db>( db, function.expect_function_literal(), builtins_class.to_instance(db).unwrap(), + builtins_class.to_instance(db).unwrap(), )) } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 143beb474791f..656741ad4e8a7 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -699,6 +699,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { Place::Undefined.into(), InstanceFallbackShadowsNonDataDescriptor::No, MemberLookupPolicy::default(), + false, ) .place else { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index ecc3be7ab5b29..717c55ddaf98c 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -10,6 +10,7 @@ //! argument types and return types. For each callable type in the union, the call expression's //! arguments must match _at least one_ overload. +use std::cell::RefCell; use std::slice::Iter; use itertools::{EitherOrBoth, Itertools}; @@ -27,6 +28,7 @@ use crate::types::infer::infer_deferred_types; use crate::types::relation::{ HasRelationToVisitor, IsDisjointVisitor, TypeRelation, TypeRelationChecker, }; +use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, CallableType, FindLegacyTypeVarsVisitor, KnownClass, MaterializationKind, ParamSpecAttrKind, SelfBinding, @@ -662,9 +664,8 @@ impl<'db> Signature<'db> { self_type: impl FnOnce() -> Option>, ) { if let Some(first_parameter) = self.parameters.value.first_mut() - && first_parameter.is_positional() + && first_parameter.is_implicit_receiver() && first_parameter.annotated_type.is_unknown() - && first_parameter.inferred_annotation && let Some(self_type) = self_type() { first_parameter.annotated_type = self_type; @@ -730,6 +731,7 @@ impl<'db> Signature<'db> { ); return_ty = return_ty.apply_type_mapping(db, &self_mapping, TypeContext::default()); } + Self { generic_context: self .generic_context @@ -740,6 +742,113 @@ impl<'db> Signature<'db> { } } + /// Replace the leading implicit receiver annotation with a narrower receiver type. + /// + /// This is used for method lookup on narrowed typevars, where the runtime receiver is still + /// "the same" typevar occurrence but one union arm has narrowed its upper bound. + pub(crate) fn rebind_implicit_receiver( + &self, + db: &'db dyn Db, + new_receiver_type: Type<'db>, + ) -> Self { + let Some(first_parameter) = self.parameters.value.first() else { + return self.clone(); + }; + + if !first_parameter.is_implicit_receiver() + || first_parameter + .annotated_type() + .is_equivalent_to(db, new_receiver_type) + { + return self.clone(); + } + + let mut parameters = self.parameters.value.clone(); + parameters[0].annotated_type = new_receiver_type; + + Self { + generic_context: self.generic_context, + definition: self.definition, + parameters: Parameters::new(db, parameters), + return_ty: self.return_ty, + } + } + + fn prune_unused_generic_context_impl(self, db: &'db dyn Db) -> Self { + #[derive(Default)] + struct CollectUsedTypeVars<'db> { + typevars: RefCell>>, + recursion_guard: TypeCollector<'db>, + } + + impl<'db> TypeVisitor<'db> for CollectUsedTypeVars<'db> { + fn should_visit_lazy_type_attributes(&self) -> bool { + false + } + + fn visit_bound_type_var_type( + &self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarInstance<'db>, + ) { + let bound_typevar = if bound_typevar.is_paramspec(db) { + bound_typevar.without_paramspec_attr(db) + } else { + bound_typevar + }; + self.typevars.borrow_mut().insert(bound_typevar); + walk_type_with_recursion_guard( + db, + Type::TypeVar(bound_typevar), + self, + &self.recursion_guard, + ); + } + + fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) { + walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard); + } + } + + let Some(generic_context) = self.generic_context else { + return self; + }; + + let used_typevars = CollectUsedTypeVars::default(); + for (index, parameter) in self.parameters.as_slice().iter().enumerate() { + if index == 0 && parameter.is_implicit_receiver() { + continue; + } + + used_typevars.visit_type(db, parameter.annotated_type()); + if let Some(default_ty) = parameter.default_type() { + used_typevars.visit_type(db, default_ty); + } + } + used_typevars.visit_type(db, self.return_ty); + + let remaining = generic_context + .variables(db) + .filter(|bound_typevar| { + used_typevars + .typevars + .borrow() + .iter() + .any(|used_typevar| used_typevar.is_same_typevar_as(db, *bound_typevar)) + }) + .collect::>(); + + Self { + generic_context: (!remaining.is_empty()) + .then(|| GenericContext::from_typevar_instances(db, remaining)), + ..self + } + } + + pub(crate) fn prune_unused_generic_context(self, db: &'db dyn Db) -> Self { + self.prune_unused_generic_context_impl(db) + } + pub(crate) fn apply_self(&self, db: &'db dyn Db, self_type: Type<'db>) -> Self { let self_mapping = TypeMapping::BindSelf(SelfBinding::new( db, @@ -2898,8 +3007,9 @@ pub(crate) struct Parameter<'db> { /// Does the type of this parameter come from an explicit annotation, or was it inferred from /// the context, like `Unknown` for any normal un-annotated parameter, `Self` for the `self` /// parameter of instance method, or `type[Self]` for `cls` parameter of classmethods. This - /// field is only used to decide whether to display the annotated type; it has no effect on the - /// type semantics of the parameter. + /// field is primarily used to decide whether to display the annotated type; some method-binding + /// logic also uses it to distinguish implicit receiver annotations from explicit `self` + /// annotations. pub(crate) inferred_annotation: bool, /// Variadic parameters can have starred annotations, e.g. @@ -3218,6 +3328,10 @@ impl<'db> Parameter<'db> { !self.inferred_annotation } + pub(crate) fn is_implicit_receiver(&self) -> bool { + self.is_positional() && self.inferred_annotation + } + /// Name of the parameter (if it has one). pub(crate) fn name(&self) -> Option<&ast::name::Name> { match &self.kind {