Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -2382,6 +2382,27 @@ def _(node: Node, person: Person):
_: Node = Person(name="Alice", parent=Node(name="Bob", parent=Person(name="Charlie", parent=None)))
```

TypedDict constructor calls should also use field type context when inferring nested recursive
values:

```py
from typing import Any, List, TypedDict, Union
from typing_extensions import NotRequired

class Comparison(TypedDict):
field: str
op: NotRequired[str]
value: Any

class Logical(TypedDict):
op: NotRequired[str]
conditions: List["Filter"]

Filter = Union[Comparison, Logical]

logical = Logical(conditions=[Comparison(field="a", value="b")])
```

## Function/assignment syntax

TypedDicts can be created using the functional syntax:
Expand Down
3 changes: 2 additions & 1 deletion crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6887,12 +6887,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(class) = class
&& class.is_typed_dict(self.db())
{
let mut speculative = self.speculate();
validate_typed_dict_constructor(
&self.context,
TypedDictType::new(class),
arguments,
func.as_ref().into(),
|expr| self.expression_type(expr),
|expr, tcx| speculative.infer_expression(expr, tcx),
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
typed_dict,
arguments,
func.into(),
|expr| self.expression_type(expr),
|expr, _| self.expression_type(expr),
);

return Some(Type::TypedDict(typed_dict));
Expand Down
29 changes: 19 additions & 10 deletions crates/ty_python_semantic/src/types/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
typed_dict: TypedDictType<'db>,
arguments: &'ast Arguments,
error_node: AnyNodeRef<'ast>,
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
mut expression_type_fn: impl FnMut(&ast::Expr, TypeContext<'db>) -> Type<'db>,
) {
let db = context.db();

Expand All @@ -923,7 +923,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
typed_dict,
arguments,
error_node,
&expression_type_fn,
&mut expression_type_fn,
);
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
} else if is_single_positional_arg {
Expand All @@ -932,7 +932,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
// Assignability already checks for required keys and type compatibility,
// so we don't need separate validation.
let arg = &arguments.args[0];
let arg_ty = expression_type_fn(arg);
let arg_ty = expression_type_fn(arg, TypeContext::default());
let target_ty = Type::TypedDict(typed_dict);

if !arg_ty.is_assignable_to(db, target_ty) {
Expand All @@ -950,7 +950,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
typed_dict,
arguments,
error_node,
&expression_type_fn,
&mut expression_type_fn,
);
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
}
Expand All @@ -963,9 +963,10 @@ fn validate_from_dict_literal<'db, 'ast>(
typed_dict: TypedDictType<'db>,
arguments: &'ast Arguments,
typed_dict_node: AnyNodeRef<'ast>,
expression_type_fn: &impl Fn(&ast::Expr) -> Type<'db>,
expression_type_fn: &mut impl FnMut(&ast::Expr, TypeContext<'db>) -> Type<'db>,
) -> OrderSet<Name> {
let mut provided_keys = OrderSet::new();
let items = typed_dict.items(context.db());

if let ast::Expr::Dict(dict_expr) = &arguments.args[0] {
// Validate dict entries
Expand All @@ -978,8 +979,11 @@ fn validate_from_dict_literal<'db, 'ast>(
let key = key_value.to_str();
provided_keys.insert(Name::new(key));

// Get the already-inferred argument type
let value_ty = expression_type_fn(&dict_item.value);
let value_tcx = items
.get(key)
.map(|field| TypeContext::new(Some(field.declared_ty)))
.unwrap_or_default();
let value_ty = expression_type_fn(&dict_item.value, value_tcx);
TypedDictKeyAssignment {
context,
typed_dict,
Expand Down Expand Up @@ -1007,9 +1011,10 @@ fn validate_from_keywords<'db, 'ast>(
typed_dict: TypedDictType<'db>,
arguments: &'ast Arguments,
typed_dict_node: AnyNodeRef<'ast>,
expression_type_fn: &impl Fn(&ast::Expr) -> Type<'db>,
expression_type_fn: &mut impl FnMut(&ast::Expr, TypeContext<'db>) -> Type<'db>,
) -> OrderSet<Name> {
let db = context.db();
let items = typed_dict.items(db);

// Collect keys from explicit keyword arguments
let mut provided_keys: OrderSet<Name> = arguments
Expand All @@ -1022,7 +1027,11 @@ fn validate_from_keywords<'db, 'ast>(
for keyword in &arguments.keywords {
if let Some(arg_name) = &keyword.arg {
// Explicit keyword argument: e.g., `name="Alice"`
let value_ty = expression_type_fn(&keyword.value);
let value_tcx = items
.get(arg_name.id.as_str())
.map(|field| TypeContext::new(Some(field.declared_ty)))
.unwrap_or_default();
let value_ty = expression_type_fn(&keyword.value, value_tcx);
TypedDictKeyAssignment {
context,
typed_dict,
Expand All @@ -1041,7 +1050,7 @@ fn validate_from_keywords<'db, 'ast>(
// Unlike positional TypedDict arguments, unpacking passes all keys as explicit
// keyword arguments, so extra keys should be flagged as errors (consistent with
// explicitly providing those keys).
let unpacked_type = expression_type_fn(&keyword.value);
let unpacked_type = expression_type_fn(&keyword.value, TypeContext::default());

// Never and Dynamic types are special: they can have any keys, so we skip
// validation and mark all required keys as provided.
Expand Down
Loading