diff --git a/pyrefly/lib/alt/attr.rs b/pyrefly/lib/alt/attr.rs index 87220f3a09..5b53708fc8 100644 --- a/pyrefly/lib/alt/attr.rs +++ b/pyrefly/lib/alt/attr.rs @@ -479,7 +479,7 @@ enum AttributeBase1 { TypedDict(TypedDictInner), /// Attribute lookup on a base as part of a subset check against a protocol. ProtocolSubset(Box), - Intersect(Vec, Vec), + Intersect(Option, Vec, Vec), /// Bound methods prefer exposing builtin `types.MethodType` attributes but fall back to the /// underlying function's attributes when the builtin ones are missing. BoundMethod(BoundMethodType), @@ -682,7 +682,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { AttributeBase1::ProtocolSubset(inner) => { self.collect_attribute_candidates_from_base(inner, candidates); } - AttributeBase1::Intersect(options, fallback) => { + AttributeBase1::Intersect(_, options, fallback) => { for b in options { self.collect_attribute_candidates_from_base(b, candidates); } @@ -1862,7 +1862,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } }, - AttributeBase1::Intersect(bases, fallback) => { + AttributeBase1::Intersect(self_type, bases, fallback) => { // Try each base and collect successful lookups, filtering out // GenericAlias lookups when the found attribute is inherited from // `object`. Parametrized classes like `Foo[int]` are an intersection of @@ -1887,7 +1887,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { continue; } let mut acc_candidate = LookupResult::empty(); - self.lookup_attr_static1(b.clone(), attr_name, &mut acc_candidate); + if let Some(self_type) = &self_type { + self.lookup_attr_from_attribute_base1_with_self_type( + b.clone(), + self_type, + attr_name, + &mut acc_candidate, + ); + } else { + self.lookup_attr_static1(b.clone(), attr_name, &mut acc_candidate); + } if acc_candidate.not_found.is_empty() && acc_candidate.internal_error.is_empty() { candidates.push(acc_candidate.found); @@ -1898,10 +1907,78 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } else { // TODO: Intersect the candidates instead of using the fallback. for b in fallback { - self.lookup_attr_static1(b.clone(), attr_name, acc); + if let Some(self_type) = &self_type { + self.lookup_attr_from_attribute_base1_with_self_type( + b.clone(), + self_type, + attr_name, + acc, + ); + } else { + self.lookup_attr_static1(b.clone(), attr_name, acc); + } + } + } + } + } + } + + fn lookup_attr_from_attribute_base1_with_self_type( + &self, + base: AttributeBase1, + self_type: &Type, + attr_name: &Name, + acc: &mut LookupResult, + ) { + match &base { + AttributeBase1::ClassInstance(class) => { + if let Some(attr) = self.try_nn_module_dict_attr(class, attr_name) { + acc.found_class_attribute(attr, base); + return; + } + let metadata = self.get_metadata_for_class(class.class_object()); + match self.get_enum_or_instance_attribute_with_self_type( + class, + &metadata, + attr_name, + self_type.clone(), + ) { + Some(attr) => acc.found_class_attribute(attr, base), + None if metadata.has_base_any() => { + acc.found_type(self.heap.mk_any_implicit(), base) + } + None if metadata + .named_tuple_metadata() + .is_some_and(|m| m.has_dynamic_fields) => + { + acc.found_type(self.heap.mk_any_implicit(), base) + } + None => { + acc.not_found(NotFoundOn::ClassInstance(class.class_object().dupe(), base)) + } + } + } + AttributeBase1::Quantified(_, bound) | AttributeBase1::SelfType(bound) => { + match self.get_instance_attribute_with_self_type( + bound, + self_type.clone(), + attr_name, + ) { + Some(attr) => acc.found_class_attribute(attr, base), + None => { + let metadata = self.get_metadata_for_class(bound.class_object()); + if metadata.has_base_any() { + acc.found_type(self.heap.mk_any_implicit(), base) + } else { + acc.not_found(NotFoundOn::ClassInstance( + bound.class_object().dupe(), + base, + )) + } } } } + _ => self.lookup_attr_static1(base, attr_name, acc), } } @@ -2240,6 +2317,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { AttributeBase1::ClassInstance(self.stdlib.generic_alias().clone()); // Since GenericAlias also exposes all class attributes, we need to intersect the two bases acc.push(AttributeBase1::Intersect( + None, vec![generic_alias_base.clone(), class_base], vec![generic_alias_base], )); @@ -2550,13 +2628,18 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { )), }, Type::Intersect(x) => { + let self_type = Type::Intersect(x.clone()); let mut acc_intersect = Vec::new(); for t in x.0 { self.as_attribute_base1(t, &mut acc_intersect); } let mut acc_fallback = Vec::new(); self.as_attribute_base1(x.1, &mut acc_fallback); - acc.push(AttributeBase1::Intersect(acc_intersect, acc_fallback)); + acc.push(AttributeBase1::Intersect( + Some(self_type), + acc_intersect, + acc_fallback, + )); } Type::ElementOfTypeVarTuple(_) => { acc.push(AttributeBase1::ClassInstance(self.stdlib.object().clone())) @@ -2944,7 +3027,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { // TODO(samzhou19815): Support autocomplete for properties {} } - AttributeBase1::Intersect(bases, _) => { + AttributeBase1::Intersect(_, bases, _) => { for b in bases { self.completions_inner1(b, expected_attribute_name, res); } diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index 045f72959b..0557f9c2be 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -626,11 +626,55 @@ impl ClassField { fn instantiate_for(&self, heap: &TypeHeap, instance: &Instance) -> Self { self.instantiate_helper(&mut |ty| { - ty.subst_self_type_mut(&instance.to_type(heap)); + if let Some(self_type) = instance.self_return_type_override() { + Self::subst_callable_return_self_type_mut(ty, self_type, &instance.to_type(heap)); + } else { + ty.subst_self_type_mut(&instance.to_type(heap)); + } instance.instantiate_member(ty) }) } + fn subst_callable_return_self_type_mut( + ty: &mut Type, + self_type: &Type, + fallback_self_type: &Type, + ) { + let subst_callable = |callable: &mut Callable| { + callable.ret.subst_self_type_mut(self_type); + callable + .params + .visit_mut(&mut |ty| ty.subst_self_type_mut(fallback_self_type)); + }; + match ty { + Type::Callable(callable) => subst_callable(callable), + Type::Function(func) => subst_callable(&mut func.signature), + Type::Forall(forall) => match &mut forall.body { + Forallable::Callable(callable) => subst_callable(callable), + Forallable::Function(func) => subst_callable(&mut func.signature), + _ => ty.subst_self_type_mut(fallback_self_type), + }, + Type::Overload(overload) => { + for sig in overload.signatures.iter_mut() { + match sig { + OverloadType::Function(func) => subst_callable(&mut func.signature), + OverloadType::Forall(forall) => subst_callable(&mut forall.body.signature), + } + } + } + Type::Union(union) => { + for member in union.members.iter_mut() { + Self::subst_callable_return_self_type_mut( + member, + self_type, + fallback_self_type, + ); + } + } + _ => ty.subst_self_type_mut(fallback_self_type), + } + } + fn instantiate_for_class_targs( &self, targs: &TArgs, @@ -3814,6 +3858,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .map(|field| self.as_instance_attribute(name, &field, &Instance::of_class(cls))) } + pub fn get_instance_attribute_with_self_type( + &self, + cls: &ClassType, + self_type: Type, + name: &Name, + ) -> Option { + self.get_class_member(cls.class_object(), name) + .map(|field| { + self.as_instance_attribute( + name, + &field, + &Instance::of_class_with_self_type(cls, self_type), + ) + }) + } + pub fn get_self_attribute(&self, cls: &ClassType, name: &Name) -> Option { self.get_class_member(cls.class_object(), name) .map(|field| self.as_instance_attribute(name, &field, &Instance::of_self_type(cls))) diff --git a/pyrefly/lib/alt/class/enums.rs b/pyrefly/lib/alt/class/enums.rs index cc5cf64644..d9fdb0e288 100644 --- a/pyrefly/lib/alt/class/enums.rs +++ b/pyrefly/lib/alt/class/enums.rs @@ -157,6 +157,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .or_else(|| self.get_instance_attribute(class, attr_name)) } + pub fn get_enum_or_instance_attribute_with_self_type( + &self, + class: &ClassType, + metadata: &ClassMetadata, + attr_name: &Name, + self_type: Type, + ) -> Option { + self.special_case_enum_attr_lookup(class, None, metadata, attr_name) + .or_else(|| self.get_instance_attribute_with_self_type(class, self_type, attr_name)) + } + /// Checks for a special-cased enum attribute on an enum literal, falling back to a regular instance attribute lookup. pub fn get_enum_literal_or_instance_attribute( &self, diff --git a/pyrefly/lib/alt/types/instance.rs b/pyrefly/lib/alt/types/instance.rs index 366157c74e..252e4834c8 100644 --- a/pyrefly/lib/alt/types/instance.rs +++ b/pyrefly/lib/alt/types/instance.rs @@ -27,6 +27,7 @@ pub enum InstanceKind { TypedDict, TypeVar(Quantified), SelfType, + ClassWithSelfType(Type), Protocol(Type), Metaclass(ClassBase), LiteralString, @@ -83,6 +84,14 @@ impl<'a> Instance<'a> { } } + pub fn of_class_with_self_type(cls: &'a ClassType, self_type: Type) -> Self { + Self { + kind: InstanceKind::ClassWithSelfType(self_type), + class: cls.class_object(), + targs: cls.targs(), + } + } + pub fn of_protocol(cls: &'a ClassType, self_type: Type) -> Self { Self { kind: InstanceKind::Protocol(self_type), @@ -113,9 +122,16 @@ impl<'a> Instance<'a> { self.targs.substitute_into_mut(raw_member) } + pub fn self_return_type_override(&self) -> Option<&Type> { + match &self.kind { + InstanceKind::ClassWithSelfType(self_type) => Some(self_type), + _ => None, + } + } + pub fn to_type(&self, heap: &TypeHeap) -> Type { match &self.kind { - InstanceKind::ClassType => { + InstanceKind::ClassType | InstanceKind::ClassWithSelfType(_) => { heap.mk_class_type(ClassType::new(self.class.dupe(), self.targs.clone())) } InstanceKind::TypedDict => { @@ -161,6 +177,7 @@ impl<'a> Instance<'a> { self.targs.clone(), ))), InstanceKind::ClassType + | InstanceKind::ClassWithSelfType(_) | InstanceKind::Protocol(..) | InstanceKind::Metaclass(..) | InstanceKind::TypeVar(..) diff --git a/pyrefly/lib/test/narrow.rs b/pyrefly/lib/test/narrow.rs index 93c8526199..7b7cdd195d 100644 --- a/pyrefly/lib/test/narrow.rs +++ b/pyrefly/lib/test/narrow.rs @@ -1524,6 +1524,44 @@ def test[T: int | str](value: T) -> T: "#, ); +testcase!( + test_self_returning_method_on_typevar_intersection, + r#" +from typing import Self, reveal_type + +class BaseModel: + def model_copy(self) -> Self: ... + +class ParentModel(BaseModel): + field: int | str + +class ChildModel(BaseModel): + field: int + +def test[T: ParentModel](value: T) -> T: + if isinstance(value, ChildModel): + reveal_type(value) # E: ChildModel & T + value_copy = value.model_copy() + reveal_type(value_copy) # E: ChildModel & T + return value_copy + return value + "#, +); + +testcase!( + test_non_self_method_on_generic_typevar_intersection, + r#" +class Mixin[T]: + def remove_all_commands(self) -> None: ... + +class Command[T]: ... + +def test[CogT](command: Command[CogT]) -> None: + if isinstance(command, Mixin): + command.remove_all_commands() + "#, +); + testcase!( test_issubclass_typevar_nondisjoint_classes, r#"