Skip to content

Commit 8f2df05

Browse files
authored
Improved var-based inference diagnostics locations. (#9006)
1 parent e562879 commit 8f2df05

File tree

4 files changed

+80
-43
lines changed

4 files changed

+80
-43
lines changed

crates/cairo-lang-semantic/src/expr/inference.rs

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,21 @@ pub type InferenceResult<T> = Result<T, ErrorSet>;
342342

343343
#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
344344
enum InferenceErrorStatus<'db> {
345-
Pending(InferenceError<'db>),
345+
/// There is a pending error.
346+
Pending(PendingInferenceError<'db>),
347+
/// There was an error but it was already consumed.
346348
Consumed(DiagnosticAdded),
347349
}
348350

351+
/// A pending inference error.
352+
#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
353+
struct PendingInferenceError<'db> {
354+
/// The actual error.
355+
err: InferenceError<'db>,
356+
/// The optional location of the error.
357+
stable_ptr: Option<SyntaxStablePtrId<'db>>,
358+
}
359+
349360
/// A mapping of an impl var's trait items to concrete items
350361
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)]
351362
pub struct ImplVarTraitItemMappings<'db> {
@@ -743,23 +754,18 @@ impl<'db, 'id> Inference<'db, 'id> {
743754
/// Returns whether the inference was successful. If not, the error may be found by
744755
/// `.error_state()`.
745756
pub fn solve(&mut self) -> InferenceResult<()> {
746-
self.solve_ex().map_err(|(err_set, _)| err_set)
747-
}
748-
749-
/// Same as `solve`, but returns the error stable pointer if an error occurred.
750-
fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId<'db>>)> {
751757
let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
752758
self.pending.extend(ambiguous.map(|(var, _)| var));
753759
while let Some(var) = self.pending.pop_front() {
754760
// First inference error stops inference.
755-
self.solve_single_pending(var).map_err(|err_set| {
756-
(err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
761+
self.solve_single_pending(var).inspect_err(|_err_set| {
762+
self.add_error_stable_ptr(InferenceVar::Impl(var));
757763
})?;
758764
}
759765
while let Some(var) = self.negative_pending.pop_front() {
760766
// First inference error stops inference.
761-
self.solve_single_negative_pending(var).map_err(|err_set| {
762-
(err_set, self.stable_ptrs.get(&InferenceVar::NegativeImpl(var)).copied())
767+
self.solve_single_negative_pending(var).inspect_err(|_err_set| {
768+
self.add_error_stable_ptr(InferenceVar::NegativeImpl(var));
763769
})?;
764770
}
765771
Ok(())
@@ -845,12 +851,9 @@ impl<'db, 'id> Inference<'db, 'id> {
845851

846852
/// Finalizes the inference by inferring uninferred numeric literals as felt252.
847853
/// Returns an error and does not report it.
848-
pub fn finalize_without_reporting(
849-
&mut self,
850-
) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId<'db>>)> {
854+
pub fn finalize_without_reporting(&mut self) -> Result<(), ErrorSet> {
851855
if self.error_status.is_err() {
852-
// TODO(yuval): consider adding error location to the set error.
853-
return Err((ErrorSet, None));
856+
return Err(ErrorSet);
854857
}
855858
let info = self.db.core_info();
856859
let numeric_trait_id = info.numeric_literal_trt;
@@ -859,7 +862,7 @@ impl<'db, 'id> Inference<'db, 'id> {
859862
// Conform all uninferred numeric literals to felt252.
860863
loop {
861864
let mut changed = false;
862-
self.solve_ex()?;
865+
self.solve()?;
863866
for (var, _) in self.ambiguous.clone() {
864867
let impl_var = self.impl_var(var).clone();
865868
if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
@@ -873,8 +876,8 @@ impl<'db, 'id> Inference<'db, 'id> {
873876
if self.rewrite(ty).no_err() == felt_ty {
874877
continue;
875878
}
876-
self.conform_ty(ty, felt_ty).map_err(|err_set| {
877-
(err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
879+
self.conform_ty(ty, felt_ty).inspect_err(|_err_set| {
880+
self.add_error_stable_ptr(InferenceVar::Impl(impl_var.id));
878881
})?;
879882
changed = true;
880883
break;
@@ -891,7 +894,7 @@ impl<'db, 'id> Inference<'db, 'id> {
891894
let Some((var, err)) = self.first_undetermined_variable() else {
892895
return Ok(());
893896
};
894-
Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
897+
Err(self.set_error_on_var(err, var))
895898
}
896899

897900
/// Finalizes the inference and report diagnostics if there are any errors.
@@ -902,12 +905,8 @@ impl<'db, 'id> Inference<'db, 'id> {
902905
diagnostics: &mut SemanticDiagnostics<'db>,
903906
stable_ptr: SyntaxStablePtrId<'db>,
904907
) {
905-
if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
906-
let diag = self.report_on_pending_error(
907-
err_set,
908-
diagnostics,
909-
err_stable_ptr.unwrap_or(stable_ptr),
910-
);
908+
if let Err(err_set) = self.finalize_without_reporting() {
909+
let diag = self.report_on_pending_error(err_set, diagnostics, stable_ptr);
911910

912911
let ty_missing = TypeId::missing(self.db, diag);
913912
for var in &self.data.type_vars {
@@ -991,7 +990,8 @@ impl<'db, 'id> Inference<'db, 'id> {
991990
}
992991
if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
993992
{
994-
return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
993+
let inference_var = InferenceVar::Impl(var);
994+
return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
995995
}
996996
self.impl_assignment.insert(var, impl_id);
997997
if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
@@ -1005,9 +1005,11 @@ impl<'db, 'id> Inference<'db, 'id> {
10051005
let ty0 = self.rewrite(ty).no_err();
10061006
let ty1 = self.rewrite(impl_ty).no_err();
10071007

1008-
let error =
1009-
InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
1010-
self.error_status = Err(InferenceErrorStatus::Pending(error));
1008+
let err = InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
1009+
self.error_status = Err(InferenceErrorStatus::Pending(PendingInferenceError {
1010+
err,
1011+
stable_ptr: self.stable_ptrs.get(&InferenceVar::Impl(var)).cloned(),
1012+
}));
10111013
return Err(err_set);
10121014
}
10131015
}
@@ -1100,7 +1102,7 @@ impl<'db, 'id> Inference<'db, 'id> {
11001102
assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
11011103
let inference_var = InferenceVar::Type(var.id);
11021104
if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
1103-
return Err(self.set_error(InferenceError::Cycle(inference_var)));
1105+
return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
11041106
}
11051107
// If assigning var to var - making sure assigning to the lower id for proper canonization.
11061108
if let TypeLongId::Var(other) = ty.long(self.db)
@@ -1328,22 +1330,51 @@ impl<'db, 'id> Inference<'db, 'id> {
13281330
/// Does nothing if an error is already set.
13291331
/// Returns an `ErrorSet` that can be used in reporting the error.
13301332
pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet {
1333+
self.set_error_ex(err, None)
1334+
}
1335+
1336+
/// Sets an error in the inference state, with an optional location for the diagnostics
1337+
/// reporting. Does nothing if an error is already set.
1338+
/// Returns an `ErrorSet` that can be used in reporting the error.
1339+
pub fn set_error_ex(
1340+
&mut self,
1341+
err: InferenceError<'db>,
1342+
stable_ptr: Option<SyntaxStablePtrId<'db>>,
1343+
) -> ErrorSet {
13311344
if self.error_status.is_err() {
13321345
return ErrorSet;
13331346
}
13341347
self.error_status = Err(if let InferenceError::Reported(diag_added) = err {
13351348
InferenceErrorStatus::Consumed(diag_added)
13361349
} else {
1337-
InferenceErrorStatus::Pending(err)
1350+
InferenceErrorStatus::Pending(PendingInferenceError { err, stable_ptr })
13381351
});
13391352
ErrorSet
13401353
}
13411354

1355+
/// Sets an error in the inference state, with a var to fetch location for the diagnostics
1356+
/// reporting. Does nothing if an error is already set.
1357+
/// Returns an `ErrorSet` that can be used in reporting the error.
1358+
pub fn set_error_on_var(&mut self, err: InferenceError<'db>, var: InferenceVar) -> ErrorSet {
1359+
self.set_error_ex(err, self.stable_ptrs.get(&var).cloned())
1360+
}
1361+
13421362
/// Returns whether an error is set (either pending or consumed).
13431363
pub fn is_error_set(&self) -> InferenceResult<()> {
13441364
self.error_status.as_ref().copied().map_err(|_| ErrorSet)
13451365
}
13461366

1367+
/// If there is no stable ptr for the pending error, add it by the given var.
1368+
fn add_error_stable_ptr(&mut self, var: InferenceVar) {
1369+
let var_stable_ptr = self.stable_ptrs.get(&var).copied();
1370+
if let Err(InferenceErrorStatus::Pending(PendingInferenceError { err: _, stable_ptr })) =
1371+
&mut self.error_status
1372+
&& stable_ptr.is_none()
1373+
{
1374+
*stable_ptr = var_stable_ptr;
1375+
}
1376+
}
1377+
13471378
/// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
13481379
/// returns None. This should be used with caution. Always prefer to use
13491380
/// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
@@ -1353,7 +1384,7 @@ impl<'db, 'id> Inference<'db, 'id> {
13531384
&mut self,
13541385
err_set: ErrorSet,
13551386
) -> Option<InferenceError<'db>> {
1356-
self.consume_error_inner(err_set, skip_diagnostic())
1387+
Some(self.consume_error_inner(err_set, skip_diagnostic())?.err)
13571388
}
13581389

13591390
/// Consumes the error that is already reported. If there is no error, or the error is consumed,
@@ -1376,10 +1407,16 @@ impl<'db, 'id> Inference<'db, 'id> {
13761407
&mut self,
13771408
_err_set: ErrorSet,
13781409
diag_added: DiagnosticAdded,
1379-
) -> Option<InferenceError<'db>> {
1410+
) -> Option<PendingInferenceError<'db>> {
13801411
match &mut self.error_status {
13811412
Err(InferenceErrorStatus::Pending(error)) => {
1382-
let pending_error = std::mem::replace(error, InferenceError::Reported(diag_added));
1413+
let pending_error = std::mem::replace(
1414+
error,
1415+
PendingInferenceError {
1416+
err: InferenceError::Reported(diag_added),
1417+
stable_ptr: None,
1418+
},
1419+
);
13831420
self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
13841421
Some(pending_error)
13851422
}
@@ -1404,16 +1441,16 @@ impl<'db, 'id> Inference<'db, 'id> {
14041441
};
14051442
match state_error {
14061443
InferenceErrorStatus::Consumed(diag_added) => *diag_added,
1407-
InferenceErrorStatus::Pending(error) => {
1408-
let diag_added = match error {
1444+
InferenceErrorStatus::Pending(pending) => {
1445+
let diag_added = match &pending.err {
14091446
InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
14101447
// If we have other diagnostics, there is no need to TypeNotInferred.
14111448

14121449
// Note that `diagnostics` is not empty, so it is safe to return
14131450
// 'DiagnosticAdded' here.
14141451
skip_diagnostic()
14151452
}
1416-
diag => diag.report(diagnostics, stable_ptr),
1453+
diag => diag.report(diagnostics, pending.stable_ptr.unwrap_or(stable_ptr)),
14171454
};
14181455
self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
14191456
diag_added
@@ -1428,7 +1465,7 @@ impl<'db, 'id> Inference<'db, 'id> {
14281465
err_set: ErrorSet,
14291466
report: impl FnOnce() -> DiagnosticAdded,
14301467
) {
1431-
if matches!(self.error_status, Err(InferenceErrorStatus::Pending(_))) {
1468+
if matches!(self.error_status, Err(InferenceErrorStatus::Pending { .. })) {
14321469
self.consume_reported_error(err_set, report());
14331470
}
14341471
}

crates/cairo-lang-semantic/src/items/imp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3052,7 +3052,7 @@ fn implicit_impl_impl_semantic_data<'db>(
30523052
let impl_lookup_context = resolver.impl_lookup_context();
30533053
let resolved_impl = concrete_trait_impl_concrete_trait.and_then(|concrete_trait_id| {
30543054
let imp = resolver.inference().new_impl_var(concrete_trait_id, None, impl_lookup_context);
3055-
resolver.inference().finalize_without_reporting().map_err(|(err_set, _)| {
3055+
resolver.inference().finalize_without_reporting().map_err(|err_set| {
30563056
diagnostics.report(
30573057
impl_def_id.stable_ptr(db).untyped(),
30583058
ImplicitImplNotInferred { trait_impl_id, concrete_trait_id },

crates/cairo-lang-semantic/src/items/tests/trait_type

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3185,6 +3185,6 @@ struct S<impl M: MyTrait> {
31853185

31863186
//! > expected_diagnostics
31873187
error: `test::M::InputType` type mismatch: `core::felt252` and `core::integer::u32`.
3188-
--> lib.cairo:12:13
3189-
fn foo() -> S<M> {
3190-
^^^^
3188+
--> lib.cairo:13:5
3189+
S { x: 3_felt252 }
3190+
^

crates/cairo-lang-semantic/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ pub fn get_impl_at_context<'db>(
814814
// It's ok to consume the errors without reporting as this is a helper function meant to find an
815815
// impl and return it, but it's ok if the impl can't be found.
816816
let impl_id = inference.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
817-
if let Err((err_set, _)) = inference.finalize_without_reporting() {
817+
if let Err(err_set) = inference.finalize_without_reporting() {
818818
return Err(inference
819819
.consume_error_without_reporting(err_set)
820820
.expect("Error couldn't be already consumed"));

0 commit comments

Comments
 (0)