|
1 | 1 | pub mod operation; |
2 | | -use crate::middleware::PodType; |
| 2 | +use crate::middleware::{wildcard_values_from_op_st, PodType}; |
3 | 3 | pub mod statement; |
4 | 4 | use std::{iter, sync::Arc}; |
5 | 5 |
|
6 | | -use itertools::Itertools; |
| 6 | +use itertools::{zip_eq, Itertools}; |
7 | 7 | use num_bigint::BigUint; |
8 | 8 | pub use operation::*; |
9 | 9 | use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher}; |
@@ -37,9 +37,9 @@ use crate::{ |
37 | 37 | serialize_proof, serialize_verifier_only, |
38 | 38 | }, |
39 | 39 | middleware::{ |
40 | | - self, resolve_wildcard_values, value_from_op, CustomPredicateBatch, |
41 | | - Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation, |
42 | | - OperationType, Params, Pod, RawValue, StatementArg, ToFields, VDSet, |
| 40 | + self, value_from_op, CustomPredicateBatch, Error as MiddlewareError, Hash, MainPodInputs, |
| 41 | + MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, |
| 42 | + ToFields, VDSet, |
43 | 43 | }, |
44 | 44 | timed, |
45 | 45 | }; |
@@ -97,28 +97,35 @@ pub(crate) fn extract_custom_predicate_verifications( |
97 | 97 | params: &Params, |
98 | 98 | aux_list: &mut [OperationAux], |
99 | 99 | operations: &[middleware::Operation], |
| 100 | + statements: &[middleware::Statement], |
100 | 101 | custom_predicate_batches: &[Arc<CustomPredicateBatch>], |
101 | 102 | ) -> Result<Vec<CustomPredicateVerification>> { |
102 | 103 | let mut table = Vec::new(); |
103 | | - for (i, op) in operations.iter().enumerate() { |
| 104 | + for (i, (op, st)) in zip_eq(operations.iter(), statements.iter()).enumerate() { |
104 | 105 | if let middleware::Operation::Custom(cpr, sts) = op { |
105 | | - let wildcard_values = |
106 | | - resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards"); |
107 | | - let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); |
108 | | - let batch_index = custom_predicate_batches |
109 | | - .iter() |
110 | | - .enumerate() |
111 | | - .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) |
112 | | - .expect("find the custom predicate from the extracted unique list"); |
113 | | - let custom_predicate_table_index = |
114 | | - batch_index * params.max_custom_batch_size + cpr.index; |
115 | | - aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len()); |
116 | | - table.push(CustomPredicateVerification { |
117 | | - custom_predicate_table_index, |
118 | | - custom_predicate: cpr.clone(), |
119 | | - args: wildcard_values, |
120 | | - op_args: sts, |
121 | | - }); |
| 106 | + if let middleware::Statement::Custom(st_cpr, st_args) = st { |
| 107 | + assert_eq!(cpr, st_cpr); |
| 108 | + let wildcard_values = |
| 109 | + wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) |
| 110 | + .expect("resolved wildcards"); |
| 111 | + let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); |
| 112 | + let batch_index = custom_predicate_batches |
| 113 | + .iter() |
| 114 | + .enumerate() |
| 115 | + .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) |
| 116 | + .expect("find the custom predicate from the extracted unique list"); |
| 117 | + let custom_predicate_table_index = |
| 118 | + batch_index * params.max_custom_batch_size + cpr.index; |
| 119 | + aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len()); |
| 120 | + table.push(CustomPredicateVerification { |
| 121 | + custom_predicate_table_index, |
| 122 | + custom_predicate: cpr.clone(), |
| 123 | + args: wildcard_values, |
| 124 | + op_args: sts, |
| 125 | + }); |
| 126 | + } else { |
| 127 | + panic!("Custom operation paired with non-custom statement"); |
| 128 | + } |
122 | 129 | } |
123 | 130 | } |
124 | 131 |
|
@@ -499,6 +506,7 @@ impl MainPodProver for Prover { |
499 | 506 | params, |
500 | 507 | &mut aux_list, |
501 | 508 | inputs.operations, |
| 509 | + inputs.statements, |
502 | 510 | &custom_predicate_batches, |
503 | 511 | )?; |
504 | 512 | let public_key_of_sks = |
@@ -823,6 +831,7 @@ pub mod tests { |
823 | 831 | frontend::{ |
824 | 832 | self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB, |
825 | 833 | }, |
| 834 | + lang::parse, |
826 | 835 | middleware::{ |
827 | 836 | self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _, |
828 | 837 | DEFAULT_VD_LIST, DEFAULT_VD_SET, |
@@ -1154,4 +1163,40 @@ pub mod tests { |
1154 | 1163 | builder.prove(&prover)?; |
1155 | 1164 | Ok(()) |
1156 | 1165 | } |
| 1166 | + |
| 1167 | + #[test] |
| 1168 | + fn test_undetermined_values() { |
| 1169 | + let params = Default::default(); |
| 1170 | + let batch = parse( |
| 1171 | + r#" |
| 1172 | + two_equal(x,y,z) = OR( |
| 1173 | + Equal(x,y) |
| 1174 | + Equal(y,z) |
| 1175 | + Equal(x,z) |
| 1176 | + ) |
| 1177 | + "#, |
| 1178 | + ¶ms, |
| 1179 | + &[], |
| 1180 | + ) |
| 1181 | + .unwrap() |
| 1182 | + .custom_batch; |
| 1183 | + let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET); |
| 1184 | + let cpr = CustomPredicateRef { batch, index: 0 }; |
| 1185 | + let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap(); |
| 1186 | + let op = frontend::Operation::custom( |
| 1187 | + cpr.clone(), |
| 1188 | + [ |
| 1189 | + eq_st, |
| 1190 | + middleware::Statement::None, |
| 1191 | + middleware::Statement::None, |
| 1192 | + ], |
| 1193 | + ); |
| 1194 | + let st = middleware::Statement::Custom( |
| 1195 | + cpr, |
| 1196 | + [1, 1, 2].into_iter().map(middleware::Value::from).collect(), |
| 1197 | + ); |
| 1198 | + builder.insert(true, (st, op)).unwrap(); |
| 1199 | + let prover = Prover {}; |
| 1200 | + builder.prove(&prover).unwrap(); |
| 1201 | + } |
1157 | 1202 | } |
0 commit comments