Skip to content

Commit d0df10a

Browse files
committed
[ty] Fix Step 5 overload ambiguity for multi-argument calls
1 parent 14bd2b2 commit d0df10a

File tree

4 files changed

+113
-65
lines changed

4 files changed

+113
-65
lines changed

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ from typing_extensions import reveal_type
880880
def _(a: int | None):
881881
reveal_type(
882882
# error: [no-matching-overload]
883-
# revealed: Unknown
883+
# revealed: Any
884884
f(
885885
A(),
886886
a1=a,
@@ -1231,8 +1231,8 @@ def _(list_int: list[int], list_any: list[Any]):
12311231
# All materializations of `list[Any]` are assignable to `list[int]` and `list[Any]`, but the
12321232
# return type of first and second overloads are not equivalent, so the overload matching
12331233
# is ambiguous.
1234-
reveal_type(f(list_any)) # revealed: Unknown
1235-
reveal_type(f(*(list_any,))) # revealed: Unknown
1234+
reveal_type(f(list_any)) # revealed: Any
1235+
reveal_type(f(*(list_any,))) # revealed: Any
12361236
```
12371237

12381238
### Single tuple argument
@@ -1277,8 +1277,8 @@ def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, An
12771277

12781278
# All materializations of `tuple[Any, Any]` are assignable to the parameters of all the
12791279
# overloads, but the return types aren't equivalent, so the overload matching is ambiguous
1280-
reveal_type(f(any_any)) # revealed: Unknown
1281-
reveal_type(f(*(any_any,))) # revealed: Unknown
1280+
reveal_type(f(any_any)) # revealed: Any
1281+
reveal_type(f(*(any_any,))) # revealed: Any
12821282
```
12831283

12841284
### `Unknown` passed into an overloaded function annotated with protocols
@@ -1309,15 +1309,15 @@ def f(a: Foo, b: list[str], c: list[LiteralString], e):
13091309
reveal_type(a.join(b)) # revealed: str
13101310
reveal_type(a.join(c)) # revealed: LiteralString
13111311

1312-
# since both overloads match and they have return types that are not equivalent,
1312+
# Since both overloads match and they have return types that are not equivalent,
13131313
# step (5) of the overload evaluation algorithm says we must evaluate the result of the
1314-
# call as `Unknown`.
1314+
# call as `Any`.
13151315
#
13161316
# Note: although the spec does not state as such (since intersections in general are not
13171317
# specified currently), `(str | LiteralString) & Unknown` might also be a reasonable type
13181318
# here (the union of all overload returns, intersected with `Unknown`) -- here that would
13191319
# simplify to `str & Unknown`.
1320-
reveal_type(a.join(e)) # revealed: Unknown
1320+
reveal_type(a.join(e)) # revealed: Any
13211321
```
13221322

13231323
### Multiple arguments
@@ -1367,8 +1367,8 @@ def _(list_int: list[int], list_any: list[Any], int_str: tuple[int, str], int_an
13671367
# All materializations of first argument is assignable to the second overload and for the second
13681368
# argument, they're assignable to the third overload, so no overloads are filtered out; the
13691369
# return types of the remaining overloads are not equivalent, so overload matching is ambiguous
1370-
reveal_type(f(list_int, any_any)) # revealed: Unknown
1371-
reveal_type(f(*(list_int, any_any))) # revealed: Unknown
1370+
reveal_type(f(list_int, any_any)) # revealed: Any
1371+
reveal_type(f(*(list_int, any_any))) # revealed: Any
13721372
```
13731373

13741374
### `LiteralString` and `str`
@@ -1400,8 +1400,8 @@ def _(literal: LiteralString, string: str, any: Any):
14001400

14011401
# `Any` matches both overloads, but the return types are not equivalent.
14021402
# Pyright and mypy both reveal `str` here, contrary to the spec.
1403-
reveal_type(f(any)) # revealed: Unknown
1404-
reveal_type(f(*(any,))) # revealed: Unknown
1403+
reveal_type(f(any)) # revealed: Any
1404+
reveal_type(f(*(any,))) # revealed: Any
14051405
```
14061406

14071407
### Generics
@@ -1436,11 +1436,11 @@ def _(list_int: list[int], list_str: list[str], list_any: list[Any], any: Any):
14361436
reveal_type(f(list_str)) # revealed: str
14371437
reveal_type(f(*(list_str,))) # revealed: str
14381438

1439-
reveal_type(f(list_any)) # revealed: Unknown
1440-
reveal_type(f(*(list_any,))) # revealed: Unknown
1439+
reveal_type(f(list_any)) # revealed: Any
1440+
reveal_type(f(*(list_any,))) # revealed: Any
14411441

1442-
reveal_type(f(any)) # revealed: Unknown
1443-
reveal_type(f(*(any,))) # revealed: Unknown
1442+
reveal_type(f(any)) # revealed: Any
1443+
reveal_type(f(*(any,))) # revealed: Any
14441444
```
14451445

14461446
### Generics (multiple arguments)
@@ -1513,7 +1513,45 @@ def _(a_int: A[int], a_str: A[str], a_any: A[Any]):
15131513
def _(b_int: B[int], b_str: B[str], b_any: B[Any]):
15141514
reveal_type(b_int.method()) # revealed: int
15151515
reveal_type(b_str.method()) # revealed: str
1516-
reveal_type(b_any.method()) # revealed: Unknown
1516+
reveal_type(b_any.method()) # revealed: Any
1517+
```
1518+
1519+
### Ambiguous `Any` overloads (multiple arguments)
1520+
1521+
```toml
1522+
[environment]
1523+
python-version = "3.12"
1524+
```
1525+
1526+
`overloaded.pyi`:
1527+
1528+
```pyi
1529+
from typing import Any, overload
1530+
1531+
class A[T]:
1532+
def get(self) -> T: ...
1533+
1534+
@overload
1535+
def op(l: A[None], r: A[None]) -> A[None]: ...
1536+
@overload
1537+
def op(l: A[None], r: A[Any]) -> A[None]: ...
1538+
@overload
1539+
def op(l: A[Any], r: A[None]) -> A[None]: ...
1540+
@overload
1541+
def op(l: A[Any], r: A[Any]) -> A[Any]: ...
1542+
```
1543+
1544+
```py
1545+
from typing import Any, assert_type
1546+
1547+
from overloaded import A, op
1548+
1549+
def _(x: A[None], y: A[Any]) -> None:
1550+
assert_type(op(x, x), A[None])
1551+
assert_type(op(x, y), A[None])
1552+
assert_type(op(y, x), A[None])
1553+
assert_type(op(y, y), Any)
1554+
reveal_type(op(y, y)) # revealed: Any
15171555
```
15181556

15191557
### Variadic argument
@@ -1557,11 +1595,11 @@ def _(arg: list[Any]):
15571595
# Matches both overload and the return types are equivalent
15581596
reveal_type(f1(*arg)) # revealed: A
15591597
# Matches both overload but the return types aren't equivalent
1560-
reveal_type(f2(*arg)) # revealed: Unknown
1598+
reveal_type(f2(*arg)) # revealed: Any
15611599
# Filters out the final overload and the return types are equivalent
15621600
reveal_type(f3(*arg)) # revealed: A
15631601
# Filters out the final overload but the return types aren't equivalent
1564-
reveal_type(f4(*arg)) # revealed: Unknown
1602+
reveal_type(f4(*arg)) # revealed: Any
15651603
```
15661604

15671605
### Varidic argument with generics
@@ -1620,15 +1658,15 @@ def _(args1: list[int], args2: list[Any]):
16201658
reveal_type(f2()) # revealed: tuple[Any, ...]
16211659
reveal_type(f2(1, 2)) # revealed: tuple[Literal[1], Literal[2]]
16221660
# TODO: Should be `tuple[Literal[1], Literal[2]]`
1623-
reveal_type(f2(x1=1, x2=2)) # revealed: Unknown
1661+
reveal_type(f2(x1=1, x2=2)) # revealed: Any
16241662
# TODO: Should be `tuple[Literal[2], Literal[1]]`
1625-
reveal_type(f2(x2=1, x1=2)) # revealed: Unknown
1663+
reveal_type(f2(x2=1, x1=2)) # revealed: Any
16261664
reveal_type(f2(1, 2, z=3)) # revealed: tuple[Any, ...]
16271665

16281666
reveal_type(f3(1, 2)) # revealed: tuple[Literal[1], Literal[2]]
16291667
reveal_type(f3(1, 2, 3)) # revealed: tuple[Any, ...]
16301668
# TODO: Should be `tuple[Literal[1], Literal[2]]`
1631-
reveal_type(f3(x1=1, x2=2)) # revealed: Unknown
1669+
reveal_type(f3(x1=1, x2=2)) # revealed: Any
16321670
reveal_type(f3(z=1)) # revealed: dict[str, Any]
16331671

16341672
# error: [no-matching-overload]
@@ -1785,8 +1823,8 @@ from typing import Any
17851823
from overloaded import A, B, C, f
17861824

17871825
def _(arg: tuple[A | B, Any]):
1788-
reveal_type(f(arg)) # revealed: A | Unknown
1789-
reveal_type(f(*(arg,))) # revealed: A | Unknown
1826+
reveal_type(f(arg)) # revealed: A | Any
1827+
reveal_type(f(*(arg,))) # revealed: A | Any
17901828
```
17911829

17921830
#### Both argument lists ambiguous
@@ -1819,8 +1857,8 @@ from typing import Any
18191857
from overloaded import A, B, C, f
18201858

18211859
def _(arg: tuple[A | B, Any]):
1822-
reveal_type(f(arg)) # revealed: Unknown
1823-
reveal_type(f(*(arg,))) # revealed: Unknown
1860+
reveal_type(f(arg)) # revealed: Any
1861+
reveal_type(f(*(arg,))) # revealed: Any
18241862
```
18251863

18261864
### Unknown argument with TypeVar overload
@@ -1850,10 +1888,10 @@ from nonexistent_module import something_unknown # error: [unresolved-import]
18501888

18511889
reveal_type(something_unknown) # revealed: Unknown
18521890

1853-
# The result should be `Unknown`, not `Literal[b""]`.
1854-
reveal_type(f(something_unknown)) # revealed: Unknown
1855-
reveal_type(f((something_unknown, something_unknown, something_unknown))) # revealed: Unknown
1856-
reveal_type(f((something_unknown, None, something_unknown))) # revealed: Unknown
1891+
# The result should be `Any`, not `Literal[b""]`.
1892+
reveal_type(f(something_unknown)) # revealed: Any
1893+
reveal_type(f((something_unknown, something_unknown, something_unknown))) # revealed: Any
1894+
reveal_type(f((something_unknown, None, something_unknown))) # revealed: Any
18571895

18581896
# Concrete arguments should still resolve correctly.
18591897
def _(s: str):

crates/ty_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Argument_type_expans…_-_Optimization___Limit_…_(cd61048adbc17331).snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/call/overloads.md
3838
4 | def _(a: int | None):
3939
5 | reveal_type(
4040
6 | # error: [no-matching-overload]
41-
7 | # revealed: Unknown
41+
7 | # revealed: Any
4242
8 | f(
4343
9 | A(),
4444
10 | a1=a,
@@ -82,7 +82,7 @@ error[no-matching-overload]: No overload of function `f` matches arguments
8282
--> src/mdtest_snippet.py:8:9
8383
|
8484
6 | # error: [no-matching-overload]
85-
7 | # revealed: Unknown
85+
7 | # revealed: Any
8686
8 | / f(
8787
9 | | A(),
8888
10 | | a1=a,

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,7 @@ def _(
13611361
reveal_type(person[str_key]) # revealed: Unknown
13621362

13631363
# No error here:
1364-
reveal_type(person[unknown_key]) # revealed: Unknown
1364+
reveal_type(person[unknown_key]) # revealed: Any
13651365

13661366
reveal_type(being["name"]) # revealed: str
13671367

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Paramete
4444
use crate::types::tuple::{TupleLength, TupleSpec, TupleType};
4545
use crate::types::{
4646
BoundMethodType, BoundTypeVarIdentity, BoundTypeVarInstance, CallableSignature, CallableType,
47-
CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams,
47+
CallableTypeKind, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, DataclassParams, DynamicType,
4848
FieldInstance, GenericAlias, InternedConstraintSet, IntersectionType, KnownBoundMethodType,
4949
KnownClass, KnownInstanceType, LiteralValueTypeKind, MemberLookupPolicy, NominalInstanceType,
5050
PropertyInstanceType, SpecialFormType, TypeAliasType, TypeContext, TypeVarBoundOrConstraints,
@@ -2557,44 +2557,33 @@ impl<'db> CallableBinding<'db> {
25572557
// unmatched for the given argument types.
25582558
let mut filter_remaining_overloads = false;
25592559

2560-
for (upto, current_index) in matching_overload_indexes.iter().enumerate() {
2561-
if filter_remaining_overloads {
2562-
self.overloads[*current_index].mark_as_unmatched_overload();
2563-
continue;
2564-
}
2565-
2560+
// Build a tuple of participating parameter types for a single overload.
2561+
//
2562+
// A given participating parameter can receive multiple argument sources (for example,
2563+
// through variadics), so we union those types for that parameter index.
2564+
let participating_parameter_tuple_for_overload = |overload: &Binding<'db>| {
25662565
let mut union_parameter_types = std::iter::repeat_with(|| UnionBuilder::new(db))
25672566
.take(max_parameter_count)
25682567
.collect::<Vec<_>>();
25692568

2570-
// The number of parameters that have been skipped because they don't participate in
2571-
// the filtering process. This is used to make sure the types are added to the
2572-
// corresponding parameter index in `union_parameter_types`.
2573-
let mut skipped_parameters = 0;
2574-
25752569
for argument_index in 0..arguments.len() {
2576-
for overload_index in &matching_overload_indexes[..=upto] {
2577-
let overload = &self.overloads[*overload_index];
2578-
for parameter_index in &overload.argument_matches[argument_index].parameters {
2579-
if !participating_parameter_indexes.contains(parameter_index) {
2580-
skipped_parameters += 1;
2581-
continue;
2582-
}
2583-
// TODO: For an unannotated `self` / `cls` parameter, the type should be
2584-
// `typing.Self` / `type[typing.Self]`
2585-
let mut parameter_type =
2586-
overload.signature.parameters()[*parameter_index].annotated_type();
2587-
if let Some(specialization) = overload.specialization {
2588-
parameter_type =
2589-
parameter_type.apply_specialization(db, specialization);
2590-
}
2591-
union_parameter_types[parameter_index.saturating_sub(skipped_parameters)]
2592-
.add_in_place(parameter_type);
2570+
for parameter_index in &overload.argument_matches[argument_index].parameters {
2571+
if !participating_parameter_indexes.contains(parameter_index) {
2572+
continue;
25932573
}
2574+
2575+
// TODO: For an unannotated `self` / `cls` parameter, the type should be
2576+
// `typing.Self` / `type[typing.Self]`
2577+
let mut parameter_type =
2578+
overload.signature.parameters()[*parameter_index].annotated_type();
2579+
if let Some(specialization) = overload.specialization {
2580+
parameter_type = parameter_type.apply_specialization(db, specialization);
2581+
}
2582+
union_parameter_types[*parameter_index].add_in_place(parameter_type);
25942583
}
25952584
}
25962585

2597-
let parameter_types = Type::heterogeneous_tuple(
2586+
Type::heterogeneous_tuple(
25982587
db,
25992588
union_parameter_types.into_iter().filter_map(|builder| {
26002589
if builder.is_empty() {
@@ -2603,9 +2592,30 @@ impl<'db> CallableBinding<'db> {
26032592
Some(builder.build())
26042593
}
26052594
}),
2595+
)
2596+
};
2597+
2598+
for (upto, current_index) in matching_overload_indexes.iter().enumerate() {
2599+
if filter_remaining_overloads {
2600+
self.overloads[*current_index].mark_as_unmatched_overload();
2601+
continue;
2602+
}
2603+
2604+
// Use a union of per-overload parameter tuples rather than a tuple of per-parameter
2605+
// unions, so we preserve cross-argument correlations from each overload.
2606+
let parameter_types = UnionType::from_elements(
2607+
db,
2608+
matching_overload_indexes[..=upto]
2609+
.iter()
2610+
.map(|overload_index| {
2611+
participating_parameter_tuple_for_overload(&self.overloads[*overload_index])
2612+
}),
26062613
);
26072614

2608-
if top_materialized_argument_type.is_assignable_to(db, parameter_types) {
2615+
if top_materialized_argument_type
2616+
.when_assignable_to(db, parameter_types, InferableTypeVars::None)
2617+
.is_always_satisfied(db)
2618+
{
26092619
filter_remaining_overloads = true;
26102620
}
26112621
}
@@ -2763,7 +2773,7 @@ impl<'db> CallableBinding<'db> {
27632773
return match overload_call_return_type {
27642774
OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type,
27652775
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(_)
2766-
| OverloadCallReturnType::Ambiguous => Type::unknown(),
2776+
| OverloadCallReturnType::Ambiguous => Type::Dynamic(DynamicType::Any),
27672777
};
27682778
}
27692779
if let Some((_, first_overload)) = self.matching_overloads().next() {

0 commit comments

Comments
 (0)