Skip to content

Commit e8468d7

Browse files
authored
chore(middleware): additional error reporting for custom predicates (#330)
* Additional error reporting for custom predicates * Code review * Typo
1 parent aeedf55 commit e8468d7

File tree

4 files changed

+133
-44
lines changed

4 files changed

+133
-44
lines changed

src/frontend/error.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ fn display_wc_map(wc_map: &[Option<Value>]) -> String {
2222
pub enum InnerError {
2323
#[error("{0} {1} is over the limit {2}")]
2424
MaxLength(String, usize, usize),
25-
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}", map=display_wc_map(.2))]
26-
StatementsDontMatch(Statement, StatementTmpl, Vec<Option<Value>>),
25+
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
26+
StatementsDontMatch(
27+
Statement,
28+
StatementTmpl,
29+
Vec<Option<Value>>,
30+
crate::middleware::Error,
31+
),
2732
#[error("invalid arguments to {0} operation")]
2833
OpInvalidArgs(String),
2934
// Other
@@ -76,8 +81,9 @@ impl Error {
7681
s0: Statement,
7782
s1: StatementTmpl,
7883
wc_map: Vec<Option<Value>>,
84+
mid_error: crate::middleware::Error,
7985
) -> Self {
80-
new!(StatementsDontMatch(s0, s1, wc_map))
86+
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
8187
}
8288
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
8389
new!(MaxLength(obj, found, expect))

src/frontend/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,11 +471,14 @@ impl MainPodBuilder {
471471
for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) {
472472
let st_args = st.args();
473473
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
474-
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
474+
if let Err(st_tmpl_check_error) =
475+
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
476+
{
475477
return Err(Error::statements_dont_match(
476478
st.clone(),
477479
st_tmpl.clone(),
478480
wildcard_map,
481+
st_tmpl_check_error,
479482
));
480483
}
481484
}

src/middleware/error.rs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
33
use std::{backtrace::Backtrace, fmt::Debug};
44

5-
use crate::middleware::{Operation, Statement, StatementArg};
5+
use crate::middleware::{
6+
CustomPredicate, Key, Operation, PodId, Statement, StatementArg, StatementTmplArg, Value,
7+
Wildcard,
8+
};
69

710
pub type Result<T, E = Error> = core::result::Result<T, E>;
811

@@ -18,6 +21,22 @@ pub enum MiddlewareInnerError {
1821
MaxLength(String, usize, usize),
1922
#[error("{0} amount of {1} should be {1} but it's {2}")]
2023
DiffAmount(String, String, usize, usize),
24+
#[error("{0} should be assigned the value {1} but has previously been assigned {2}")]
25+
InvalidWildcardAssignment(Wildcard, Value, Value),
26+
#[error("{0} matches POD ID {1}, yet the template key {2} does not match {3}")]
27+
MismatchedAnchoredKeyInStatementTmplArg(Wildcard, PodId, Key, Key),
28+
#[error("{0} does not match against {1}")]
29+
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
30+
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
31+
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
32+
#[error(
33+
"Not all statement templates of the following custom predicate have been matched:\n{0}"
34+
)]
35+
UnsatisfiedCustomPredicateConjunction(CustomPredicate),
36+
#[error(
37+
"None of the statement templates of the following custom predicate have been matched:\n{0}"
38+
)]
39+
UnsatisfiedCustomPredicateDisjunction(CustomPredicate),
2140
// Other
2241
#[error("{0}")]
2342
Custom(String),
@@ -65,6 +84,48 @@ impl Error {
6584
pub(crate) fn diff_amount(obj: String, unit: String, expect: usize, found: usize) -> Self {
6685
new!(DiffAmount(obj, unit, expect, found))
6786
}
87+
pub(crate) fn invalid_wildcard_assignment(
88+
wildcard: Wildcard,
89+
value: Value,
90+
prev_value: Value,
91+
) -> Self {
92+
new!(InvalidWildcardAssignment(wildcard, value, prev_value))
93+
}
94+
pub(crate) fn mismatched_anchored_key_in_statement_tmpl_arg(
95+
pod_id_wildcard: Wildcard,
96+
pod_id: PodId,
97+
key_tmpl: Key,
98+
key: Key,
99+
) -> Self {
100+
new!(MismatchedAnchoredKeyInStatementTmplArg(
101+
pod_id_wildcard,
102+
pod_id,
103+
key_tmpl,
104+
key
105+
))
106+
}
107+
pub(crate) fn mismatched_statement_tmpl_arg(
108+
st_tmpl_arg: StatementTmplArg,
109+
st_arg: StatementArg,
110+
) -> Self {
111+
new!(MismatchedStatementTmplArg(st_tmpl_arg, st_arg))
112+
}
113+
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
114+
wc_value: Value,
115+
st_arg: Value,
116+
arg_index: usize,
117+
pred: CustomPredicate,
118+
) -> Self {
119+
new!(MismatchedWildcardValueAndStatementArg(
120+
wc_value, st_arg, arg_index, pred
121+
))
122+
}
123+
pub(crate) fn unsatisfied_custom_predicate_conjunction(pred: CustomPredicate) -> Self {
124+
new!(UnsatisfiedCustomPredicateConjunction(pred))
125+
}
126+
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
127+
new!(UnsatisfiedCustomPredicateDisjunction(pred))
128+
}
68129
pub(crate) fn custom(s: String) -> Self {
69130
new!(Custom(s))
70131
}

src/middleware/operation.rs

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ impl Operation {
355355
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
356356
if batch == &cpr.batch && index == &cpr.index =>
357357
{
358-
check_custom_pred(params, cpr, args, s_args)?
358+
check_custom_pred(params, cpr, args, s_args).map(|_| true)?
359359
}
360360
_ => return Err(deduction_err()),
361361
};
@@ -370,79 +370,88 @@ pub fn check_st_tmpl(
370370
st_arg: &StatementArg,
371371
// Map from wildcards to values that we have seen so far.
372372
wildcard_map: &mut [Option<Value>],
373-
) -> bool {
373+
) -> Result<()> {
374374
// Check that the value `v` at wildcard `wc` exists in the map or set it.
375-
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> bool {
375+
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
376376
if let Some(prev) = &wildcard_map[wc.index] {
377377
if *prev != v {
378-
// TODO: Return nice error
379-
return false;
378+
return Err(Error::invalid_wildcard_assignment(
379+
wc.clone(),
380+
v,
381+
prev.clone(),
382+
));
380383
}
381384
} else {
382385
wildcard_map[wc.index] = Some(v);
383386
}
384-
true
387+
Ok(())
385388
}
386389

387390
match (st_tmpl_arg, st_arg) {
388-
(StatementTmplArg::None, StatementArg::None) => true,
389-
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
391+
(StatementTmplArg::None, StatementArg::None) => Ok(()),
392+
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()),
390393
(
391394
StatementTmplArg::AnchoredKey(pod_id_wc, key_tmpl),
392395
StatementArg::Key(AnchoredKey { pod_id, key }),
393396
) => {
394397
let pod_id_ok = check_or_set(Value::from(*pod_id), pod_id_wc, wildcard_map);
395-
pod_id_ok && (key_tmpl == key)
398+
pod_id_ok.and_then(|_| {
399+
(key_tmpl == key).then_some(()).ok_or(
400+
Error::mismatched_anchored_key_in_statement_tmpl_arg(
401+
pod_id_wc.clone(),
402+
*pod_id,
403+
key_tmpl.clone(),
404+
key.clone(),
405+
),
406+
)
407+
})
396408
}
397409
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
398410
check_or_set(v.clone(), wc, wildcard_map)
399411
}
400-
_ => {
401-
println!("DBG {:?} {:?}", st_tmpl_arg, st_arg);
402-
false
403-
}
412+
_ => Err(Error::mismatched_statement_tmpl_arg(
413+
st_tmpl_arg.clone(),
414+
st_arg.clone(),
415+
)),
404416
}
405417
}
406418

407419
pub fn resolve_wildcard_values(
408420
params: &Params,
409421
pred: &CustomPredicate,
410422
args: &[Statement],
411-
) -> Option<Vec<Value>> {
423+
) -> Result<Vec<Value>> {
412424
// Check that all wildcard have consistent values as assigned in the statements while storing a
413425
// map of their values.
414426
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
415427
// disjunctions we expect Statement::None for the unused statements.
416428
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
417429
for (st_tmpl, st) in pred.statements.iter().zip(args) {
418430
let st_args = st.args();
419-
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
420-
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
421-
// TODO: Better errors. Example:
422-
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
423-
// println!("{} doesn't match {}", st, st_tmpl);
424-
return None;
425-
}
426-
}
431+
st_tmpl
432+
.args
433+
.iter()
434+
.zip(&st_args)
435+
.try_for_each(|(st_tmpl_arg, st_arg)| {
436+
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
437+
})?;
427438
}
428439

429440
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
430441
// they are beyond the number of used wildcards in this custom predicate, or they could be
431442
// private arguments that are unused in a particular disjunction.
432-
Some(
433-
wildcard_map
434-
.into_iter()
435-
.map(|opt| opt.unwrap_or(Value::from(0)))
436-
.collect(),
437-
)
443+
Ok(wildcard_map
444+
.into_iter()
445+
.map(|opt| opt.unwrap_or(Value::from(0)))
446+
.collect())
438447
}
439448

440449
fn check_custom_pred(
441450
params: &Params,
442451
custom_pred_ref: &CustomPredicateRef,
443452
args: &[Statement],
444453
s_args: &[Value],
445-
) -> Result<bool> {
454+
) -> Result<()> {
446455
let pred = custom_pred_ref.predicate();
447456
if pred.statements.len() != args.len() {
448457
return Err(Error::diff_amount(
@@ -476,23 +485,33 @@ fn check_custom_pred(
476485
}
477486
}
478487

479-
let wildcard_map = match resolve_wildcard_values(params, pred, args) {
480-
Some(wc_map) => wc_map,
481-
None => return Ok(false),
482-
};
488+
let wildcard_map = resolve_wildcard_values(params, pred, args)?;
483489

484-
// Check that the resolved wildcard match the statement arguments.
485-
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
490+
// Check that the resolved wildcards match the statement arguments.
491+
for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() {
486492
if *wc_value != *s_arg {
487-
return Ok(false);
493+
return Err(Error::mismatched_wildcard_value_and_statement_arg(
494+
wc_value.clone(),
495+
s_arg.clone(),
496+
arg_index,
497+
pred.clone(),
498+
));
488499
}
489500
}
490501

491502
if pred.conjunction {
492-
Ok(num_matches == pred.statements.len())
493-
} else {
494-
Ok(num_matches > 0)
503+
if num_matches != pred.statements.len() {
504+
return Err(Error::unsatisfied_custom_predicate_conjunction(
505+
pred.clone(),
506+
));
507+
}
508+
} else if num_matches == 0 {
509+
return Err(Error::unsatisfied_custom_predicate_disjunction(
510+
pred.clone(),
511+
));
495512
}
513+
514+
Ok(())
496515
}
497516

498517
impl ToFields for Operation {

0 commit comments

Comments
 (0)