Skip to content

Commit c41af09

Browse files
author
Grant Wuerker
committed
ADT recursion
1 parent 574c1f0 commit c41af09

File tree

9 files changed

+228
-58
lines changed

9 files changed

+228
-58
lines changed

Diff for: Cargo.lock

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: crates/common2/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ camino = "1.1.4"
1515
smol_str = "0.1.24"
1616
salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" }
1717
parser = { path = "../parser2", package = "fe-parser2" }
18+
rustc-hash = "1.1.0"
19+
ena = "0.14"

Diff for: crates/common2/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod diagnostics;
22
pub mod input;
3+
pub mod recursive_def;
34

45
pub use input::{InputFile, InputIngot};
56

Diff for: crates/common2/src/recursive_def.rs

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
use std::{fmt::Debug, hash::Hash};
2+
3+
use ena::unify::{InPlaceUnificationTable, UnifyKey};
4+
use rustc_hash::FxHashMap;
5+
6+
/// Represents a definition that contains a direct reference to itself.
7+
///
8+
/// Recursive definitions are not valid and must be reported to the user.
9+
/// It is preferable to group definitions together such that recursions
10+
/// are reported in-whole rather than separately. `RecursiveDef` can be
11+
/// used with `RecursiveDefHelper` to perform this grouping operation.
12+
///
13+
/// The fields `from` and `to` are the relevant identifiers and `site` can
14+
/// be used to carry diagnostic information.
15+
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
16+
pub struct RecursiveDef<T, U>
17+
where
18+
T: PartialEq + Copy,
19+
{
20+
pub from: T,
21+
pub to: T,
22+
pub site: U,
23+
}
24+
25+
impl<T, U> RecursiveDef<T, U>
26+
where
27+
T: PartialEq + Copy,
28+
{
29+
pub fn new(from: T, to: T, site: U) -> Self {
30+
Self { from, to, site }
31+
}
32+
}
33+
34+
#[derive(PartialEq, Debug, Clone, Copy)]
35+
struct RecursiveDefKey(u32);
36+
37+
impl UnifyKey for RecursiveDefKey {
38+
type Value = ();
39+
40+
fn index(&self) -> u32 {
41+
self.0
42+
}
43+
44+
fn from_index(idx: u32) -> Self {
45+
Self(idx)
46+
}
47+
48+
fn tag() -> &'static str {
49+
"RecursiveDefKey"
50+
}
51+
}
52+
53+
pub struct RecursiveDefHelper<T, U>
54+
where
55+
T: Eq + Clone + Debug + Copy,
56+
{
57+
defs: Vec<RecursiveDef<T, U>>,
58+
table: InPlaceUnificationTable<RecursiveDefKey>,
59+
keys: FxHashMap<T, RecursiveDefKey>,
60+
}
61+
62+
impl<T, U> RecursiveDefHelper<T, U>
63+
where
64+
T: Eq + Clone + Debug + Copy + Hash,
65+
{
66+
pub fn new(defs: Vec<RecursiveDef<T, U>>) -> Self {
67+
let mut table = InPlaceUnificationTable::new();
68+
let keys: FxHashMap<_, _> = defs
69+
.iter()
70+
.map(|def| (def.from, table.new_key(())))
71+
.collect();
72+
73+
for def in defs.iter() {
74+
table.union(keys[&def.from], keys[&def.to])
75+
}
76+
77+
Self { defs, table, keys }
78+
}
79+
80+
/// Removes a disjoint set of recursive definitions from the helper
81+
/// and returns it, if one exists.
82+
pub fn remove_disjoint_set(&mut self) -> Option<Vec<RecursiveDef<T, U>>> {
83+
let mut disjoint_set = vec![];
84+
let mut remaining_set = vec![];
85+
let mut union_key: Option<&RecursiveDefKey> = None;
86+
87+
while let Some(def) = self.defs.pop() {
88+
let cur_key = &self.keys[&def.from];
89+
90+
if union_key.is_none() || self.table.unioned(*union_key.unwrap(), *cur_key) {
91+
disjoint_set.push(def)
92+
} else {
93+
remaining_set.push(def)
94+
}
95+
96+
if union_key.is_none() {
97+
union_key = Some(cur_key)
98+
}
99+
}
100+
101+
self.defs = remaining_set;
102+
103+
if union_key.is_some() {
104+
Some(disjoint_set)
105+
} else {
106+
None
107+
}
108+
}
109+
}
110+
111+
#[test]
112+
fn one_recursion() {
113+
let defs = vec![RecursiveDef::new(0, 1, ()), RecursiveDef::new(1, 0, ())];
114+
let mut helper = RecursiveDefHelper::new(defs);
115+
assert!(helper.remove_disjoint_set().is_some());
116+
assert!(helper.remove_disjoint_set().is_none());
117+
}
118+
119+
#[test]
120+
fn two_recursions() {
121+
let defs = vec![
122+
RecursiveDef::new(0, 1, ()),
123+
RecursiveDef::new(1, 0, ()),
124+
RecursiveDef::new(2, 3, ()),
125+
RecursiveDef::new(3, 4, ()),
126+
RecursiveDef::new(4, 2, ()),
127+
];
128+
let mut helper = RecursiveDefHelper::new(defs);
129+
assert!(helper.remove_disjoint_set().is_some());
130+
assert!(helper.remove_disjoint_set().is_some());
131+
assert!(helper.remove_disjoint_set().is_none());
132+
}

Diff for: crates/hir-analysis/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub struct Jar(
7272
ty::diagnostics::ImplTraitDefDiagAccumulator,
7373
ty::diagnostics::ImplDefDiagAccumulator,
7474
ty::diagnostics::FuncDefDiagAccumulator,
75+
ty::diagnostics::RecursiveAdtDefAccumulator,
7576
);
7677

7778
pub trait HirAnalysisDb: salsa::DbWithJar<Jar> + HirDb {

Diff for: crates/hir-analysis/src/ty/def_analysis.rs

+18-11
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ use super::{
2121
collect_impl_block_constraints, collect_super_traits, AssumptionListId, SuperTraitCycle,
2222
},
2323
constraint_solver::{is_goal_satisfiable, GoalSatisfiability},
24-
diagnostics::{ImplDiag, TraitConstraintDiag, TraitLowerDiag, TyDiagCollection, TyLowerDiag},
24+
diagnostics::{
25+
ImplDiag, RecursiveAdtDef, TraitConstraintDiag, TraitLowerDiag, TyDiagCollection,
26+
TyLowerDiag,
27+
},
2528
trait_def::{ingot_trait_env, Implementor, TraitDef, TraitMethod},
2629
trait_lower::{lower_trait, lower_trait_ref, TraitRefLowerError},
2730
ty_def::{AdtDef, AdtRef, AdtRefId, FuncDef, InvalidCause, TyData, TyId},
@@ -33,7 +36,8 @@ use crate::{
3336
ty::{
3437
diagnostics::{
3538
AdtDefDiagAccumulator, FuncDefDiagAccumulator, ImplDefDiagAccumulator,
36-
ImplTraitDefDiagAccumulator, TraitDefDiagAccumulator, TypeAliasDefDiagAccumulator,
39+
ImplTraitDefDiagAccumulator, RecursiveAdtDefAccumulator, TraitDefDiagAccumulator,
40+
TypeAliasDefDiagAccumulator,
3741
},
3842
method_table::collect_methods,
3943
trait_lower::lower_impl_trait,
@@ -62,8 +66,8 @@ pub fn analyze_adt(db: &dyn HirAnalysisDb, adt_ref: AdtRefId) {
6266
AdtDefDiagAccumulator::push(db, diag);
6367
}
6468

65-
if let Some(diag) = check_recursive_adt(db, adt_ref) {
66-
AdtDefDiagAccumulator::push(db, diag);
69+
if let Some(def) = check_recursive_adt(db, adt_ref) {
70+
RecursiveAdtDefAccumulator::push(db, def);
6771
}
6872
}
6973

@@ -764,7 +768,7 @@ impl<'db> Visitor for DefAnalyzer<'db> {
764768
pub(crate) fn check_recursive_adt(
765769
db: &dyn HirAnalysisDb,
766770
adt: AdtRefId,
767-
) -> Option<TyDiagCollection> {
771+
) -> Option<RecursiveAdtDef> {
768772
let adt_def = lower_adt(db, adt);
769773
for field in adt_def.fields(db) {
770774
for ty in field.iter_types(db) {
@@ -781,7 +785,7 @@ fn check_recursive_adt_impl(
781785
db: &dyn HirAnalysisDb,
782786
cycle: &salsa::Cycle,
783787
adt: AdtRefId,
784-
) -> Option<TyDiagCollection> {
788+
) -> Option<RecursiveAdtDef> {
785789
let participants: FxHashSet<_> = cycle
786790
.participant_keys()
787791
.map(|key| check_recursive_adt::key_from_id(key.key_index()))
@@ -792,11 +796,14 @@ fn check_recursive_adt_impl(
792796
for (ty_idx, ty) in field.iter_types(db).enumerate() {
793797
for field_adt_ref in ty.collect_direct_adts(db) {
794798
if participants.contains(&field_adt_ref) && participants.contains(&adt) {
795-
let diag = TyLowerDiag::recursive_type(
796-
adt.name_span(db),
797-
adt_def.variant_ty_span(db, field_idx, ty_idx),
798-
);
799-
return Some(diag.into());
799+
return Some(RecursiveAdtDef::new(
800+
adt,
801+
field_adt_ref,
802+
(
803+
adt.name_span(db),
804+
adt_def.variant_ty_span(db, field_idx, ty_idx),
805+
),
806+
));
800807
}
801808
}
802809
}

Diff for: crates/hir-analysis/src/ty/diagnostics.rs

+30-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use std::collections::BTreeSet;
22

3-
use common::diagnostics::{
4-
CompleteDiagnostic, DiagnosticPass, GlobalErrorCode, LabelStyle, Severity, SubDiagnostic,
3+
use common::{
4+
diagnostics::{
5+
CompleteDiagnostic, DiagnosticPass, GlobalErrorCode, LabelStyle, Severity, SubDiagnostic,
6+
},
7+
recursive_def::RecursiveDef,
58
};
69
use hir::{
710
diagnostics::DiagnosticVoucher,
@@ -11,11 +14,12 @@ use hir::{
1114
};
1215
use itertools::Itertools;
1316

17+
use crate::HirAnalysisDb;
18+
1419
use super::{
1520
constraint::PredicateId,
16-
ty_def::{Kind, TyId},
21+
ty_def::{AdtRefId, Kind, TyId},
1722
};
18-
use crate::HirAnalysisDb;
1923

2024
#[salsa::accumulator]
2125
pub struct AdtDefDiagAccumulator(pub(super) TyDiagCollection);
@@ -29,6 +33,10 @@ pub struct ImplDefDiagAccumulator(pub(super) TyDiagCollection);
2933
pub struct FuncDefDiagAccumulator(pub(super) TyDiagCollection);
3034
#[salsa::accumulator]
3135
pub struct TypeAliasDefDiagAccumulator(pub(super) TyDiagCollection);
36+
#[salsa::accumulator]
37+
pub struct RecursiveAdtDefAccumulator(pub(super) RecursiveAdtDef);
38+
39+
pub type RecursiveAdtDef = RecursiveDef<AdtRefId, (DynLazySpan, DynLazySpan)>;
3240

3341
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::From)]
3442
pub enum TyDiagCollection {
@@ -53,10 +61,7 @@ impl TyDiagCollection {
5361
pub enum TyLowerDiag {
5462
ExpectedStarKind(DynLazySpan),
5563
InvalidTypeArgKind(DynLazySpan, String),
56-
RecursiveType {
57-
primary_span: DynLazySpan,
58-
field_span: DynLazySpan,
59-
},
64+
AdtRecursion(Vec<RecursiveAdtDef>),
6065

6166
UnboundTypeAliasParam {
6267
span: DynLazySpan,
@@ -140,11 +145,8 @@ impl TyLowerDiag {
140145
Self::InvalidTypeArgKind(span, msg)
141146
}
142147

143-
pub(super) fn recursive_type(primary_span: DynLazySpan, field_span: DynLazySpan) -> Self {
144-
Self::RecursiveType {
145-
primary_span,
146-
field_span,
147-
}
148+
pub(super) fn adt_recursion(defs: Vec<RecursiveAdtDef>) -> Self {
149+
Self::AdtRecursion(defs)
148150
}
149151

150152
pub(super) fn unbound_type_alias_param(
@@ -249,7 +251,7 @@ impl TyLowerDiag {
249251
match self {
250252
Self::ExpectedStarKind(_) => 0,
251253
Self::InvalidTypeArgKind(_, _) => 1,
252-
Self::RecursiveType { .. } => 2,
254+
Self::AdtRecursion { .. } => 2,
253255
Self::UnboundTypeAliasParam { .. } => 3,
254256
Self::TypeAliasCycle { .. } => 4,
255257
Self::InconsistentKindBound(_, _) => 5,
@@ -270,7 +272,7 @@ impl TyLowerDiag {
270272
match self {
271273
Self::ExpectedStarKind(_) => "expected `*` kind in this context".to_string(),
272274
Self::InvalidTypeArgKind(_, _) => "invalid type argument kind".to_string(),
273-
Self::RecursiveType { .. } => "recursive type is not allowed".to_string(),
275+
Self::AdtRecursion { .. } => "recursive type is not allowed".to_string(),
274276

275277
Self::UnboundTypeAliasParam { .. } => {
276278
"all type parameters of type alias must be given".to_string()
@@ -326,22 +328,23 @@ impl TyLowerDiag {
326328
span.resolve(db),
327329
)],
328330

329-
Self::RecursiveType {
330-
primary_span,
331-
field_span,
332-
} => {
333-
vec![
334-
SubDiagnostic::new(
331+
Self::AdtRecursion(defs) => {
332+
let mut diags = vec![];
333+
334+
for RecursiveAdtDef { site, .. } in defs {
335+
diags.push(SubDiagnostic::new(
335336
LabelStyle::Primary,
336337
"recursive type definition".to_string(),
337-
primary_span.resolve(db),
338-
),
339-
SubDiagnostic::new(
338+
site.0.resolve(db),
339+
));
340+
diags.push(SubDiagnostic::new(
340341
LabelStyle::Secondary,
341342
"recursion occurs here".to_string(),
342-
field_span.resolve(db),
343-
),
344-
]
343+
site.1.resolve(db),
344+
));
345+
}
346+
347+
diags
345348
}
346349

347350
Self::UnboundTypeAliasParam {

0 commit comments

Comments
 (0)