@@ -342,10 +342,21 @@ pub type InferenceResult<T> = Result<T, ErrorSet>;
342342
343343#[ derive( Clone , Debug , Eq , Hash , PartialEq , salsa:: Update ) ]
344344enum InferenceErrorStatus < ' db > {
345- Pending ( InferenceError < ' db > ) ,
345+ /// There is a pending error.
346+ Pending ( PendingInferenceError < ' db > ) ,
347+ /// There was an error but it was already consumed.
346348 Consumed ( DiagnosticAdded ) ,
347349}
348350
351+ /// A pending inference error.
352+ #[ derive( Clone , Debug , Eq , Hash , PartialEq , salsa:: Update ) ]
353+ struct PendingInferenceError < ' db > {
354+ /// The actual error.
355+ err : InferenceError < ' db > ,
356+ /// The optional location of the error.
357+ stable_ptr : Option < SyntaxStablePtrId < ' db > > ,
358+ }
359+
349360/// A mapping of an impl var's trait items to concrete items
350361#[ derive( Debug , Default , PartialEq , Eq , Hash , Clone , SemanticObject , salsa:: Update ) ]
351362pub struct ImplVarTraitItemMappings < ' db > {
@@ -743,23 +754,18 @@ impl<'db, 'id> Inference<'db, 'id> {
743754 /// Returns whether the inference was successful. If not, the error may be found by
744755 /// `.error_state()`.
745756 pub fn solve ( & mut self ) -> InferenceResult < ( ) > {
746- self . solve_ex ( ) . map_err ( |( err_set, _) | err_set)
747- }
748-
749- /// Same as `solve`, but returns the error stable pointer if an error occurred.
750- fn solve_ex ( & mut self ) -> Result < ( ) , ( ErrorSet , Option < SyntaxStablePtrId < ' db > > ) > {
751757 let ambiguous = std:: mem:: take ( & mut self . ambiguous ) . into_iter ( ) ;
752758 self . pending . extend ( ambiguous. map ( |( var, _) | var) ) ;
753759 while let Some ( var) = self . pending . pop_front ( ) {
754760 // First inference error stops inference.
755- self . solve_single_pending ( var) . map_err ( |err_set | {
756- ( err_set , self . stable_ptrs . get ( & InferenceVar :: Impl ( var) ) . copied ( ) )
761+ self . solve_single_pending ( var) . inspect_err ( |_err_set | {
762+ self . add_error_stable_ptr ( InferenceVar :: Impl ( var) ) ;
757763 } ) ?;
758764 }
759765 while let Some ( var) = self . negative_pending . pop_front ( ) {
760766 // First inference error stops inference.
761- self . solve_single_negative_pending ( var) . map_err ( |err_set | {
762- ( err_set , self . stable_ptrs . get ( & InferenceVar :: NegativeImpl ( var) ) . copied ( ) )
767+ self . solve_single_negative_pending ( var) . inspect_err ( |_err_set | {
768+ self . add_error_stable_ptr ( InferenceVar :: NegativeImpl ( var) ) ;
763769 } ) ?;
764770 }
765771 Ok ( ( ) )
@@ -845,12 +851,9 @@ impl<'db, 'id> Inference<'db, 'id> {
845851
846852 /// Finalizes the inference by inferring uninferred numeric literals as felt252.
847853 /// Returns an error and does not report it.
848- pub fn finalize_without_reporting (
849- & mut self ,
850- ) -> Result < ( ) , ( ErrorSet , Option < SyntaxStablePtrId < ' db > > ) > {
854+ pub fn finalize_without_reporting ( & mut self ) -> Result < ( ) , ErrorSet > {
851855 if self . error_status . is_err ( ) {
852- // TODO(yuval): consider adding error location to the set error.
853- return Err ( ( ErrorSet , None ) ) ;
856+ return Err ( ErrorSet ) ;
854857 }
855858 let info = self . db . core_info ( ) ;
856859 let numeric_trait_id = info. numeric_literal_trt ;
@@ -859,7 +862,7 @@ impl<'db, 'id> Inference<'db, 'id> {
859862 // Conform all uninferred numeric literals to felt252.
860863 loop {
861864 let mut changed = false ;
862- self . solve_ex ( ) ?;
865+ self . solve ( ) ?;
863866 for ( var, _) in self . ambiguous . clone ( ) {
864867 let impl_var = self . impl_var ( var) . clone ( ) ;
865868 if impl_var. concrete_trait_id . trait_id ( self . db ) != numeric_trait_id {
@@ -873,8 +876,8 @@ impl<'db, 'id> Inference<'db, 'id> {
873876 if self . rewrite ( ty) . no_err ( ) == felt_ty {
874877 continue ;
875878 }
876- self . conform_ty ( ty, felt_ty) . map_err ( |err_set | {
877- ( err_set , self . stable_ptrs . get ( & InferenceVar :: Impl ( impl_var. id ) ) . copied ( ) )
879+ self . conform_ty ( ty, felt_ty) . inspect_err ( |_err_set | {
880+ self . add_error_stable_ptr ( InferenceVar :: Impl ( impl_var. id ) ) ;
878881 } ) ?;
879882 changed = true ;
880883 break ;
@@ -891,7 +894,7 @@ impl<'db, 'id> Inference<'db, 'id> {
891894 let Some ( ( var, err) ) = self . first_undetermined_variable ( ) else {
892895 return Ok ( ( ) ) ;
893896 } ;
894- Err ( ( self . set_error ( err) , self . stable_ptrs . get ( & var) . copied ( ) ) )
897+ Err ( self . set_error_on_var ( err, var) )
895898 }
896899
897900 /// Finalizes the inference and report diagnostics if there are any errors.
@@ -902,12 +905,8 @@ impl<'db, 'id> Inference<'db, 'id> {
902905 diagnostics : & mut SemanticDiagnostics < ' db > ,
903906 stable_ptr : SyntaxStablePtrId < ' db > ,
904907 ) {
905- if let Err ( ( err_set, err_stable_ptr) ) = self . finalize_without_reporting ( ) {
906- let diag = self . report_on_pending_error (
907- err_set,
908- diagnostics,
909- err_stable_ptr. unwrap_or ( stable_ptr) ,
910- ) ;
908+ if let Err ( err_set) = self . finalize_without_reporting ( ) {
909+ let diag = self . report_on_pending_error ( err_set, diagnostics, stable_ptr) ;
911910
912911 let ty_missing = TypeId :: missing ( self . db , diag) ;
913912 for var in & self . data . type_vars {
@@ -991,7 +990,8 @@ impl<'db, 'id> Inference<'db, 'id> {
991990 }
992991 if !impl_id. is_var_free ( self . db ) && self . impl_contains_var ( impl_id, InferenceVar :: Impl ( var) )
993992 {
994- return Err ( self . set_error ( InferenceError :: Cycle ( InferenceVar :: Impl ( var) ) ) ) ;
993+ let inference_var = InferenceVar :: Impl ( var) ;
994+ return Err ( self . set_error_on_var ( InferenceError :: Cycle ( inference_var) , inference_var) ) ;
995995 }
996996 self . impl_assignment . insert ( var, impl_id) ;
997997 if let Some ( mappings) = self . impl_vars_trait_item_mappings . remove ( & var) {
@@ -1005,9 +1005,11 @@ impl<'db, 'id> Inference<'db, 'id> {
10051005 let ty0 = self . rewrite ( ty) . no_err ( ) ;
10061006 let ty1 = self . rewrite ( impl_ty) . no_err ( ) ;
10071007
1008- let error =
1009- InferenceError :: ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } ;
1010- self . error_status = Err ( InferenceErrorStatus :: Pending ( error) ) ;
1008+ let err = InferenceError :: ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } ;
1009+ self . error_status = Err ( InferenceErrorStatus :: Pending ( PendingInferenceError {
1010+ err,
1011+ stable_ptr : self . stable_ptrs . get ( & InferenceVar :: Impl ( var) ) . cloned ( ) ,
1012+ } ) ) ;
10111013 return Err ( err_set) ;
10121014 }
10131015 }
@@ -1100,7 +1102,7 @@ impl<'db, 'id> Inference<'db, 'id> {
11001102 assert ! ( !self . type_assignment. contains_key( & var. id) , "Cannot reassign variable." ) ;
11011103 let inference_var = InferenceVar :: Type ( var. id ) ;
11021104 if !ty. is_var_free ( self . db ) && self . ty_contains_var ( ty, inference_var) {
1103- return Err ( self . set_error ( InferenceError :: Cycle ( inference_var) ) ) ;
1105+ return Err ( self . set_error_on_var ( InferenceError :: Cycle ( inference_var) , inference_var ) ) ;
11041106 }
11051107 // If assigning var to var - making sure assigning to the lower id for proper canonization.
11061108 if let TypeLongId :: Var ( other) = ty. long ( self . db )
@@ -1328,22 +1330,51 @@ impl<'db, 'id> Inference<'db, 'id> {
13281330 /// Does nothing if an error is already set.
13291331 /// Returns an `ErrorSet` that can be used in reporting the error.
13301332 pub fn set_error ( & mut self , err : InferenceError < ' db > ) -> ErrorSet {
1333+ self . set_error_ex ( err, None )
1334+ }
1335+
1336+ /// Sets an error in the inference state, with an optional location for the diagnostics
1337+ /// reporting. Does nothing if an error is already set.
1338+ /// Returns an `ErrorSet` that can be used in reporting the error.
1339+ pub fn set_error_ex (
1340+ & mut self ,
1341+ err : InferenceError < ' db > ,
1342+ stable_ptr : Option < SyntaxStablePtrId < ' db > > ,
1343+ ) -> ErrorSet {
13311344 if self . error_status . is_err ( ) {
13321345 return ErrorSet ;
13331346 }
13341347 self . error_status = Err ( if let InferenceError :: Reported ( diag_added) = err {
13351348 InferenceErrorStatus :: Consumed ( diag_added)
13361349 } else {
1337- InferenceErrorStatus :: Pending ( err)
1350+ InferenceErrorStatus :: Pending ( PendingInferenceError { err, stable_ptr } )
13381351 } ) ;
13391352 ErrorSet
13401353 }
13411354
1355+ /// Sets an error in the inference state, with a var to fetch location for the diagnostics
1356+ /// reporting. Does nothing if an error is already set.
1357+ /// Returns an `ErrorSet` that can be used in reporting the error.
1358+ pub fn set_error_on_var ( & mut self , err : InferenceError < ' db > , var : InferenceVar ) -> ErrorSet {
1359+ self . set_error_ex ( err, self . stable_ptrs . get ( & var) . cloned ( ) )
1360+ }
1361+
13421362 /// Returns whether an error is set (either pending or consumed).
13431363 pub fn is_error_set ( & self ) -> InferenceResult < ( ) > {
13441364 self . error_status . as_ref ( ) . copied ( ) . map_err ( |_| ErrorSet )
13451365 }
13461366
1367+ /// If there is no stable ptr for the pending error, add it by the given var.
1368+ fn add_error_stable_ptr ( & mut self , var : InferenceVar ) {
1369+ let var_stable_ptr = self . stable_ptrs . get ( & var) . copied ( ) ;
1370+ if let Err ( InferenceErrorStatus :: Pending ( PendingInferenceError { err : _, stable_ptr } ) ) =
1371+ & mut self . error_status
1372+ && stable_ptr. is_none ( )
1373+ {
1374+ * stable_ptr = var_stable_ptr;
1375+ }
1376+ }
1377+
13471378 /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
13481379 /// returns None. This should be used with caution. Always prefer to use
13491380 /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
@@ -1353,7 +1384,7 @@ impl<'db, 'id> Inference<'db, 'id> {
13531384 & mut self ,
13541385 err_set : ErrorSet ,
13551386 ) -> Option < InferenceError < ' db > > {
1356- self . consume_error_inner ( err_set, skip_diagnostic ( ) )
1387+ Some ( self . consume_error_inner ( err_set, skip_diagnostic ( ) ) ? . err )
13571388 }
13581389
13591390 /// Consumes the error that is already reported. If there is no error, or the error is consumed,
@@ -1376,10 +1407,16 @@ impl<'db, 'id> Inference<'db, 'id> {
13761407 & mut self ,
13771408 _err_set : ErrorSet ,
13781409 diag_added : DiagnosticAdded ,
1379- ) -> Option < InferenceError < ' db > > {
1410+ ) -> Option < PendingInferenceError < ' db > > {
13801411 match & mut self . error_status {
13811412 Err ( InferenceErrorStatus :: Pending ( error) ) => {
1382- let pending_error = std:: mem:: replace ( error, InferenceError :: Reported ( diag_added) ) ;
1413+ let pending_error = std:: mem:: replace (
1414+ error,
1415+ PendingInferenceError {
1416+ err : InferenceError :: Reported ( diag_added) ,
1417+ stable_ptr : None ,
1418+ } ,
1419+ ) ;
13831420 self . error_status = Err ( InferenceErrorStatus :: Consumed ( diag_added) ) ;
13841421 Some ( pending_error)
13851422 }
@@ -1404,16 +1441,16 @@ impl<'db, 'id> Inference<'db, 'id> {
14041441 } ;
14051442 match state_error {
14061443 InferenceErrorStatus :: Consumed ( diag_added) => * diag_added,
1407- InferenceErrorStatus :: Pending ( error ) => {
1408- let diag_added = match error {
1444+ InferenceErrorStatus :: Pending ( pending ) => {
1445+ let diag_added = match & pending . err {
14091446 InferenceError :: TypeNotInferred ( _) if diagnostics. error_count > 0 => {
14101447 // If we have other diagnostics, there is no need to TypeNotInferred.
14111448
14121449 // Note that `diagnostics` is not empty, so it is safe to return
14131450 // 'DiagnosticAdded' here.
14141451 skip_diagnostic ( )
14151452 }
1416- diag => diag. report ( diagnostics, stable_ptr) ,
1453+ diag => diag. report ( diagnostics, pending . stable_ptr . unwrap_or ( stable_ptr ) ) ,
14171454 } ;
14181455 self . error_status = Err ( InferenceErrorStatus :: Consumed ( diag_added) ) ;
14191456 diag_added
@@ -1428,7 +1465,7 @@ impl<'db, 'id> Inference<'db, 'id> {
14281465 err_set : ErrorSet ,
14291466 report : impl FnOnce ( ) -> DiagnosticAdded ,
14301467 ) {
1431- if matches ! ( self . error_status, Err ( InferenceErrorStatus :: Pending ( _ ) ) ) {
1468+ if matches ! ( self . error_status, Err ( InferenceErrorStatus :: Pending { .. } ) ) {
14321469 self . consume_reported_error ( err_set, report ( ) ) ;
14331470 }
14341471 }
0 commit comments