Skip to content

Commit 479c7ec

Browse files
committed
[ty] Use field type context for TypedDict constructor values
1 parent 1fe1c5f commit 479c7ec

File tree

4 files changed

+43
-12
lines changed

4 files changed

+43
-12
lines changed

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,27 @@ def _(node: Node, person: Person):
23822382
_: Node = Person(name="Alice", parent=Node(name="Bob", parent=Person(name="Charlie", parent=None)))
23832383
```
23842384

2385+
TypedDict constructor calls should also use field type context when inferring nested recursive
2386+
values:
2387+
2388+
```py
2389+
from typing import Any, List, TypedDict, Union
2390+
from typing_extensions import NotRequired
2391+
2392+
class Comparison(TypedDict):
2393+
field: str
2394+
op: NotRequired[str]
2395+
value: Any
2396+
2397+
class Logical(TypedDict):
2398+
op: NotRequired[str]
2399+
conditions: List["Filter"]
2400+
2401+
Filter = Union[Comparison, Logical]
2402+
2403+
logical = Logical(conditions=[Comparison(field="a", value="b")])
2404+
```
2405+
23852406
## Function/assignment syntax
23862407

23872408
TypedDicts can be created using the functional syntax:

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6887,12 +6887,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
68876887
if let Some(class) = class
68886888
&& class.is_typed_dict(self.db())
68896889
{
6890+
let mut speculative = self.speculate();
68906891
validate_typed_dict_constructor(
68916892
&self.context,
68926893
TypedDictType::new(class),
68936894
arguments,
68946895
func.as_ref().into(),
6895-
|expr| self.expression_type(expr),
6896+
|expr, tcx| speculative.infer_expression(expr, tcx),
68966897
);
68976898
}
68986899

crates/ty_python_semantic/src/types/infer/builder/dict.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
4545
typed_dict,
4646
arguments,
4747
func.into(),
48-
|expr| self.expression_type(expr),
48+
|expr, _| self.expression_type(expr),
4949
);
5050

5151
return Some(Type::TypedDict(typed_dict));

crates/ty_python_semantic/src/types/typed_dict.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
906906
typed_dict: TypedDictType<'db>,
907907
arguments: &'ast Arguments,
908908
error_node: AnyNodeRef<'ast>,
909-
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
909+
mut expression_type_fn: impl FnMut(&ast::Expr, TypeContext<'db>) -> Type<'db>,
910910
) {
911911
let db = context.db();
912912

@@ -923,7 +923,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
923923
typed_dict,
924924
arguments,
925925
error_node,
926-
&expression_type_fn,
926+
&mut expression_type_fn,
927927
);
928928
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
929929
} else if is_single_positional_arg {
@@ -932,7 +932,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
932932
// Assignability already checks for required keys and type compatibility,
933933
// so we don't need separate validation.
934934
let arg = &arguments.args[0];
935-
let arg_ty = expression_type_fn(arg);
935+
let arg_ty = expression_type_fn(arg, TypeContext::default());
936936
let target_ty = Type::TypedDict(typed_dict);
937937

938938
if !arg_ty.is_assignable_to(db, target_ty) {
@@ -950,7 +950,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
950950
typed_dict,
951951
arguments,
952952
error_node,
953-
&expression_type_fn,
953+
&mut expression_type_fn,
954954
);
955955
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
956956
}
@@ -963,9 +963,10 @@ fn validate_from_dict_literal<'db, 'ast>(
963963
typed_dict: TypedDictType<'db>,
964964
arguments: &'ast Arguments,
965965
typed_dict_node: AnyNodeRef<'ast>,
966-
expression_type_fn: &impl Fn(&ast::Expr) -> Type<'db>,
966+
expression_type_fn: &mut impl FnMut(&ast::Expr, TypeContext<'db>) -> Type<'db>,
967967
) -> OrderSet<Name> {
968968
let mut provided_keys = OrderSet::new();
969+
let items = typed_dict.items(context.db());
969970

970971
if let ast::Expr::Dict(dict_expr) = &arguments.args[0] {
971972
// Validate dict entries
@@ -978,8 +979,11 @@ fn validate_from_dict_literal<'db, 'ast>(
978979
let key = key_value.to_str();
979980
provided_keys.insert(Name::new(key));
980981

981-
// Get the already-inferred argument type
982-
let value_ty = expression_type_fn(&dict_item.value);
982+
let value_tcx = items
983+
.get(key)
984+
.map(|field| TypeContext::new(Some(field.declared_ty)))
985+
.unwrap_or_default();
986+
let value_ty = expression_type_fn(&dict_item.value, value_tcx);
983987
TypedDictKeyAssignment {
984988
context,
985989
typed_dict,
@@ -1007,9 +1011,10 @@ fn validate_from_keywords<'db, 'ast>(
10071011
typed_dict: TypedDictType<'db>,
10081012
arguments: &'ast Arguments,
10091013
typed_dict_node: AnyNodeRef<'ast>,
1010-
expression_type_fn: &impl Fn(&ast::Expr) -> Type<'db>,
1014+
expression_type_fn: &mut impl FnMut(&ast::Expr, TypeContext<'db>) -> Type<'db>,
10111015
) -> OrderSet<Name> {
10121016
let db = context.db();
1017+
let items = typed_dict.items(db);
10131018

10141019
// Collect keys from explicit keyword arguments
10151020
let mut provided_keys: OrderSet<Name> = arguments
@@ -1022,7 +1027,11 @@ fn validate_from_keywords<'db, 'ast>(
10221027
for keyword in &arguments.keywords {
10231028
if let Some(arg_name) = &keyword.arg {
10241029
// Explicit keyword argument: e.g., `name="Alice"`
1025-
let value_ty = expression_type_fn(&keyword.value);
1030+
let value_tcx = items
1031+
.get(arg_name.id.as_str())
1032+
.map(|field| TypeContext::new(Some(field.declared_ty)))
1033+
.unwrap_or_default();
1034+
let value_ty = expression_type_fn(&keyword.value, value_tcx);
10261035
TypedDictKeyAssignment {
10271036
context,
10281037
typed_dict,
@@ -1041,7 +1050,7 @@ fn validate_from_keywords<'db, 'ast>(
10411050
// Unlike positional TypedDict arguments, unpacking passes all keys as explicit
10421051
// keyword arguments, so extra keys should be flagged as errors (consistent with
10431052
// explicitly providing those keys).
1044-
let unpacked_type = expression_type_fn(&keyword.value);
1053+
let unpacked_type = expression_type_fn(&keyword.value, TypeContext::default());
10451054

10461055
// Never and Dynamic types are special: they can have any keys, so we skip
10471056
// validation and mark all required keys as provided.

0 commit comments

Comments
 (0)