@@ -260,6 +260,15 @@ fn process_use_statement(
260260 Ok ( ( ) )
261261}
262262
263+ enum StatementContext < ' a > {
264+ CustomPredicate ,
265+ Request {
266+ custom_batch : & ' a Arc < CustomPredicateBatch > ,
267+ wildcard_names : & ' a mut Vec < String > ,
268+ defined_wildcards : & ' a mut HashSet < String > ,
269+ } ,
270+ }
271+
263272fn second_pass ( ctx : & mut ProcessingContext ) -> Result < PodlangOutput , ProcessorError > {
264273 let mut cpb_builder =
265274 CustomPredicateBatchBuilder :: new ( ctx. params . clone ( ) , "PodlangBatch" . to_string ( ) ) ;
@@ -271,7 +280,7 @@ fn second_pass(ctx: &mut ProcessingContext) -> Result<PodlangOutput, ProcessorEr
271280 let custom_batch = cpb_builder. finish ( ) ;
272281
273282 let request_templates = if let Some ( req_pair) = & ctx. request_pair {
274- process_request_def ( req_pair, ctx) ?
283+ process_request_def ( req_pair, ctx, & custom_batch ) ?
275284 } else {
276285 Vec :: new ( )
277286 } ;
@@ -560,38 +569,10 @@ fn process_and_add_custom_predicate_to_batch(
560569 . into_inner ( )
561570 . filter ( |p| p. as_rule ( ) == Rule :: statement)
562571 {
563- let mut inner_stmt_pairs = stmt_pair. clone ( ) . into_inner ( ) ;
564- let stmt_name_pair = inner_stmt_pairs
565- . find ( |p| p. as_rule ( ) == Rule :: identifier)
566- . unwrap_or_else ( || unreachable ! ( "statement name must be present in statement" ) ) ;
567- let stmt_name_str = stmt_name_pair. as_str ( ) ;
568-
569- let builder_args = parse_statement_args ( & stmt_pair) ?;
570-
571- let middleware_predicate_type =
572- if let Some ( native_pred) = native_predicate_from_string ( stmt_name_str) {
573- Predicate :: Native ( native_pred)
574- } else if let Some ( custom_ref) = processing_ctx. imported_predicates . get ( stmt_name_str) {
575- Predicate :: Custom ( custom_ref. clone ( ) )
576- } else if let Some ( ( pred_index, _expected_arity) ) = processing_ctx
577- . custom_predicate_signatures
578- . get ( stmt_name_str)
579- {
580- Predicate :: BatchSelf ( * pred_index)
581- } else {
582- return Err ( ProcessorError :: UndefinedIdentifier {
583- name : stmt_name_str. to_string ( ) ,
584- span : Some ( get_span ( & stmt_name_pair) ) ,
585- } ) ;
586- } ;
587-
588- let stb = validate_and_build_statement_template (
589- stmt_name_str,
590- & middleware_predicate_type,
591- builder_args,
572+ let stb = process_statement_template (
573+ & stmt_pair,
592574 processing_ctx,
593- get_span ( & stmt_pair) ,
594- get_span ( & stmt_name_pair) ,
575+ StatementContext :: CustomPredicate ,
595576 ) ?;
596577 statement_builders. push ( stb) ;
597578 }
@@ -612,6 +593,7 @@ fn process_and_add_custom_predicate_to_batch(
612593fn process_request_def (
613594 req_def_pair : & Pair < Rule > ,
614595 processing_ctx : & ProcessingContext ,
596+ custom_batch : & Arc < CustomPredicateBatch > ,
615597) -> Result < Vec < StatementTmpl > , ProcessorError > {
616598 let mut request_wildcard_names: Vec < String > = Vec :: new ( ) ;
617599 let mut defined_request_wildcards: HashSet < String > = HashSet :: new ( ) ;
@@ -627,11 +609,14 @@ fn process_request_def(
627609 . into_inner ( )
628610 . filter ( |p| p. as_rule ( ) == Rule :: statement)
629611 {
630- let built_stb = process_proof_request_statement_template (
612+ let built_stb = process_statement_template (
631613 & stmt_pair,
632614 processing_ctx,
633- & mut request_wildcard_names,
634- & mut defined_request_wildcards,
615+ StatementContext :: Request {
616+ custom_batch,
617+ wildcard_names : & mut request_wildcard_names,
618+ defined_wildcards : & mut defined_request_wildcards,
619+ } ,
635620 ) ?;
636621 request_statement_builders. push ( built_stb) ;
637622 }
@@ -648,11 +633,10 @@ fn process_request_def(
648633 Ok ( request_templates)
649634}
650635
651- fn process_proof_request_statement_template (
636+ fn process_statement_template (
652637 stmt_pair : & Pair < Rule > ,
653638 processing_ctx : & ProcessingContext ,
654- request_wildcard_names : & mut Vec < String > ,
655- defined_request_wildcards : & mut HashSet < String > ,
639+ mut context : StatementContext ,
656640) -> Result < StatementTmplBuilder , ProcessorError > {
657641 let mut inner_stmt_pairs = stmt_pair. clone ( ) . into_inner ( ) ;
658642 let name_pair = inner_stmt_pairs
@@ -661,45 +645,58 @@ fn process_proof_request_statement_template(
661645 let stmt_name_str = name_pair. as_str ( ) ;
662646
663647 let builder_args = parse_statement_args ( stmt_pair) ?;
664- let mut temp_stmt_wildcard_names: Vec < String > = Vec :: new ( ) ;
665648
666- for arg in & builder_args {
667- match arg {
668- BuilderArg :: WildcardLiteral ( name) => temp_stmt_wildcard_names. push ( name. clone ( ) ) ,
669- BuilderArg :: Key ( pod_id_str, key_wc_str) => {
670- if let SelfOrWildcardStr :: Wildcard ( name) = pod_id_str {
671- temp_stmt_wildcard_names. push ( name. clone ( ) ) ;
672- }
673- if let KeyOrWildcardStr :: Wildcard ( key_wc_name) = key_wc_str {
674- temp_stmt_wildcard_names. push ( key_wc_name. clone ( ) ) ;
649+ if let StatementContext :: Request {
650+ wildcard_names,
651+ defined_wildcards,
652+ ..
653+ } = & mut context
654+ {
655+ let mut temp_stmt_wildcard_names: Vec < String > = Vec :: new ( ) ;
656+ for arg in & builder_args {
657+ match arg {
658+ BuilderArg :: WildcardLiteral ( name) => temp_stmt_wildcard_names. push ( name. clone ( ) ) ,
659+ BuilderArg :: Key ( pod_id_str, key_wc_str) => {
660+ if let SelfOrWildcardStr :: Wildcard ( name) = pod_id_str {
661+ temp_stmt_wildcard_names. push ( name. clone ( ) ) ;
662+ }
663+ if let KeyOrWildcardStr :: Wildcard ( key_wc_name) = key_wc_str {
664+ temp_stmt_wildcard_names. push ( key_wc_name. clone ( ) ) ;
665+ }
675666 }
667+ _ => { }
676668 }
677- _ => { }
678669 }
679- }
680-
681- for name in temp_stmt_wildcard_names {
682- if defined_request_wildcards. insert ( name. clone ( ) ) {
683- request_wildcard_names. push ( name) ;
670+ for name in temp_stmt_wildcard_names {
671+ if defined_wildcards. insert ( name. clone ( ) ) {
672+ wildcard_names. push ( name) ;
673+ }
684674 }
685675 }
686676
687- let middleware_predicate_type =
688- if let Some ( native_pred) = native_predicate_from_string ( stmt_name_str) {
689- Predicate :: Native ( native_pred)
690- } else if let Some ( custom_ref) = processing_ctx. imported_predicates . get ( stmt_name_str) {
691- Predicate :: Custom ( custom_ref. clone ( ) )
692- } else if let Some ( ( pred_index, _expected_arity) ) = processing_ctx
693- . custom_predicate_signatures
694- . get ( stmt_name_str)
695- {
696- Predicate :: BatchSelf ( * pred_index)
697- } else {
698- return Err ( ProcessorError :: UndefinedIdentifier {
699- name : stmt_name_str. to_string ( ) ,
700- span : Some ( get_span ( & name_pair) ) ,
701- } ) ;
702- } ;
677+ let middleware_predicate_type = if let Some ( native_pred) =
678+ native_predicate_from_string ( stmt_name_str)
679+ {
680+ Predicate :: Native ( native_pred)
681+ } else if let Some ( custom_ref) = processing_ctx. imported_predicates . get ( stmt_name_str) {
682+ Predicate :: Custom ( custom_ref. clone ( ) )
683+ } else if let Some ( ( pred_index, _expected_arity) ) = processing_ctx
684+ . custom_predicate_signatures
685+ . get ( stmt_name_str)
686+ {
687+ match context {
688+ StatementContext :: CustomPredicate => Predicate :: BatchSelf ( * pred_index) ,
689+ StatementContext :: Request { custom_batch, .. } => {
690+ let custom_pred_ref = CustomPredicateRef :: new ( custom_batch. clone ( ) , * pred_index) ;
691+ Predicate :: Custom ( custom_pred_ref)
692+ }
693+ }
694+ } else {
695+ return Err ( ProcessorError :: UndefinedIdentifier {
696+ name : stmt_name_str. to_string ( ) ,
697+ span : Some ( get_span ( & name_pair) ) ,
698+ } ) ;
699+ } ;
703700
704701 let stb = validate_and_build_statement_template (
705702 stmt_name_str,
0 commit comments