diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index 82d73ef455473..55651811cea74 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -1975,3 +1975,10 @@ def takes_str_or_float(x: float | str): ... takes_str_or_float(round(1.0)) ``` + +```py +def f(x: float) -> None: + reveal_type(round(x)) # revealed: int + reveal_type(round(x, None)) # revealed: int + reveal_type(round(x, 1)) # revealed: int | float +``` diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 505a505db87f0..08225eaa301ea 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -2011,7 +2011,6 @@ impl<'db, 'c> SpecializationBuilder<'db, 'c> { return Err(error); } } - (Type::TypeVar(bound_typevar), ty) | (ty, Type::TypeVar(bound_typevar)) if bound_typevar.is_inferable(self.db, self.inferable) => { @@ -2210,6 +2209,16 @@ impl<'db, 'c> SpecializationBuilder<'db, 'c> { } } + (formal @ Type::ProtocolInstance(_), actual @ Type::Union(_)) => { + // Reuse the protocol constraint solver for union-typed arguments too. + // This allows protocols like `_SupportsRound1[T]` to infer `T` from + // annotations such as `float`, which are interpreted as `int | float`. + let when = + actual.when_constraint_set_assignable_to(self.db, formal, self.constraints); + let _ = self.add_type_mappings_from_constraint_set(formal, when, &mut f); + return Ok(()); + } + (formal, Type::NominalInstance(actual_nominal)) => { // Special case: `formal` and `actual` are both tuples. if let (Some(formal_tuple), Some(actual_tuple)) = (