@@ -355,7 +355,7 @@ impl Operation {
355355 ( Self :: Custom ( CustomPredicateRef { batch, index } , args) , Custom ( cpr, s_args) )
356356 if batch == & cpr. batch && index == & cpr. index =>
357357 {
358- check_custom_pred ( params, cpr, args, s_args) ?
358+ check_custom_pred ( params, cpr, args, s_args) . map ( |_| true ) ?
359359 }
360360 _ => return Err ( deduction_err ( ) ) ,
361361 } ;
@@ -370,79 +370,88 @@ pub fn check_st_tmpl(
370370 st_arg : & StatementArg ,
371371 // Map from wildcards to values that we have seen so far.
372372 wildcard_map : & mut [ Option < Value > ] ,
373- ) -> bool {
373+ ) -> Result < ( ) > {
374374 // Check that the value `v` at wildcard `wc` exists in the map or set it.
375- fn check_or_set ( v : Value , wc : & Wildcard , wildcard_map : & mut [ Option < Value > ] ) -> bool {
375+ fn check_or_set ( v : Value , wc : & Wildcard , wildcard_map : & mut [ Option < Value > ] ) -> Result < ( ) > {
376376 if let Some ( prev) = & wildcard_map[ wc. index ] {
377377 if * prev != v {
378- // TODO: Return nice error
379- return false ;
378+ return Err ( Error :: invalid_wildcard_assignment (
379+ wc. clone ( ) ,
380+ v,
381+ prev. clone ( ) ,
382+ ) ) ;
380383 }
381384 } else {
382385 wildcard_map[ wc. index ] = Some ( v) ;
383386 }
384- true
387+ Ok ( ( ) )
385388 }
386389
387390 match ( st_tmpl_arg, st_arg) {
388- ( StatementTmplArg :: None , StatementArg :: None ) => true ,
389- ( StatementTmplArg :: Literal ( lhs) , StatementArg :: Literal ( rhs) ) if lhs == rhs => true ,
391+ ( StatementTmplArg :: None , StatementArg :: None ) => Ok ( ( ) ) ,
392+ ( StatementTmplArg :: Literal ( lhs) , StatementArg :: Literal ( rhs) ) if lhs == rhs => Ok ( ( ) ) ,
390393 (
391394 StatementTmplArg :: AnchoredKey ( pod_id_wc, key_tmpl) ,
392395 StatementArg :: Key ( AnchoredKey { pod_id, key } ) ,
393396 ) => {
394397 let pod_id_ok = check_or_set ( Value :: from ( * pod_id) , pod_id_wc, wildcard_map) ;
395- pod_id_ok && ( key_tmpl == key)
398+ pod_id_ok. and_then ( |_| {
399+ ( key_tmpl == key) . then_some ( ( ) ) . ok_or (
400+ Error :: mismatched_anchored_key_in_statement_tmpl_arg (
401+ pod_id_wc. clone ( ) ,
402+ * pod_id,
403+ key_tmpl. clone ( ) ,
404+ key. clone ( ) ,
405+ ) ,
406+ )
407+ } )
396408 }
397409 ( StatementTmplArg :: Wildcard ( wc) , StatementArg :: Literal ( v) ) => {
398410 check_or_set ( v. clone ( ) , wc, wildcard_map)
399411 }
400- _ => {
401- println ! ( "DBG {:?} {:?}" , st_tmpl_arg, st_arg ) ;
402- false
403- }
412+ _ => Err ( Error :: mismatched_statement_tmpl_arg (
413+ st_tmpl_arg. clone ( ) ,
414+ st_arg . clone ( ) ,
415+ ) ) ,
404416 }
405417}
406418
407419pub fn resolve_wildcard_values (
408420 params : & Params ,
409421 pred : & CustomPredicate ,
410422 args : & [ Statement ] ,
411- ) -> Option < Vec < Value > > {
423+ ) -> Result < Vec < Value > > {
412424 // Check that all wildcard have consistent values as assigned in the statements while storing a
413425 // map of their values.
414426 // NOTE: We assume the statements have the same order as defined in the custom predicate. For
415427 // disjunctions we expect Statement::None for the unused statements.
416428 let mut wildcard_map = vec ! [ None ; params. max_custom_predicate_wildcards] ;
417429 for ( st_tmpl, st) in pred. statements . iter ( ) . zip ( args) {
418430 let st_args = st. args ( ) ;
419- for ( st_tmpl_arg, st_arg) in st_tmpl. args . iter ( ) . zip ( & st_args) {
420- if !check_st_tmpl ( st_tmpl_arg, st_arg, & mut wildcard_map) {
421- // TODO: Better errors. Example:
422- // println!("{} doesn't match {}", st_arg, st_tmpl_arg);
423- // println!("{} doesn't match {}", st, st_tmpl);
424- return None ;
425- }
426- }
431+ st_tmpl
432+ . args
433+ . iter ( )
434+ . zip ( & st_args)
435+ . try_for_each ( |( st_tmpl_arg, st_arg) | {
436+ check_st_tmpl ( st_tmpl_arg, st_arg, & mut wildcard_map)
437+ } ) ?;
427438 }
428439
429440 // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
430441 // they are beyond the number of used wildcards in this custom predicate, or they could be
431442 // private arguments that are unused in a particular disjunction.
432- Some (
433- wildcard_map
434- . into_iter ( )
435- . map ( |opt| opt. unwrap_or ( Value :: from ( 0 ) ) )
436- . collect ( ) ,
437- )
443+ Ok ( wildcard_map
444+ . into_iter ( )
445+ . map ( |opt| opt. unwrap_or ( Value :: from ( 0 ) ) )
446+ . collect ( ) )
438447}
439448
440449fn check_custom_pred (
441450 params : & Params ,
442451 custom_pred_ref : & CustomPredicateRef ,
443452 args : & [ Statement ] ,
444453 s_args : & [ Value ] ,
445- ) -> Result < bool > {
454+ ) -> Result < ( ) > {
446455 let pred = custom_pred_ref. predicate ( ) ;
447456 if pred. statements . len ( ) != args. len ( ) {
448457 return Err ( Error :: diff_amount (
@@ -476,23 +485,33 @@ fn check_custom_pred(
476485 }
477486 }
478487
479- let wildcard_map = match resolve_wildcard_values ( params, pred, args) {
480- Some ( wc_map) => wc_map,
481- None => return Ok ( false ) ,
482- } ;
488+ let wildcard_map = resolve_wildcard_values ( params, pred, args) ?;
483489
484- // Check that the resolved wildcard match the statement arguments.
485- for ( s_arg, wc_value) in s_args. iter ( ) . zip ( wildcard_map. iter ( ) ) {
490+ // Check that the resolved wildcards match the statement arguments.
491+ for ( arg_index , ( s_arg, wc_value) ) in s_args. iter ( ) . zip ( wildcard_map. iter ( ) ) . enumerate ( ) {
486492 if * wc_value != * s_arg {
487- return Ok ( false ) ;
493+ return Err ( Error :: mismatched_wildcard_value_and_statement_arg (
494+ wc_value. clone ( ) ,
495+ s_arg. clone ( ) ,
496+ arg_index,
497+ pred. clone ( ) ,
498+ ) ) ;
488499 }
489500 }
490501
491502 if pred. conjunction {
492- Ok ( num_matches == pred. statements . len ( ) )
493- } else {
494- Ok ( num_matches > 0 )
503+ if num_matches != pred. statements . len ( ) {
504+ return Err ( Error :: unsatisfied_custom_predicate_conjunction (
505+ pred. clone ( ) ,
506+ ) ) ;
507+ }
508+ } else if num_matches == 0 {
509+ return Err ( Error :: unsatisfied_custom_predicate_disjunction (
510+ pred. clone ( ) ,
511+ ) ) ;
495512 }
513+
514+ Ok ( ( ) )
496515}
497516
498517impl ToFields for Operation {
0 commit comments