Skip to content

Commit d345652

Browse files
committed
[ty] Add a materialization visitor
1 parent 0abbde5 commit d345652

File tree

7 files changed

+268
-23
lines changed

7 files changed

+268
-23
lines changed

crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,54 @@ static_assert(is_subtype_of(Bottom[JsonDict], Bottom[JsonDict]))
489489
static_assert(is_subtype_of(Bottom[JsonDict], Top[JsonDict]))
490490
```
491491

492+
### Equivalence of top materializations of mutually recursive invariant aliases
493+
494+
```py
495+
from typing import Callable
496+
from ty_extensions import static_assert, is_equivalent_to, is_subtype_of, Top
497+
498+
class Box[T]:
499+
pass
500+
501+
type A = Callable[[B], None]
502+
type B = Callable[[A], None]
503+
504+
static_assert(is_equivalent_to(Top[Box[A]], Top[Box[B]]))
505+
static_assert(is_subtype_of(Top[Box[A]], Top[Box[B]]))
506+
static_assert(is_subtype_of(Top[Box[B]], Top[Box[A]]))
507+
```
508+
509+
### Assignment through recursive aliases
510+
511+
```py
512+
from __future__ import annotations
513+
514+
type JSON = str | int | float | bool | list[JSON] | list[JSON_OBJECT] | dict[str, JSON] | None
515+
type JSON_OBJECT = dict[str, JSON]
516+
517+
x: JSON_OBJECT = {"hello": 23}
518+
519+
def f() -> JSON_OBJECT:
520+
return {"hello": 23}
521+
```
522+
523+
### Recursive dict alias in method return
524+
525+
```py
526+
from __future__ import annotations
527+
from dataclasses import dataclass
528+
529+
type NodeDict = dict[str, str | list[NodeDict]]
530+
531+
@dataclass
532+
class Node:
533+
label: str
534+
children: list[Node]
535+
536+
def to_dict(self) -> NodeDict:
537+
return {"label": self.label, "children": [child.to_dict() for child in self.children]}
538+
```
539+
492540
### Cyclic defaults
493541

494542
```py

crates/ty_python_semantic/src/types.rs

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ use ruff_diagnostics::{Edit, Fix};
44
use rustc_hash::FxHashMap;
55

66
use std::borrow::Cow;
7+
use std::cell::RefCell;
78
use std::iter;
9+
use std::rc::Rc;
810
use std::time::Duration;
911

1012
use bitflags::bitflags;
@@ -235,8 +237,108 @@ fn definition_expression_type<'db>(
235237
}
236238
}
237239

240+
struct ApplyDefaultTypeMapping;
241+
struct ApplyTopMaterialization;
242+
struct ApplyBottomMaterialization;
243+
struct ApplyMaterializationEquivalence;
244+
245+
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
246+
enum ActiveTypeMapping {
247+
Default,
248+
TopMaterialization,
249+
BottomMaterialization,
250+
}
251+
238252
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
239-
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;
253+
///
254+
/// Materialization is the only mapping mode that needs to visit the same type under two different
255+
/// mappings within a single recursive call chain (`Top` and `Bottom`). Keep separate cycle caches
256+
/// for those modes so invariant checks can safely reuse one visitor.
257+
pub(crate) struct ApplyTypeMappingVisitor<'db> {
258+
default: TypeTransformer<'db, ApplyDefaultTypeMapping>,
259+
top_materialization: TypeTransformer<'db, ApplyTopMaterialization>,
260+
bottom_materialization: TypeTransformer<'db, ApplyBottomMaterialization>,
261+
materialization_equivalence:
262+
Rc<CycleDetector<ApplyMaterializationEquivalence, (Type<'db>, Type<'db>), bool>>,
263+
active_type_mappings: RefCell<Vec<ActiveTypeMapping>>,
264+
}
265+
266+
impl<'db> ApplyTypeMappingVisitor<'db> {
267+
pub(crate) fn visit(&self, ty: Type<'db>, func: impl FnOnce() -> Type<'db>) -> Type<'db> {
268+
let active_type_mapping = self
269+
.active_type_mappings
270+
.borrow()
271+
.last()
272+
.copied()
273+
.unwrap_or(ActiveTypeMapping::Default);
274+
275+
match active_type_mapping {
276+
ActiveTypeMapping::Default => self.default.visit(ty, func),
277+
ActiveTypeMapping::TopMaterialization => self.top_materialization.visit(ty, func),
278+
ActiveTypeMapping::BottomMaterialization => self.bottom_materialization.visit(ty, func),
279+
}
280+
}
281+
282+
pub(crate) fn with_type_mapping<T>(
283+
&self,
284+
type_mapping: &TypeMapping<'_, 'db>,
285+
func: impl FnOnce() -> T,
286+
) -> T {
287+
let active_type_mapping = match type_mapping {
288+
TypeMapping::Materialize(MaterializationKind::Top) => {
289+
ActiveTypeMapping::TopMaterialization
290+
}
291+
TypeMapping::Materialize(MaterializationKind::Bottom) => {
292+
ActiveTypeMapping::BottomMaterialization
293+
}
294+
_ => ActiveTypeMapping::Default,
295+
};
296+
297+
self.active_type_mappings
298+
.borrow_mut()
299+
.push(active_type_mapping);
300+
301+
let result = func();
302+
303+
let previous = self.active_type_mappings.borrow_mut().pop();
304+
debug_assert_eq!(previous, Some(active_type_mapping));
305+
306+
result
307+
}
308+
309+
pub(crate) fn is_equivalent_to_materialization(
310+
&self,
311+
db: &'db dyn Db,
312+
left: Type<'db>,
313+
right: Type<'db>,
314+
) -> bool {
315+
self.materialization_equivalence.visit((left, right), || {
316+
left.is_equivalent_to_with_materialization_visitor(db, right, self)
317+
})
318+
}
319+
320+
pub(crate) fn for_new_materialization_root(&self) -> Self {
321+
Self {
322+
default: TypeTransformer::default(),
323+
top_materialization: TypeTransformer::default(),
324+
bottom_materialization: TypeTransformer::default(),
325+
materialization_equivalence: Rc::clone(&self.materialization_equivalence),
326+
active_type_mappings: RefCell::default(),
327+
}
328+
}
329+
}
330+
331+
impl<'db> Default for ApplyTypeMappingVisitor<'db> {
332+
fn default() -> Self {
333+
Self {
334+
default: TypeTransformer::default(),
335+
top_materialization: TypeTransformer::default(),
336+
bottom_materialization: TypeTransformer::default(),
337+
materialization_equivalence: Rc::new(CycleDetector::new(true)),
338+
active_type_mappings: RefCell::default(),
339+
}
340+
}
341+
}
240342

241343
/// A [`CycleDetector`] that is used in `find_legacy_typevars` methods.
242344
pub(crate) type FindLegacyTypeVarsVisitor<'db> = CycleDetector<FindLegacyTypeVars, Type<'db>, ()>;
@@ -5516,7 +5618,7 @@ impl<'db> Type<'db> {
55165618
return self;
55175619
}
55185620

5519-
match self {
5621+
visitor.with_type_mapping(type_mapping, || match self {
55205622
Type::TypeVar(bound_typevar) => bound_typevar.apply_type_mapping_impl(db, type_mapping, visitor),
55215623
Type::KnownInstance(known_instance) => known_instance.apply_type_mapping_impl(db, type_mapping, tcx, visitor),
55225624

@@ -5769,7 +5871,7 @@ impl<'db> Type<'db> {
57695871
| Type::ClassLiteral(_)
57705872
| Type::BoundSuper(_)
57715873
| Type::SpecialForm(_) => self,
5772-
}
5874+
})
57735875
}
57745876

57755877
/// Locates any legacy `TypeVar`s in this type, and adds them to a set. This is used to build

crates/ty_python_semantic/src/types/class.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,11 +1124,13 @@ impl<'db> ClassType<'db> {
11241124
let constraints = ConstraintSetBuilder::new();
11251125
let relation_visitor = HasRelationToVisitor::default(&constraints);
11261126
let disjointness_visitor = IsDisjointVisitor::default(&constraints);
1127+
let materialization_visitor = ApplyTypeMappingVisitor::default();
11271128
let checker = TypeRelationChecker::subtyping(
11281129
&constraints,
11291130
InferableTypeVars::None,
11301131
&relation_visitor,
11311132
&disjointness_visitor,
1133+
&materialization_visitor,
11321134
);
11331135
checker
11341136
.check_class_pair(db, self, target)

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,11 @@ impl<'db> Specialization<'db> {
12301230
TypeVarVariance::Invariant => {
12311231
let top_materialization =
12321232
vartype.materialize(db, MaterializationKind::Top, visitor);
1233-
if !vartype.is_equivalent_to(db, top_materialization) {
1233+
if !visitor.is_equivalent_to_materialization(
1234+
db,
1235+
*vartype,
1236+
top_materialization,
1237+
) {
12341238
has_dynamic_invariant_typevar = true;
12351239
}
12361240
*vartype
@@ -1270,11 +1274,13 @@ impl<'db> Specialization<'db> {
12701274
) -> ConstraintSet<'db, 'c> {
12711275
let relation_visitor = HasRelationToVisitor::default(constraints);
12721276
let disjointness_visitor = IsDisjointVisitor::default(constraints);
1277+
let materialization_visitor = ApplyTypeMappingVisitor::default();
12731278
let checker = DisjointnessChecker::new(
12741279
constraints,
12751280
inferable,
12761281
&relation_visitor,
12771282
&disjointness_visitor,
1283+
&materialization_visitor,
12781284
);
12791285
checker.check_specialization_pair(db, self, other)
12801286
}
@@ -1455,10 +1461,20 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> {
14551461
target_type: Type<'db>,
14561462
target_materialization: MaterializationKind,
14571463
) -> ConstraintSet<'db, 'c> {
1458-
let source_top = source_type.top_materialization(db);
1459-
let source_bottom = source_type.bottom_materialization(db);
1460-
let target_top = target_type.top_materialization(db);
1461-
let target_bottom = target_type.bottom_materialization(db);
1464+
let source_top =
1465+
source_type.materialize(db, MaterializationKind::Top, self.materialization_visitor);
1466+
let source_bottom = source_type.materialize(
1467+
db,
1468+
MaterializationKind::Bottom,
1469+
self.materialization_visitor,
1470+
);
1471+
let target_top =
1472+
target_type.materialize(db, MaterializationKind::Top, self.materialization_visitor);
1473+
let target_bottom = target_type.materialize(
1474+
db,
1475+
MaterializationKind::Bottom,
1476+
self.materialization_visitor,
1477+
);
14621478

14631479
let is_subtype_of = |source: Type<'db>, target: Type<'db>| {
14641480
// TODO:

crates/ty_python_semantic/src/types/instance.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,11 +719,13 @@ impl<'db> ProtocolInstanceType<'db> {
719719
let constraints = ConstraintSetBuilder::new();
720720
let relation_visitor = HasRelationToVisitor::default(&constraints);
721721
let disjointness_visitor = IsDisjointVisitor::default(&constraints);
722+
let materialization_visitor = ApplyTypeMappingVisitor::default();
722723
let checker = TypeRelationChecker::subtyping(
723724
&constraints,
724725
InferableTypeVars::None,
725726
&relation_visitor,
726727
&disjointness_visitor,
728+
&materialization_visitor,
727729
);
728730
checker
729731
.check_type_satisfies_protocol(db, Type::object(), protocol)

0 commit comments

Comments
 (0)