Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/frontend/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ fn display_wc_map(wc_map: &[Option<Value>]) -> String {
pub enum InnerError {
#[error("{0} {1} is over the limit {2}")]
MaxLength(String, usize, usize),
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}", map=display_wc_map(.2))]
StatementsDontMatch(Statement, StatementTmpl, Vec<Option<Value>>),
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
StatementsDontMatch(
Statement,
StatementTmpl,
Vec<Option<Value>>,
crate::middleware::Error,
),
#[error("invalid arguments to {0} operation")]
OpInvalidArgs(String),
// Other
Expand Down Expand Up @@ -76,8 +81,9 @@ impl Error {
s0: Statement,
s1: StatementTmpl,
wc_map: Vec<Option<Value>>,
mid_error: crate::middleware::Error,
) -> Self {
new!(StatementsDontMatch(s0, s1, wc_map))
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
}
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
new!(MaxLength(obj, found, expect))
Expand Down
5 changes: 4 additions & 1 deletion src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,14 @@ impl MainPodBuilder {
for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) {
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
if let Err(st_tmpl_check_error) =
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
{
return Err(Error::statements_dont_match(
st.clone(),
st_tmpl.clone(),
wildcard_map,
st_tmpl_check_error,
));
}
}
Expand Down
63 changes: 62 additions & 1 deletion src/middleware/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

use std::{backtrace::Backtrace, fmt::Debug};

use crate::middleware::{Operation, Statement, StatementArg};
use crate::middleware::{
CustomPredicate, Key, Operation, PodId, Statement, StatementArg, StatementTmplArg, Value,
Wildcard,
};

pub type Result<T, E = Error> = core::result::Result<T, E>;

Expand All @@ -18,6 +21,22 @@ pub enum MiddlewareInnerError {
MaxLength(String, usize, usize),
#[error("{0} amount of {1} should be {1} but it's {2}")]
DiffAmount(String, String, usize, usize),
#[error("{0} should be assigned the value {1} but has previously been assigned {2}")]
InvalidWildcardAssignment(Wildcard, Value, Value),
#[error("{0} matches POD ID {1}, yet the template key {2} does not match {3}")]
MismatchedAnchoredKeyInStatementTmplArg(Wildcard, PodId, Key, Key),
#[error("{0} does not match against {1}")]
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
#[error(
"Not all statement templates of the following custom predicate have been matched:\n{0}"
)]
UnsatisfiedCustomPredicateConjunction(CustomPredicate),
#[error(
"None of the statement templates of the following custom predicate have been matched:\n{0}"
)]
UnsatisfiedCustomPredicateDisjunction(CustomPredicate),
// Other
#[error("{0}")]
Custom(String),
Expand Down Expand Up @@ -65,6 +84,48 @@ impl Error {
pub(crate) fn diff_amount(obj: String, unit: String, expect: usize, found: usize) -> Self {
new!(DiffAmount(obj, unit, expect, found))
}
pub(crate) fn invalid_wildcard_assignment(
wildcard: Wildcard,
value: Value,
prev_value: Value,
) -> Self {
new!(InvalidWildcardAssignment(wildcard, value, prev_value))
}
pub(crate) fn mismatched_anchored_key_in_statement_tmpl_arg(
pod_id_wildcard: Wildcard,
pod_id: PodId,
key_tmpl: Key,
key: Key,
) -> Self {
new!(MismatchedAnchoredKeyInStatementTmplArg(
pod_id_wildcard,
pod_id,
key_tmpl,
key
))
}
pub(crate) fn mismatched_statement_tmpl_arg(
st_tmpl_arg: StatementTmplArg,
st_arg: StatementArg,
) -> Self {
new!(MismatchedStatementTmplArg(st_tmpl_arg, st_arg))
}
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
wc_value: Value,
st_arg: Value,
arg_index: usize,
pred: CustomPredicate,
) -> Self {
new!(MismatchedWildcardValueAndStatementArg(
wc_value, st_arg, arg_index, pred
))
}
pub(crate) fn unsatisfied_custom_predicate_conjunction(pred: CustomPredicate) -> Self {
new!(UnsatisfiedCustomPredicateConjunction(pred))
}
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
new!(UnsatisfiedCustomPredicateDisjunction(pred))
}
pub(crate) fn custom(s: String) -> Self {
new!(Custom(s))
}
Expand Down
97 changes: 58 additions & 39 deletions src/middleware/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ impl Operation {
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index =>
{
check_custom_pred(params, cpr, args, s_args)?
check_custom_pred(params, cpr, args, s_args).map(|_| true)?
}
_ => return Err(deduction_err()),
};
Expand All @@ -370,79 +370,88 @@ pub fn check_st_tmpl(
st_arg: &StatementArg,
// Map from wildcards to values that we have seen so far.
wildcard_map: &mut [Option<Value>],
) -> bool {
) -> Result<()> {
// Check that the value `v` at wildcard `wc` exists in the map or set it.
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> bool {
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
if let Some(prev) = &wildcard_map[wc.index] {
if *prev != v {
// TODO: Return nice error
return false;
return Err(Error::invalid_wildcard_assignment(
wc.clone(),
v,
prev.clone(),
));
}
} else {
wildcard_map[wc.index] = Some(v);
}
true
Ok(())
}

match (st_tmpl_arg, st_arg) {
(StatementTmplArg::None, StatementArg::None) => true,
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
(StatementTmplArg::None, StatementArg::None) => Ok(()),
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()),
(
StatementTmplArg::AnchoredKey(pod_id_wc, key_tmpl),
StatementArg::Key(AnchoredKey { pod_id, key }),
) => {
let pod_id_ok = check_or_set(Value::from(*pod_id), pod_id_wc, wildcard_map);
pod_id_ok && (key_tmpl == key)
pod_id_ok.and_then(|_| {
(key_tmpl == key).then_some(()).ok_or(
Error::mismatched_anchored_key_in_statement_tmpl_arg(
pod_id_wc.clone(),
*pod_id,
key_tmpl.clone(),
key.clone(),
),
)
})
}
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
check_or_set(v.clone(), wc, wildcard_map)
}
_ => {
println!("DBG {:?} {:?}", st_tmpl_arg, st_arg);
false
}
_ => Err(Error::mismatched_statement_tmpl_arg(
st_tmpl_arg.clone(),
st_arg.clone(),
)),
}
}

pub fn resolve_wildcard_values(
params: &Params,
pred: &CustomPredicate,
args: &[Statement],
) -> Option<Vec<Value>> {
) -> Result<Vec<Value>> {
// Check that all wildcard have consistent values as assigned in the statements while storing a
// map of their values.
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
// disjunctions we expect Statement::None for the unused statements.
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
// TODO: Better errors. Example:
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
// println!("{} doesn't match {}", st, st_tmpl);
return None;
}
}
st_tmpl
.args
.iter()
.zip(&st_args)
.try_for_each(|(st_tmpl_arg, st_arg)| {
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
})?;
}

// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
// they are beyond the number of used wildcards in this custom predicate, or they could be
// private arguments that are unused in a particular disjunction.
Some(
wildcard_map
.into_iter()
.map(|opt| opt.unwrap_or(Value::from(0)))
.collect(),
)
Ok(wildcard_map
.into_iter()
.map(|opt| opt.unwrap_or(Value::from(0)))
.collect())
}

fn check_custom_pred(
params: &Params,
custom_pred_ref: &CustomPredicateRef,
args: &[Statement],
s_args: &[Value],
) -> Result<bool> {
) -> Result<()> {
let pred = custom_pred_ref.predicate();
if pred.statements.len() != args.len() {
return Err(Error::diff_amount(
Expand Down Expand Up @@ -476,23 +485,33 @@ fn check_custom_pred(
}
}

let wildcard_map = match resolve_wildcard_values(params, pred, args) {
Some(wc_map) => wc_map,
None => return Ok(false),
};
let wildcard_map = resolve_wildcard_values(params, pred, args)?;

// Check that the resolved wildcard match the statement arguments.
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
// Check that the resolved wildcards match the statement arguments.
for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() {
if *wc_value != *s_arg {
return Ok(false);
return Err(Error::mismatched_wildcard_value_and_statement_arg(
wc_value.clone(),
s_arg.clone(),
arg_index,
pred.clone(),
));
}
}

if pred.conjunction {
Ok(num_matches == pred.statements.len())
} else {
Ok(num_matches > 0)
if num_matches != pred.statements.len() {
return Err(Error::unsatisfied_custom_predicate_conjunction(
pred.clone(),
));
}
} else if num_matches == 0 {
return Err(Error::unsatisfied_custom_predicate_disjunction(
pred.clone(),
));
}

Ok(())
}

impl ToFields for Operation {
Expand Down
Loading