Skip to content

Commit 26548cf

Browse files
authored
Fix for #413 (#415)
1 parent 1d14338 commit 26548cf

File tree

2 files changed

+106
-47
lines changed

2 files changed

+106
-47
lines changed

src/backends/plonky2/mainpod/mod.rs

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
pub mod operation;
2-
use crate::middleware::PodType;
2+
use crate::middleware::{wildcard_values_from_op_st, PodType};
33
pub mod statement;
44
use std::{iter, sync::Arc};
55

6-
use itertools::Itertools;
6+
use itertools::{zip_eq, Itertools};
77
use num_bigint::BigUint;
88
pub use operation::*;
99
use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher};
@@ -37,9 +37,9 @@ use crate::{
3737
serialize_proof, serialize_verifier_only,
3838
},
3939
middleware::{
40-
self, resolve_wildcard_values, value_from_op, CustomPredicateBatch,
41-
Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation,
42-
OperationType, Params, Pod, RawValue, StatementArg, ToFields, VDSet,
40+
self, value_from_op, CustomPredicateBatch, Error as MiddlewareError, Hash, MainPodInputs,
41+
MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg,
42+
ToFields, VDSet,
4343
},
4444
timed,
4545
};
@@ -97,28 +97,35 @@ pub(crate) fn extract_custom_predicate_verifications(
9797
params: &Params,
9898
aux_list: &mut [OperationAux],
9999
operations: &[middleware::Operation],
100+
statements: &[middleware::Statement],
100101
custom_predicate_batches: &[Arc<CustomPredicateBatch>],
101102
) -> Result<Vec<CustomPredicateVerification>> {
102103
let mut table = Vec::new();
103-
for (i, op) in operations.iter().enumerate() {
104+
for (i, (op, st)) in zip_eq(operations.iter(), statements.iter()).enumerate() {
104105
if let middleware::Operation::Custom(cpr, sts) = op {
105-
let wildcard_values =
106-
resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards");
107-
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
108-
let batch_index = custom_predicate_batches
109-
.iter()
110-
.enumerate()
111-
.find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i))
112-
.expect("find the custom predicate from the extracted unique list");
113-
let custom_predicate_table_index =
114-
batch_index * params.max_custom_batch_size + cpr.index;
115-
aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len());
116-
table.push(CustomPredicateVerification {
117-
custom_predicate_table_index,
118-
custom_predicate: cpr.clone(),
119-
args: wildcard_values,
120-
op_args: sts,
121-
});
106+
if let middleware::Statement::Custom(st_cpr, st_args) = st {
107+
assert_eq!(cpr, st_cpr);
108+
let wildcard_values =
109+
wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args)
110+
.expect("resolved wildcards");
111+
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
112+
let batch_index = custom_predicate_batches
113+
.iter()
114+
.enumerate()
115+
.find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i))
116+
.expect("find the custom predicate from the extracted unique list");
117+
let custom_predicate_table_index =
118+
batch_index * params.max_custom_batch_size + cpr.index;
119+
aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len());
120+
table.push(CustomPredicateVerification {
121+
custom_predicate_table_index,
122+
custom_predicate: cpr.clone(),
123+
args: wildcard_values,
124+
op_args: sts,
125+
});
126+
} else {
127+
panic!("Custom operation paired with non-custom statement");
128+
}
122129
}
123130
}
124131

@@ -499,6 +506,7 @@ impl MainPodProver for Prover {
499506
params,
500507
&mut aux_list,
501508
inputs.operations,
509+
inputs.statements,
502510
&custom_predicate_batches,
503511
)?;
504512
let public_key_of_sks =
@@ -823,6 +831,7 @@ pub mod tests {
823831
frontend::{
824832
self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB,
825833
},
834+
lang::parse,
826835
middleware::{
827836
self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _,
828837
DEFAULT_VD_LIST, DEFAULT_VD_SET,
@@ -1154,4 +1163,40 @@ pub mod tests {
11541163
builder.prove(&prover)?;
11551164
Ok(())
11561165
}
1166+
1167+
#[test]
1168+
fn test_undetermined_values() {
1169+
let params = Default::default();
1170+
let batch = parse(
1171+
r#"
1172+
two_equal(x,y,z) = OR(
1173+
Equal(x,y)
1174+
Equal(y,z)
1175+
Equal(x,z)
1176+
)
1177+
"#,
1178+
&params,
1179+
&[],
1180+
)
1181+
.unwrap()
1182+
.custom_batch;
1183+
let mut builder = MainPodBuilder::new(&params, &DEFAULT_VD_SET);
1184+
let cpr = CustomPredicateRef { batch, index: 0 };
1185+
let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap();
1186+
let op = frontend::Operation::custom(
1187+
cpr.clone(),
1188+
[
1189+
eq_st,
1190+
middleware::Statement::None,
1191+
middleware::Statement::None,
1192+
],
1193+
);
1194+
let st = middleware::Statement::Custom(
1195+
cpr,
1196+
[1, 1, 2].into_iter().map(middleware::Value::from).collect(),
1197+
);
1198+
builder.insert(true, (st, op)).unwrap();
1199+
let prover = Prover {};
1200+
builder.prove(&prover).unwrap();
1201+
}
11571202
}

src/middleware/operation.rs

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{fmt, iter};
22

3+
use itertools::Itertools;
34
use log::error;
45
use plonky2::field::types::Field;
56
use serde::{Deserialize, Serialize};
@@ -14,8 +15,8 @@ use crate::{
1415
},
1516
middleware::{
1617
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
17-
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmpl,
18-
StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
18+
MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg,
19+
StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
1920
},
2021
};
2122

@@ -613,27 +614,37 @@ pub fn check_st_tmpl(
613614
}
614615
}
615616

616-
pub fn resolve_wildcard_values(
617-
params: &Params,
617+
pub fn fill_wildcard_values(
618618
pred: &CustomPredicate,
619619
args: &[Statement],
620-
) -> Result<Vec<Value>> {
621-
// Check that all wildcard have consistent values as assigned in the statements while storing a
622-
// map of their values.
623-
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
624-
// disjunctions we expect Statement::None for the unused statements.
625-
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
620+
wildcard_map: &mut [Option<Value>],
621+
) -> Result<()> {
626622
for (st_tmpl, st) in pred.statements.iter().zip(args) {
627623
let st_args = st.args();
628624
st_tmpl
629625
.args
630626
.iter()
631627
.zip(&st_args)
632628
.try_for_each(|(st_tmpl_arg, st_arg)| {
633-
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
629+
check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map)
634630
})?;
635631
}
632+
Ok(())
633+
}
636634

635+
pub fn wildcard_values_from_op_st(
636+
params: &Params,
637+
pred: &CustomPredicate,
638+
op_args: &[Statement],
639+
st_args: &[Value],
640+
) -> Result<Vec<Value>> {
641+
let mut wildcard_map = st_args
642+
.iter()
643+
.map(|v| Some(v.clone()))
644+
.chain(core::iter::repeat(None))
645+
.take(params.max_custom_predicate_wildcards)
646+
.collect_vec();
647+
fill_wildcard_values(pred, op_args, &mut wildcard_map)?;
637648
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
638649
// they are beyond the number of used wildcards in this custom predicate, or they could be
639650
// private arguments that are unused in a particular disjunction.
@@ -717,21 +728,24 @@ pub(crate) fn check_custom_pred(
717728
));
718729
}
719730

720-
let wildcard_map = resolve_wildcard_values(params, pred, args)?;
721-
722731
// Check that the resolved wildcards match the statement arguments.
723-
for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() {
724-
if *wc_value != *s_arg {
725-
return Err(Error::mismatched_wildcard_value_and_statement_arg(
726-
wc_value.clone(),
727-
s_arg.clone(),
728-
arg_index,
729-
pred.clone(),
730-
));
731-
}
732+
match wildcard_values_from_op_st(params, pred, args, s_args) {
733+
Ok(_) => Ok(()),
734+
Err(Error::Inner { inner, backtrace }) => match *inner {
735+
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
736+
if wc.index <= s_args.len() =>
737+
{
738+
Err(Error::mismatched_wildcard_value_and_statement_arg(
739+
v,
740+
prev,
741+
wc.index,
742+
pred.clone(),
743+
))
744+
}
745+
_ => Err(Error::Inner { inner, backtrace }),
746+
},
747+
_ => unreachable!(),
732748
}
733-
734-
Ok(())
735749
}
736750

737751
impl ToFields for Operation {

0 commit comments

Comments
 (0)