diff --git a/pyrefly/lib/alt/operators.rs b/pyrefly/lib/alt/operators.rs index 30ef8244fa..7e9d87e138 100644 --- a/pyrefly/lib/alt/operators.rs +++ b/pyrefly/lib/alt/operators.rs @@ -440,6 +440,15 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ]; self.try_binop_calls(&calls_to_try, range, errors, &context) }; + let preserve_quantified_if_same_restriction = |q: &Quantified, result: Type| { + if matches!(q.restriction, Restriction::Constraints(_)) + && result == q.bound_type(self.stdlib, self.heap) + { + Type::Quantified(Box::new(q.clone())) + } else { + result + } + }; self.distribute_over_union(lhs, |lhs| { self.distribute_over_union(rhs, |rhs| { // If an Any appears on the RHS, do not refine the return type based on the LHS. @@ -542,20 +551,23 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { && let Restriction::Constraints(constraints) = &left_q.restriction => { - self.unions(constraints.map(|constraint| { + let result = self.unions(constraints.map(|constraint| { self.binop_types(x, constraint, constraint, errors) - })) + })); + preserve_quantified_if_same_restriction(left_q, result) } // We skip non-union bounds to avoid accidentally erasing `Self` typevars. (Type::Quantified(left_q), _) if let Some(left_restriction) = self.as_union_restriction(left_q) => { - self.binop_types(x, &left_restriction, rhs, errors) + let result = self.binop_types(x, &left_restriction, rhs, errors); + preserve_quantified_if_same_restriction(left_q, result) } (_, Type::Quantified(right_q)) if let Some(right_restriction) = self.as_union_restriction(right_q) => { - self.binop_types(x, lhs, &right_restriction, errors) + let result = self.binop_types(x, lhs, &right_restriction, errors); + preserve_quantified_if_same_restriction(right_q, result) } _ => binop_call(x.op, lhs, rhs, x.range), } diff --git a/pyrefly/lib/test/operators.rs b/pyrefly/lib/test/operators.rs index abf55cd50f..b2d1f79739 100644 --- a/pyrefly/lib/test/operators.rs +++ b/pyrefly/lib/test/operators.rs @@ -42,6 +42,19 @@ class A(Generic[N]): "#, ); +testcase!( + test_preserve_constrained_typevar_through_binop, + r#" +from typing import TypeVar + +T = TypeVar("T", str, int) + +def foo(a: T) -> T: + doubled = 2 * a + return a + doubled +"#, +); + testcase!( test_bound_typevar_comparison, r#"