Skip to content

Commit b4436e0

Browse files
committed
Add extra test for unknown batches
1 parent 64110ec commit b4436e0

File tree

2 files changed

+100
-71
lines changed

2 files changed

+100
-71
lines changed

src/lang/mod.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ pub use parser::{parse_podlang, Pairs, ParseError, Rule};
99
pub use processor::process_pest_tree;
1010
use processor::PodlangOutput;
1111

12-
use crate::middleware::{CustomPredicateBatch, Params};
12+
use crate::{
13+
lang::error::ProcessorError,
14+
middleware::{CustomPredicateBatch, Params},
15+
};
1316

1417
pub fn parse(
1518
input: &str,
@@ -882,4 +885,33 @@ mod tests {
882885

883886
Ok(())
884887
}
888+
889+
#[test]
890+
fn test_e2e_use_unknown_batch() {
891+
let params = Params::default();
892+
let available_batches = &[];
893+
894+
let unknown_batch_id = format!("0x{}", "a".repeat(64));
895+
896+
let input = format!(
897+
r#"
898+
use some_pred from {}
899+
"#,
900+
unknown_batch_id
901+
);
902+
903+
let result = parse(&input, &params, available_batches);
904+
905+
assert!(result.is_err());
906+
907+
match result.err().unwrap() {
908+
LangError::Processor(e) => match *e {
909+
ProcessorError::BatchNotFound { id, .. } => {
910+
assert_eq!(id, unknown_batch_id);
911+
}
912+
_ => panic!("Expected BatchNotFound error, but got {:?}", e),
913+
},
914+
e => panic!("Expected LangError::Processor, but got {:?}", e),
915+
}
916+
}
885917
}

src/lang/processor.rs

Lines changed: 67 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,15 @@ fn process_use_statement(
260260
Ok(())
261261
}
262262

263+
enum StatementContext<'a> {
264+
CustomPredicate,
265+
Request {
266+
custom_batch: &'a Arc<CustomPredicateBatch>,
267+
wildcard_names: &'a mut Vec<String>,
268+
defined_wildcards: &'a mut HashSet<String>,
269+
},
270+
}
271+
263272
fn second_pass(ctx: &mut ProcessingContext) -> Result<PodlangOutput, ProcessorError> {
264273
let mut cpb_builder =
265274
CustomPredicateBatchBuilder::new(ctx.params.clone(), "PodlangBatch".to_string());
@@ -271,7 +280,7 @@ fn second_pass(ctx: &mut ProcessingContext) -> Result<PodlangOutput, ProcessorEr
271280
let custom_batch = cpb_builder.finish();
272281

273282
let request_templates = if let Some(req_pair) = &ctx.request_pair {
274-
process_request_def(req_pair, ctx)?
283+
process_request_def(req_pair, ctx, &custom_batch)?
275284
} else {
276285
Vec::new()
277286
};
@@ -560,38 +569,10 @@ fn process_and_add_custom_predicate_to_batch(
560569
.into_inner()
561570
.filter(|p| p.as_rule() == Rule::statement)
562571
{
563-
let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
564-
let stmt_name_pair = inner_stmt_pairs
565-
.find(|p| p.as_rule() == Rule::identifier)
566-
.unwrap_or_else(|| unreachable!("statement name must be present in statement"));
567-
let stmt_name_str = stmt_name_pair.as_str();
568-
569-
let builder_args = parse_statement_args(&stmt_pair)?;
570-
571-
let middleware_predicate_type =
572-
if let Some(native_pred) = native_predicate_from_string(stmt_name_str) {
573-
Predicate::Native(native_pred)
574-
} else if let Some(custom_ref) = processing_ctx.imported_predicates.get(stmt_name_str) {
575-
Predicate::Custom(custom_ref.clone())
576-
} else if let Some((pred_index, _expected_arity)) = processing_ctx
577-
.custom_predicate_signatures
578-
.get(stmt_name_str)
579-
{
580-
Predicate::BatchSelf(*pred_index)
581-
} else {
582-
return Err(ProcessorError::UndefinedIdentifier {
583-
name: stmt_name_str.to_string(),
584-
span: Some(get_span(&stmt_name_pair)),
585-
});
586-
};
587-
588-
let stb = validate_and_build_statement_template(
589-
stmt_name_str,
590-
&middleware_predicate_type,
591-
builder_args,
572+
let stb = process_statement_template(
573+
&stmt_pair,
592574
processing_ctx,
593-
get_span(&stmt_pair),
594-
get_span(&stmt_name_pair),
575+
StatementContext::CustomPredicate,
595576
)?;
596577
statement_builders.push(stb);
597578
}
@@ -612,6 +593,7 @@ fn process_and_add_custom_predicate_to_batch(
612593
fn process_request_def(
613594
req_def_pair: &Pair<Rule>,
614595
processing_ctx: &ProcessingContext,
596+
custom_batch: &Arc<CustomPredicateBatch>,
615597
) -> Result<Vec<StatementTmpl>, ProcessorError> {
616598
let mut request_wildcard_names: Vec<String> = Vec::new();
617599
let mut defined_request_wildcards: HashSet<String> = HashSet::new();
@@ -627,11 +609,14 @@ fn process_request_def(
627609
.into_inner()
628610
.filter(|p| p.as_rule() == Rule::statement)
629611
{
630-
let built_stb = process_proof_request_statement_template(
612+
let built_stb = process_statement_template(
631613
&stmt_pair,
632614
processing_ctx,
633-
&mut request_wildcard_names,
634-
&mut defined_request_wildcards,
615+
StatementContext::Request {
616+
custom_batch,
617+
wildcard_names: &mut request_wildcard_names,
618+
defined_wildcards: &mut defined_request_wildcards,
619+
},
635620
)?;
636621
request_statement_builders.push(built_stb);
637622
}
@@ -648,11 +633,10 @@ fn process_request_def(
648633
Ok(request_templates)
649634
}
650635

651-
fn process_proof_request_statement_template(
636+
fn process_statement_template(
652637
stmt_pair: &Pair<Rule>,
653638
processing_ctx: &ProcessingContext,
654-
request_wildcard_names: &mut Vec<String>,
655-
defined_request_wildcards: &mut HashSet<String>,
639+
mut context: StatementContext,
656640
) -> Result<StatementTmplBuilder, ProcessorError> {
657641
let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
658642
let name_pair = inner_stmt_pairs
@@ -661,45 +645,58 @@ fn process_proof_request_statement_template(
661645
let stmt_name_str = name_pair.as_str();
662646

663647
let builder_args = parse_statement_args(stmt_pair)?;
664-
let mut temp_stmt_wildcard_names: Vec<String> = Vec::new();
665648

666-
for arg in &builder_args {
667-
match arg {
668-
BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()),
669-
BuilderArg::Key(pod_id_str, key_wc_str) => {
670-
if let SelfOrWildcardStr::Wildcard(name) = pod_id_str {
671-
temp_stmt_wildcard_names.push(name.clone());
672-
}
673-
if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str {
674-
temp_stmt_wildcard_names.push(key_wc_name.clone());
649+
if let StatementContext::Request {
650+
wildcard_names,
651+
defined_wildcards,
652+
..
653+
} = &mut context
654+
{
655+
let mut temp_stmt_wildcard_names: Vec<String> = Vec::new();
656+
for arg in &builder_args {
657+
match arg {
658+
BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()),
659+
BuilderArg::Key(pod_id_str, key_wc_str) => {
660+
if let SelfOrWildcardStr::Wildcard(name) = pod_id_str {
661+
temp_stmt_wildcard_names.push(name.clone());
662+
}
663+
if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str {
664+
temp_stmt_wildcard_names.push(key_wc_name.clone());
665+
}
675666
}
667+
_ => {}
676668
}
677-
_ => {}
678669
}
679-
}
680-
681-
for name in temp_stmt_wildcard_names {
682-
if defined_request_wildcards.insert(name.clone()) {
683-
request_wildcard_names.push(name);
670+
for name in temp_stmt_wildcard_names {
671+
if defined_wildcards.insert(name.clone()) {
672+
wildcard_names.push(name);
673+
}
684674
}
685675
}
686676

687-
let middleware_predicate_type =
688-
if let Some(native_pred) = native_predicate_from_string(stmt_name_str) {
689-
Predicate::Native(native_pred)
690-
} else if let Some(custom_ref) = processing_ctx.imported_predicates.get(stmt_name_str) {
691-
Predicate::Custom(custom_ref.clone())
692-
} else if let Some((pred_index, _expected_arity)) = processing_ctx
693-
.custom_predicate_signatures
694-
.get(stmt_name_str)
695-
{
696-
Predicate::BatchSelf(*pred_index)
697-
} else {
698-
return Err(ProcessorError::UndefinedIdentifier {
699-
name: stmt_name_str.to_string(),
700-
span: Some(get_span(&name_pair)),
701-
});
702-
};
677+
let middleware_predicate_type = if let Some(native_pred) =
678+
native_predicate_from_string(stmt_name_str)
679+
{
680+
Predicate::Native(native_pred)
681+
} else if let Some(custom_ref) = processing_ctx.imported_predicates.get(stmt_name_str) {
682+
Predicate::Custom(custom_ref.clone())
683+
} else if let Some((pred_index, _expected_arity)) = processing_ctx
684+
.custom_predicate_signatures
685+
.get(stmt_name_str)
686+
{
687+
match context {
688+
StatementContext::CustomPredicate => Predicate::BatchSelf(*pred_index),
689+
StatementContext::Request { custom_batch, .. } => {
690+
let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index);
691+
Predicate::Custom(custom_pred_ref)
692+
}
693+
}
694+
} else {
695+
return Err(ProcessorError::UndefinedIdentifier {
696+
name: stmt_name_str.to_string(),
697+
span: Some(get_span(&name_pair)),
698+
});
699+
};
703700

704701
let stb = validate_and_build_statement_template(
705702
stmt_name_str,

0 commit comments

Comments
 (0)