diff --git a/book/src/anchoredkeys.md b/book/src/anchoredkeys.md index b6875184..fcdbb4fc 100644 --- a/book/src/anchoredkeys.md +++ b/book/src/anchoredkeys.md @@ -6,6 +6,8 @@ type AnchoredKey = (Origin, Key) type Key = String ``` +FIXME: This description is incorrect. We don't have *gadget ID*. And we don't think of *origin* as a triple, we just see it as a single value that encodes the *pod ID* or SELF. + An *origin* is a triple consisting of a numeric identifier called the *origin ID*, a string called the *origin name* (omitted in the backend) and another numeric identifier called the *gadget ID*, which identifies the means by which the value corresponding to a given key is produced. The origin ID is defined to be 0 for 'no origin' and 1 for 'self origin', otherwise it is the content ID[^content-id] of the POD to which it refers. The origin name is not cryptographically significant and is merely a convenience for the frontend. @@ -19,4 +21,4 @@ The gadget ID takes on the values in the following table: | 2 | `MainPOD` gadget: The key-value pair was produced in the construction of a `MainPOD`. | For example, a gadget ID of 1 implies that the key-value pair in question was produced in the process of constructing a `SignedPOD`. -[^content-id]: TODO Refer to this when it is documented. \ No newline at end of file +[^content-id]: TODO Refer to this when it is documented. diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 2b6187ad..239b2d0f 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -579,10 +579,22 @@ pub struct CustomPredicateVerifyEntryTarget { pub custom_predicate_table_index: Target, pub custom_predicate: CustomPredicateEntryTarget, pub args: Vec, - pub query: CustomPredicateVerifyQueryTarget, + pub op_args: Vec, } impl CustomPredicateVerifyEntryTarget { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + CustomPredicateVerifyEntryTarget { + custom_predicate_table_index: builder.add_virtual_target(), + custom_predicate: builder.add_virtual_custom_predicate_entry(params), + args: (0..params.max_custom_predicate_wildcards) + .map(|_| builder.add_virtual_value()) + .collect(), + op_args: (0..params.max_operation_args) + .map(|_| builder.add_virtual_statement(params)) + .collect(), + } + } pub fn set_targets( &self, pw: &mut PartialWitness, @@ -606,7 +618,7 @@ impl CustomPredicateVerifyEntryTarget { arg_target.set_targets(pw, &Value::from(arg.raw()))?; } let pad_op_arg = Statement(Predicate::Native(NativePredicate::None), vec![]); - for (op_arg_target, op_arg) in self.query.op_args.iter().zip_eq( + for (op_arg_target, op_arg) in self.op_args.iter().zip_eq( cpv.op_args .iter() .chain(iter::repeat(&pad_op_arg)) diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index d740365b..2d1a7fa4 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -1,6 +1,6 @@ use std::{array, iter, sync::Arc}; -use itertools::{zip_eq, Itertools}; +use itertools::{izip, zip_eq, Itertools}; use plonky2::{ field::types::Field, hash::{ @@ -26,13 +26,13 @@ use crate::{ OperationTypeTarget, PredicateTarget, StatementArgTarget, StatementTarget, StatementTmplArgTarget, StatementTmplTarget, ValueTarget, }, - signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget}, + signedpod::{verify_signed_pod_circuit, SignedPodVerifyTarget}, }, emptypod::{EmptyPod, STANDARD_EMPTY_POD_DATA}, error::Result, mainpod::{self, pad_statement}, primitives::merkletree::{ - MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget, + verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget, }, recursion::{InnerCircuit, VerifiedProofTarget}, signedpod::SignedPod, @@ -56,10 +56,6 @@ pub const PI_OFFSET_VDSROOT: usize = 4; pub const NUM_PUBLIC_INPUTS: usize = 8; -struct OperationVerifyGadget { - params: Params, -} - const MAX_VALUE_ARGS: usize = 3; struct StatementArgCache { @@ -140,660 +136,659 @@ impl StatementCache { } } -impl OperationVerifyGadget { - #[allow(clippy::too_many_arguments)] - fn eval( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op: &OperationTarget, - prev_statements: &[StatementTarget], - input_statements_offset: usize, - merkle_claims: &[MerkleClaimTarget], - custom_predicate_verification_table: &[HashOutTarget], - ) -> Result<()> { - let measure = measure_gates_begin!(builder, "OpVerify"); - let _true = builder._true(); - let _false = builder._false(); - - // Verify that the operation `op` correctly generates the statement `st`. The operation - // can reference any of the `prev_statements`. - // TODO: Clean this up. - let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - let cache = StatementCache::new(&self.params, builder, op, st, prev_statements); - measure_gates_end!(builder, measure_resolve_op_args); - // TODO: Can we have a single table with merkel claims and verified custom predicates - // together (with an identifying prefix) and then we only need one random access instead of - // two? - // Currently we use one slot of aux for the index to merkle claim and another slot of aux - // for the index to the verified custom predicate. We can't use the same slot because then - // if one table is different size the random access to the smaller one may use an index - // that is too big and not pass the constraints. Possible solutions to use a single slot - // are: - // - a. Use a single table (mux both tables) - // - b. select the index or 0 by checking the operation type here; but that breaks the - // current abstraction a little bit. - - // Certain operations (Contains/NotContains) will refer to one - // of the provided Merkle proofs (if any). These proofs have already - // been verified, so we need only look up the claim. - let measure_resolve_merkle_claim = measure_gates_begin!(builder, "ResolveMerkleClaim"); - let resolved_merkle_claim = (!merkle_claims.is_empty()) - .then(|| builder.vec_ref(&self.params, merkle_claims, op.aux[0])); - measure_gates_end!(builder, measure_resolve_merkle_claim); - - // Operations from custom statements will refer to one - // of the provided custom predicates verifications (if any). These operations have already - // been verified, so we need only look up the entry. - let measure_resolve_custom_pred_verification = - measure_gates_begin!(builder, "ResolveCustomPredVerification"); - let resolved_custom_pred_verification = (!custom_predicate_verification_table.is_empty()) - .then(|| builder.vec_ref(&self.params, custom_predicate_verification_table, op.aux[1])); - measure_gates_end!(builder, measure_resolve_custom_pred_verification); - - // The verification may require aux data which needs to be stored in the - // `OperationVerifyTarget` so that we can set during witness generation. - - // For now only support native operations - // Op checks to carry out. Each 'eval_X' should - // be thought of as 'eval' restricted to the op of type X, - // where the returned target is `false` if the input targets - // lie outside of the domain. - let op_checks = [ +#[allow(clippy::too_many_arguments)] +fn verify_operation_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op: &OperationTarget, + prev_statements: &[StatementTarget], + input_statements_offset: usize, + merkle_claims: &[MerkleClaimTarget], + custom_predicate_verification_table: &[HashOutTarget], +) -> Result<()> { + let measure = measure_gates_begin!(builder, "OpVerify"); + let _true = builder._true(); + let _false = builder._false(); + + // Verify that the operation `op` correctly generates the statement `st`. The operation + // can reference any of the `prev_statements`. + // TODO: Clean this up. + let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); + let cache = StatementCache::new(params, builder, op, st, prev_statements); + measure_gates_end!(builder, measure_resolve_op_args); + // TODO: Can we have a single table with merkel claims and verified custom predicates + // together (with an identifying prefix) and then we only need one random access instead of + // two? + // Currently we use one slot of aux for the index to merkle claim and another slot of aux + // for the index to the verified custom predicate. We can't use the same slot because then + // if one table is different size the random access to the smaller one may use an index + // that is too big and not pass the constraints. Possible solutions to use a single slot + // are: + // - a. Use a single table (mux both tables) + // - b. select the index or 0 by checking the operation type here; but that breaks the + // current abstraction a little bit. + + // Certain operations (Contains/NotContains) will refer to one + // of the provided Merkle proofs (if any). These proofs have already + // been verified, so we need only look up the claim. + let measure_resolve_merkle_claim = measure_gates_begin!(builder, "ResolveMerkleClaim"); + let resolved_merkle_claim = + (!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, op.aux[0])); + measure_gates_end!(builder, measure_resolve_merkle_claim); + + // Operations from custom statements will refer to one + // of the provided custom predicates verifications (if any). These operations have already + // been verified, so we need only look up the entry. + let measure_resolve_custom_pred_verification = + measure_gates_begin!(builder, "ResolveCustomPredVerification"); + let resolved_custom_pred_verification = (!custom_predicate_verification_table.is_empty()) + .then(|| builder.vec_ref(params, custom_predicate_verification_table, op.aux[1])); + measure_gates_end!(builder, measure_resolve_custom_pred_verification); + + // The verification may require aux data which needs to be stored in the + // `OperationVerifyTarget` so that we can set during witness generation. + + // For now only support native operations + // Op checks to carry out. Each 'eval_X' should + // be thought of as 'eval' restricted to the op of type X, + // where the returned target is `false` if the input targets + // lie outside of the domain. + let op_checks = [ + vec![ + verify_none_circuit(params, builder, st, &op.op_type), + verify_new_entry_circuit( + params, + builder, + st, + &op.op_type, + prev_statements, + input_statements_offset, + ), + ], + // Skip these if there are no resolved op args + if cache.op_args.is_empty() { + vec![] + } else { vec![ - self.eval_none(builder, st, &op.op_type), - self.eval_new_entry( + verify_copy_circuit(builder, st, &op.op_type, &cache.op_args), + verify_eq_neq_from_entries_circuit(params, builder, st, &op.op_type, &cache), + verify_lt_lteq_from_entries_circuit(params, builder, st, &op.op_type, &cache), + verify_transitive_eq_circuit(params, builder, st, &op.op_type, &cache.op_args), + verify_lt_to_neq_circuit(params, builder, st, &op.op_type, &cache.op_args), + verify_hash_of_circuit(params, builder, st, &op.op_type, &cache), + verify_sum_of_circuit(params, builder, st, &op.op_type, &cache), + verify_product_of_circuit(params, builder, st, &op.op_type, &cache), + verify_max_of_circuit(params, builder, st, &op.op_type, &cache), + ] + }, + // Skip these if there are no resolved Merkle claims + if let Some(resolved_merkle_claim) = resolved_merkle_claim { + vec![ + verify_contains_from_entries_circuit( + params, builder, st, &op.op_type, - prev_statements, - input_statements_offset, + resolved_merkle_claim, + &cache, ), - ], - // Skip these if there are no resolved op args - if cache.op_args.is_empty() { - vec![] - } else { - vec![ - self.eval_copy(builder, st, &op.op_type, &cache.op_args), - self.eval_eq_neq_from_entries(builder, st, &op.op_type, &cache), - self.eval_lt_lteq_from_entries(builder, st, &op.op_type, &cache), - self.eval_transitive_eq(builder, st, &op.op_type, &cache.op_args), - self.eval_lt_to_neq(builder, st, &op.op_type, &cache.op_args), - self.eval_hash_of(builder, st, &op.op_type, &cache), - self.eval_sum_of(builder, st, &op.op_type, &cache), - self.eval_product_of(builder, st, &op.op_type, &cache), - self.eval_max_of(builder, st, &op.op_type, &cache), - ] - }, - // Skip these if there are no resolved Merkle claims - if let Some(resolved_merkle_claim) = resolved_merkle_claim { - vec![ - self.eval_contains_from_entries( - builder, - st, - &op.op_type, - resolved_merkle_claim, - &cache, - ), - self.eval_not_contains_from_entries( - builder, - st, - &op.op_type, - resolved_merkle_claim, - &cache, - ), - ] - } else { - vec![] - }, - // Skip these if there are no resolved custom predicate verifications - if let Some(resolved_custom_pred_verification) = resolved_custom_pred_verification { - vec![self.eval_custom( + verify_not_contains_from_entries_circuit( + params, builder, st, &op.op_type, - resolved_custom_pred_verification, - &cache.op_args, - )] - } else { - vec![] - }, - ] - .concat(); - - let ok = builder.any(op_checks); - builder.assert_one(ok.target); + resolved_merkle_claim, + &cache, + ), + ] + } else { + vec![] + }, + // Skip these if there are no resolved custom predicate verifications + if let Some(resolved_custom_pred_verification) = resolved_custom_pred_verification { + vec![verify_custom_circuit( + builder, + st, + &op.op_type, + resolved_custom_pred_verification, + &cache.op_args, + )] + } else { + vec![] + }, + ] + .concat(); - measure_gates_end!(builder, measure); - Ok(()) - } + let ok = builder.any(op_checks); + builder.assert_one(ok.target); - fn eval_contains_from_entries( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - resolved_merkle_claim: MerkleClaimTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); - let op_code_ok = op_type.has_native(builder, NativeOperation::ContainsFromEntries); - - let (arg_types_ok, [merkle_root_value, key_value, value_value]) = - cache.first_n_args_as_values(); - - // Check Merkle proof (verified elsewhere) against op args. - let merkle_proof_checks = [ - /* The supplied Merkle proof must be enabled. */ - resolved_merkle_claim.enabled, - /* ...and it must be an existence proof. */ - resolved_merkle_claim.existence, - /* ...for the root-key-value triple in the resolved op args. */ - builder.is_equal_slice( - &merkle_root_value.elements, - &resolved_merkle_claim.root.elements, - ), - builder.is_equal_slice(&key_value.elements, &resolved_merkle_claim.key.elements), - builder.is_equal_slice(&value_value.elements, &resolved_merkle_claim.value.elements), - ]; - - let merkle_proof_ok = builder.all(merkle_proof_checks); + measure_gates_end!(builder, measure); + Ok(()) +} - // Check output statement - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); - let arg3_expected = cache.equations[2].lhs.clone(); - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::Contains, - &[arg1_expected, arg2_expected, arg3_expected], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); +// +// Native operation constraints +// - let ok = builder.all([op_code_ok, arg_types_ok, merkle_proof_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } +fn verify_contains_from_entries_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + resolved_merkle_claim: MerkleClaimTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); + let op_code_ok = op_type.has_native(builder, NativeOperation::ContainsFromEntries); + + let (arg_types_ok, [merkle_root_value, key_value, value_value]) = + cache.first_n_args_as_values(); + + // Check Merkle proof (verified elsewhere) against op args. + let merkle_proof_checks = [ + /* The supplied Merkle proof must be enabled. */ + resolved_merkle_claim.enabled, + /* ...and it must be an existence proof. */ + resolved_merkle_claim.existence, + /* ...for the root-key-value triple in the resolved op args. */ + builder.is_equal_slice( + &merkle_root_value.elements, + &resolved_merkle_claim.root.elements, + ), + builder.is_equal_slice(&key_value.elements, &resolved_merkle_claim.key.elements), + builder.is_equal_slice(&value_value.elements, &resolved_merkle_claim.value.elements), + ]; + + let merkle_proof_ok = builder.all(merkle_proof_checks); + + // Check output statement + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::Contains, + &[arg1_expected, arg2_expected, arg3_expected], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - fn eval_not_contains_from_entries( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - resolved_merkle_claim: MerkleClaimTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); - let op_code_ok = op_type.has_native(builder, NativeOperation::NotContainsFromEntries); - - let (arg_types_ok, [merkle_root_value, key_value]) = cache.first_n_args_as_values(); - - // Check Merkle proof (verified elsewhere) against op args. - let merkle_proof_checks = [ - /* The supplied Merkle proof must be enabled. */ - resolved_merkle_claim.enabled, - /* ...and it must be a nonexistence proof. */ - builder.not(resolved_merkle_claim.existence), - /* ...for the root-key pair in the resolved op args. */ - builder.is_equal_slice( - &merkle_root_value.elements, - &resolved_merkle_claim.root.elements, - ), - builder.is_equal_slice(&key_value.elements, &resolved_merkle_claim.key.elements), - ]; + let ok = builder.all([op_code_ok, arg_types_ok, merkle_proof_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} - let merkle_proof_ok = builder.all(merkle_proof_checks); +fn verify_not_contains_from_entries_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + resolved_merkle_claim: MerkleClaimTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); + let op_code_ok = op_type.has_native(builder, NativeOperation::NotContainsFromEntries); + + let (arg_types_ok, [merkle_root_value, key_value]) = cache.first_n_args_as_values(); + + // Check Merkle proof (verified elsewhere) against op args. + let merkle_proof_checks = [ + /* The supplied Merkle proof must be enabled. */ + resolved_merkle_claim.enabled, + /* ...and it must be a nonexistence proof. */ + builder.not(resolved_merkle_claim.existence), + /* ...for the root-key pair in the resolved op args. */ + builder.is_equal_slice( + &merkle_root_value.elements, + &resolved_merkle_claim.root.elements, + ), + builder.is_equal_slice(&key_value.elements, &resolved_merkle_claim.key.elements), + ]; + + let merkle_proof_ok = builder.all(merkle_proof_checks); + + // Check output statement + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::NotContains, + &[arg1_expected, arg2_expected], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - // Check output statement - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::NotContains, - &[arg1_expected, arg2_expected], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); + let ok = builder.all([op_code_ok, arg_types_ok, merkle_proof_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} - let ok = builder.all([op_code_ok, arg_types_ok, merkle_proof_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } +fn verify_custom_circuit( + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + resolved_custom_pred_verification: HashOutTarget, + resolved_op_args: &[StatementTarget], +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpCustom"); + let query = CustomPredicateVerifyQueryTarget { + statement: st.clone(), + op_type: op_type.clone(), + op_args: resolved_op_args.to_vec(), + }; + let out_query_hash = query.hash(builder); + let ok = builder.is_equal_slice( + &resolved_custom_pred_verification.elements, + &out_query_hash.elements, + ); + measure_gates_end!(builder, measure); + ok +} - fn eval_custom( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - resolved_custom_pred_verification: HashOutTarget, - resolved_op_args: &[StatementTarget], - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpCustom"); - let query = CustomPredicateVerifyQueryTarget { - statement: st.clone(), - op_type: op_type.clone(), - op_args: resolved_op_args.to_vec(), - }; - let out_query_hash = query.hash(builder); - let ok = builder.is_equal_slice( - &resolved_custom_pred_verification.elements, - &out_query_hash.elements, - ); - measure_gates_end!(builder, measure); - ok - } +/// Carries out the checks necessary for EqualFromEntries and +/// NotEqualFromEntries. +fn verify_eq_neq_from_entries_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); + let eq_op_st_code_ok = { + let op_code_ok = op_type.has_native(builder, NativeOperation::EqualFromEntries); + let st_code_ok = st.has_native_type(builder, params, NativePredicate::Equal); + builder.and(op_code_ok, st_code_ok) + }; + let neq_op_st_code_ok = { + let op_code_ok = op_type.has_native(builder, NativeOperation::NotEqualFromEntries); + let st_code_ok = st.has_native_type(builder, params, NativePredicate::NotEqual); + builder.and(op_code_ok, st_code_ok) + }; + let op_st_code_ok = builder.or(eq_op_st_code_ok, neq_op_st_code_ok); - /// Carries out the checks necessary for EqualFromEntries and - /// NotEqualFromEntries. - fn eval_eq_neq_from_entries( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); - let eq_op_st_code_ok = { - let op_code_ok = op_type.has_native(builder, NativeOperation::EqualFromEntries); - let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::Equal); - builder.and(op_code_ok, st_code_ok) - }; - let neq_op_st_code_ok = { - let op_code_ok = op_type.has_native(builder, NativeOperation::NotEqualFromEntries); - let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::NotEqual); - builder.and(op_code_ok, st_code_ok) - }; - let op_st_code_ok = builder.or(eq_op_st_code_ok, neq_op_st_code_ok); + let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values(); - let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values(); + let op_args_eq = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements); + let op_args_ok = builder.is_equal(op_args_eq.target, eq_op_st_code_ok.target); - let op_args_eq = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements); - let op_args_ok = builder.is_equal(op_args_eq.target, eq_op_st_code_ok.target); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); + let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] + .into_iter() + .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) + .take(params.max_statement_args) + .flat_map(|arg| arg.elements) + .collect(); - let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] - .into_iter() - .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) - .take(self.params.max_statement_args) + let st_args_ok = builder.is_equal_slice( + &expected_st_args, + &st.args + .iter() .flat_map(|arg| arg.elements) - .collect(); + .collect::>(), + ); - let st_args_ok = builder.is_equal_slice( - &expected_st_args, - &st.args - .iter() - .flat_map(|arg| arg.elements) - .collect::>(), - ); + let ok = builder.all([op_st_code_ok, arg_types_ok, op_args_ok, st_args_ok]); + measure_gates_end!(builder, measure); + ok +} - let ok = builder.all([op_st_code_ok, arg_types_ok, op_args_ok, st_args_ok]); - measure_gates_end!(builder, measure); - ok - } +/// Carries out the checks necessary for LtFromEntries and +/// LtEqFromEntries. +fn verify_lt_lteq_from_entries_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries"); + let zero = ValueTarget::zero(builder); + let one = ValueTarget::one(builder); + + let lt_op_st_code_ok = { + let op_code_ok = op_type.has_native(builder, NativeOperation::LtFromEntries); + let st_code_ok = st.has_native_type(builder, params, NativePredicate::Lt); + builder.and(op_code_ok, st_code_ok) + }; + let lteq_op_st_code_ok = { + let op_code_ok = op_type.has_native(builder, NativeOperation::LtEqFromEntries); + let st_code_ok = st.has_native_type(builder, params, NativePredicate::LtEq); + builder.and(op_code_ok, st_code_ok) + }; + let op_st_code_ok = builder.or(lt_op_st_code_ok, lteq_op_st_code_ok); - /// Carries out the checks necessary for LtFromEntries and - /// LtEqFromEntries. - fn eval_lt_lteq_from_entries( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries"); - let zero = ValueTarget::zero(builder); - let one = ValueTarget::one(builder); - - let lt_op_st_code_ok = { - let op_code_ok = op_type.has_native(builder, NativeOperation::LtFromEntries); - let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::Lt); - builder.and(op_code_ok, st_code_ok) - }; - let lteq_op_st_code_ok = { - let op_code_ok = op_type.has_native(builder, NativeOperation::LtEqFromEntries); - let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::LtEq); - builder.and(op_code_ok, st_code_ok) - }; - let op_st_code_ok = builder.or(lt_op_st_code_ok, lteq_op_st_code_ok); + let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values(); - let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values(); + // If we are not dealing with the right op & statement types, + // replace args with dummy values in the following checks. + let value1 = builder.select_value(op_st_code_ok, arg1_value, zero); + let value2 = builder.select_value(op_st_code_ok, arg2_value, one); - // If we are not dealing with the right op & statement types, - // replace args with dummy values in the following checks. - let value1 = builder.select_value(op_st_code_ok, arg1_value, zero); - let value2 = builder.select_value(op_st_code_ok, arg2_value, one); + // Range check + builder.assert_i64(value1); + builder.assert_i64(value2); - // Range check - builder.assert_i64(value1); - builder.assert_i64(value2); + // Check for equality. + let args_equal = builder.is_equal_slice(&value1.elements, &value2.elements); - // Check for equality. - let args_equal = builder.is_equal_slice(&value1.elements, &value2.elements); + // Check < if applicable. + let lt_check_flag = { + let not_args_equal = builder.not(args_equal); + let lteq_eq_case = builder.and(lteq_op_st_code_ok, not_args_equal); + builder.or(lt_op_st_code_ok, lteq_eq_case) + }; + builder.assert_i64_less_if(lt_check_flag, value1, value2); - // Check < if applicable. - let lt_check_flag = { - let not_args_equal = builder.not(args_equal); - let lteq_eq_case = builder.and(lteq_op_st_code_ok, not_args_equal); - builder.or(lt_op_st_code_ok, lteq_eq_case) - }; - builder.assert_i64_less_if(lt_check_flag, value1, value2); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); + let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] + .into_iter() + .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) + .take(params.max_statement_args) + .flat_map(|arg| arg.elements) + .collect(); - let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] - .into_iter() - .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) - .take(self.params.max_statement_args) + let st_args_ok = builder.is_equal_slice( + &expected_st_args, + &st.args + .iter() .flat_map(|arg| arg.elements) - .collect(); - - let st_args_ok = builder.is_equal_slice( - &expected_st_args, - &st.args - .iter() - .flat_map(|arg| arg.elements) - .collect::>(), - ); - - let ok = builder.all([op_st_code_ok, arg_types_ok, st_args_ok]); - measure_gates_end!(builder, measure); - ok - } - - fn eval_hash_of( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpHashOf"); - let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); - - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); - - let expected_hash_value = builder.hash_values(arg2_value, arg3_value); - - let hash_value_ok = - builder.is_equal_slice(&arg1_value.elements, &expected_hash_value.elements); - - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); - let arg3_expected = cache.equations[2].lhs.clone(); - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::HashOf, - &[arg1_expected, arg2_expected, arg3_expected], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); - - let ok = builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } - - fn eval_sum_of( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpSumOf"); - let value_zero = ValueTarget::zero(builder); - - let op_code_ok = op_type.has_native(builder, NativeOperation::SumOf); - - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); - - // Select to avoid overflow. - let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero); - let summand2 = builder.select_value(op_code_ok, arg3_value, value_zero); - - let expected_sum = builder.i64_add(summand1, summand2); - - let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements); - - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); - let arg3_expected = cache.equations[2].lhs.clone(); - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::SumOf, - &[arg1_expected, arg2_expected, arg3_expected], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); - - let ok = builder.all([op_code_ok, arg_types_ok, sum_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } - - fn eval_product_of( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpProductOf"); - let value_zero = ValueTarget::zero(builder); - - let op_code_ok = op_type.has_native(builder, NativeOperation::ProductOf); - - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); - - // Select to avoid overflow. - let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero); - let factor2 = builder.select_value(op_code_ok, arg3_value, value_zero); - - let expected_product = builder.i64_mul(factor1, factor2); - - let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements); - - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); - let arg3_expected = cache.equations[2].lhs.clone(); - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::ProductOf, - &[arg1_expected, arg2_expected, arg3_expected], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); - - let ok = builder.all([op_code_ok, arg_types_ok, product_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } - - fn eval_max_of( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - cache: &StatementCache, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpMaxOf"); - let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); - - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); - - // Check that arg1_value is equal to one of the other two - // values. - let arg1_eq_arg2 = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements); - let arg1_eq_arg3 = builder.is_equal_slice(&arg1_value.elements, &arg3_value.elements); - - let all_eq = builder.and(arg1_eq_arg2, arg1_eq_arg3); - let not_all_eq = builder.not(all_eq); + .collect::>(), + ); - let arg1_check = builder.or(arg1_eq_arg2, arg1_eq_arg3); + let ok = builder.all([op_st_code_ok, arg_types_ok, st_args_ok]); + measure_gates_end!(builder, measure); + ok +} - // If it is not equal to any of the other two values, it must be greater than it. - let lower_bound = builder.select_value(arg1_eq_arg2, arg3_value, arg2_value); +fn verify_hash_of_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpHashOf"); + let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); + + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); + + let expected_hash_value = builder.hash_values(arg2_value, arg3_value); + + let hash_value_ok = builder.is_equal_slice(&arg1_value.elements, &expected_hash_value.elements); + + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::HashOf, + &[arg1_expected, arg2_expected, arg3_expected], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - // Only check lower bound if not all args are equal. - let lt_check_enabled = builder.and(not_all_eq, op_code_ok); - builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value); + let ok = builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} - let arg1_expected = cache.equations[0].lhs.clone(); - let arg2_expected = cache.equations[1].lhs.clone(); - let arg3_expected = cache.equations[2].lhs.clone(); - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::MaxOf, - &[arg1_expected, arg2_expected, arg3_expected], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); +fn verify_sum_of_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpSumOf"); + let value_zero = ValueTarget::zero(builder); + + let op_code_ok = op_type.has_native(builder, NativeOperation::SumOf); + + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); + + // Select to avoid overflow. + let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero); + let summand2 = builder.select_value(op_code_ok, arg3_value, value_zero); + + let expected_sum = builder.i64_add(summand1, summand2); + + let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements); + + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::SumOf, + &[arg1_expected, arg2_expected, arg3_expected], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - let ok = builder.all([op_code_ok, arg_types_ok, arg1_check, st_ok]); - measure_gates_end!(builder, measure); - ok - } + let ok = builder.all([op_code_ok, arg_types_ok, sum_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} - fn eval_transitive_eq( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpTransitiveEq"); - let op_code_ok = - op_type.has_native(builder, NativeOperation::TransitiveEqualFromStatements); - - let arg1_type_ok = - resolved_op_args[0].has_native_type(builder, &self.params, NativePredicate::Equal); - let arg2_type_ok = - resolved_op_args[1].has_native_type(builder, &self.params, NativePredicate::Equal); - let arg_types_ok = builder.all([arg1_type_ok, arg2_type_ok]); - - let arg1_lhs = &resolved_op_args[0].args[0]; - let arg1_rhs = &resolved_op_args[0].args[1]; - let arg2_lhs = &resolved_op_args[1].args[0]; - let arg2_rhs = &resolved_op_args[1].args[1]; - - let inner_args_match = builder.is_equal_slice(&arg1_rhs.elements, &arg2_lhs.elements); - - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::Equal, - &[arg1_lhs.clone(), arg2_rhs.clone()], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); +fn verify_product_of_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpProductOf"); + let value_zero = ValueTarget::zero(builder); + + let op_code_ok = op_type.has_native(builder, NativeOperation::ProductOf); + + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); + + // Select to avoid overflow. + let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero); + let factor2 = builder.select_value(op_code_ok, arg3_value, value_zero); + + let expected_product = builder.i64_mul(factor1, factor2); + + let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements); + + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::ProductOf, + &[arg1_expected, arg2_expected, arg3_expected], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - let ok = builder.all([op_code_ok, arg_types_ok, inner_args_match, st_ok]); - measure_gates_end!(builder, measure); - ok - } - fn eval_none( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpNone"); - let op_code_ok = op_type.has_native(builder, NativeOperation::None); - - let expected_statement = - StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); - - let ok = builder.all([op_code_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } + let ok = builder.all([op_code_ok, arg_types_ok, product_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} - fn eval_new_entry( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - prev_statements: &[StatementTarget], - input_statements_offset: usize, - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpNewEntry"); - let op_code_ok = op_type.has_native(builder, NativeOperation::NewEntry); - let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::Equal); - - let expected_arg_prefix = builder.constants( - &StatementArg::Key(AnchoredKey::from((SELF, ""))).to_fields(&self.params)[..VALUE_SIZE], - ); - let arg_prefix_ok = - builder.is_equal_slice(&st.args[0].elements[..VALUE_SIZE], &expected_arg_prefix); +fn verify_max_of_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCache, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpMaxOf"); + let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); + + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); + + // Check that arg1_value is equal to one of the other two + // values. + let arg1_eq_arg2 = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements); + let arg1_eq_arg3 = builder.is_equal_slice(&arg1_value.elements, &arg3_value.elements); + + let all_eq = builder.and(arg1_eq_arg2, arg1_eq_arg3); + let not_all_eq = builder.not(all_eq); + + let arg1_check = builder.or(arg1_eq_arg2, arg1_eq_arg3); + + // If it is not equal to any of the other two values, it must be greater than it. + let lower_bound = builder.select_value(arg1_eq_arg2, arg3_value, arg2_value); + + // Only check lower bound if not all args are equal. + let lt_check_enabled = builder.and(not_all_eq, op_code_ok); + builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value); + + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::MaxOf, + &[arg1_expected, arg2_expected, arg3_expected], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - let input_statements = &prev_statements[input_statements_offset..]; - let individual_dupe_checks = input_statements - .iter() - .map(|ps| builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements)) - .collect::>(); - let dupe_check = builder.any(individual_dupe_checks); - let no_dupes_ok = builder.not(dupe_check); - - let ok = builder.all([op_code_ok, st_code_ok, arg_prefix_ok, no_dupes_ok]); - measure_gates_end!(builder, measure); - ok - } + let ok = builder.all([op_code_ok, arg_types_ok, arg1_check, st_ok]); + measure_gates_end!(builder, measure); + ok +} - fn eval_lt_to_neq( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpLtToNeq"); - let op_code_ok = op_type.has_native(builder, NativeOperation::LtToNotEqual); +fn verify_transitive_eq_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + resolved_op_args: &[StatementTarget], +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpTransitiveEq"); + let op_code_ok = op_type.has_native(builder, NativeOperation::TransitiveEqualFromStatements); + + let arg1_type_ok = resolved_op_args[0].has_native_type(builder, params, NativePredicate::Equal); + let arg2_type_ok = resolved_op_args[1].has_native_type(builder, params, NativePredicate::Equal); + let arg_types_ok = builder.all([arg1_type_ok, arg2_type_ok]); + + let arg1_lhs = &resolved_op_args[0].args[0]; + let arg1_rhs = &resolved_op_args[0].args[1]; + let arg2_lhs = &resolved_op_args[1].args[0]; + let arg2_rhs = &resolved_op_args[1].args[1]; + + let inner_args_match = builder.is_equal_slice(&arg1_rhs.elements, &arg2_lhs.elements); + + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::Equal, + &[arg1_lhs.clone(), arg2_rhs.clone()], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - let arg_type_ok = - resolved_op_args[0].has_native_type(builder, &self.params, NativePredicate::Lt); + let ok = builder.all([op_code_ok, arg_types_ok, inner_args_match, st_ok]); + measure_gates_end!(builder, measure); + ok +} - let arg1_expected = resolved_op_args[0].args[0].clone(); - let arg2_expected = resolved_op_args[0].args[1].clone(); +fn verify_none_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpNone"); + let op_code_ok = op_type.has_native(builder, NativeOperation::None); + + let expected_statement = + StatementTarget::new_native(builder, params, NativePredicate::None, &[]); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); + + let ok = builder.all([op_code_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} - let expected_statement = StatementTarget::new_native( - builder, - &self.params, - NativePredicate::NotEqual, - &[arg1_expected, arg2_expected], - ); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); +fn verify_new_entry_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + prev_statements: &[StatementTarget], + input_statements_offset: usize, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpNewEntry"); + let op_code_ok = op_type.has_native(builder, NativeOperation::NewEntry); + let st_code_ok = st.has_native_type(builder, params, NativePredicate::Equal); + + let expected_arg_prefix = builder.constants( + &StatementArg::Key(AnchoredKey::from((SELF, ""))).to_fields(params)[..VALUE_SIZE], + ); + let arg_prefix_ok = + builder.is_equal_slice(&st.args[0].elements[..VALUE_SIZE], &expected_arg_prefix); + + let input_statements = &prev_statements[input_statements_offset..]; + let individual_dupe_checks = input_statements + .iter() + .map(|ps| builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements)) + .collect::>(); + let dupe_check = builder.any(individual_dupe_checks); + let no_dupes_ok = builder.not(dupe_check); + + let ok = builder.all([op_code_ok, st_code_ok, arg_prefix_ok, no_dupes_ok]); + measure_gates_end!(builder, measure); + ok +} - let ok = builder.all([op_code_ok, arg_type_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } +fn verify_lt_to_neq_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + resolved_op_args: &[StatementTarget], +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpLtToNeq"); + let op_code_ok = op_type.has_native(builder, NativeOperation::LtToNotEqual); + + let arg_type_ok = resolved_op_args[0].has_native_type(builder, params, NativePredicate::Lt); + + let arg1_expected = resolved_op_args[0].args[0].clone(); + let arg2_expected = resolved_op_args[0].args[1].clone(); + + let expected_statement = StatementTarget::new_native( + builder, + params, + NativePredicate::NotEqual, + &[arg1_expected, arg2_expected], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); - fn eval_copy( - &self, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], - ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpCopy"); - let op_code_ok = op_type.has_native(builder, NativeOperation::CopyStatement); - - let expected_statement = &resolved_op_args[0]; - let st_ok = builder.is_equal_flattenable(st, expected_statement); - - let ok = builder.all([op_code_ok, st_ok]); - measure_gates_end!(builder, measure); - ok - } + let ok = builder.all([op_code_ok, arg_type_ok, st_ok]); + measure_gates_end!(builder, measure); + ok } -struct CustomOperationVerifyGadget { - params: Params, +// +// Custom Predicate constraints +// + +fn verify_copy_circuit( + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + resolved_op_args: &[StatementTarget], +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpCopy"); + let op_code_ok = op_type.has_native(builder, NativeOperation::CopyStatement); + + let expected_statement = &resolved_op_args[0]; + let st_ok = builder.is_equal_flattenable(st, expected_statement); + + let ok = builder.all([op_code_ok, st_ok]); + measure_gates_end!(builder, measure); + ok } // NOTE: This is a bit messy. The target types are defined in `common.rs` because they are used in @@ -801,609 +796,570 @@ struct CustomOperationVerifyGadget { // here. Maybe we want to move everything related to custom predicates to its own module, but then // should we add a new trait for the `add_virtual_foo` methods so that everything is contained in a // module? -impl CustomOperationVerifyGadget { - fn statement_arg_from_template( - &self, - builder: &mut CircuitBuilder, - st_tmpl_arg: &StatementTmplArgTarget, - args: &[ValueTarget], - ) -> StatementArgTarget { - let zero = builder.zero(); - let (is_literal, value_literal) = st_tmpl_arg.as_literal(builder); - let (is_ak, ak_id_wc_index, ak_key_lit_or_wc) = st_tmpl_arg.as_anchored_key(builder); - let (is_wc_literal, wc_index) = st_tmpl_arg.as_wildcard_literal(builder); - - let ((_is_ak_key_lit, ak_key_lit), (is_ak_key_wc, ak_key_wc_index)) = - ak_key_lit_or_wc.cases(builder); - - // optimization: ak_id_wc_index and wc_index use the same signals, so we only need to do one - // random access to resolve both of them - assert_eq!(ak_id_wc_index, wc_index); - // optimization: the wildcard indices have an offset of +1. This allows us to set a fixed - // SELF in args[0] to resolve SelfOrWildcard::SELF encoded as a wildcard at index 0. - let value_self = ValueTarget::from_slice(&builder.constants(&SELF.0 .0)); - let args = iter::once(value_self) - .chain(args.iter().cloned()) - .collect_vec(); - // If the index is not used, use a 0 instead to still pass the range constraints from - // vec_ref - let first_index = ak_id_wc_index; - let is_first_index_valid = builder.or(is_ak, is_wc_literal); - let first_index = builder.select(is_first_index_valid, first_index, zero); - let resolved_ak_id = builder.vec_ref(&self.params, &args, first_index); - let resolved_wc = resolved_ak_id; - - // If the index is not used, use a 0 instead to still pass the range constraints from - // vec_ref - let second_index = ak_key_wc_index; - let is_second_index_valid = builder.and(is_ak, is_ak_key_wc); - let second_index = builder.select(is_second_index_valid, second_index, zero); - let resolved_ak_key = builder.vec_ref(&self.params, &args, second_index); - - let ak_key = ak_key_lit; // is_ak_key_lit - let ak_key = - builder.select_flattenable(&self.params, is_ak_key_wc, &resolved_ak_key, &ak_key); - - let first = ValueTarget::zero(builder); // is_none - let first = builder.select_flattenable(&self.params, is_literal, &value_literal, &first); - let first = builder.select_flattenable(&self.params, is_ak, &resolved_ak_id, &first); - let first = builder.select_flattenable(&self.params, is_wc_literal, &resolved_wc, &first); - - let second = ValueTarget::zero(builder); // is_none or is_literal or is_wc_literal - let second = builder.select_flattenable(&self.params, is_ak, &ak_key, &second); - - StatementArgTarget::new(first, second) - } +fn make_statement_arg_from_template_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st_tmpl_arg: &StatementTmplArgTarget, + args: &[ValueTarget], +) -> StatementArgTarget { + let zero = builder.zero(); + let (is_literal, value_literal) = st_tmpl_arg.as_literal(builder); + let (is_ak, ak_id_wc_index, ak_key_lit_or_wc) = st_tmpl_arg.as_anchored_key(builder); + let (is_wc_literal, wc_index) = st_tmpl_arg.as_wildcard_literal(builder); + + let ((_is_ak_key_lit, ak_key_lit), (is_ak_key_wc, ak_key_wc_index)) = + ak_key_lit_or_wc.cases(builder); + + // optimization: ak_id_wc_index and wc_index use the same signals, so we only need to do one + // random access to resolve both of them + assert_eq!(ak_id_wc_index, wc_index); + // optimization: the wildcard indices have an offset of +1. This allows us to set a fixed + // SELF in args[0] to resolve SelfOrWildcard::SELF encoded as a wildcard at index 0. + let value_self = ValueTarget::from_slice(&builder.constants(&SELF.0 .0)); + let args = iter::once(value_self) + .chain(args.iter().cloned()) + .collect_vec(); + // If the index is not used, use a 0 instead to still pass the range constraints from + // vec_ref + let first_index = ak_id_wc_index; + let is_first_index_valid = builder.or(is_ak, is_wc_literal); + let first_index = builder.select(is_first_index_valid, first_index, zero); + let resolved_ak_id = builder.vec_ref(params, &args, first_index); + let resolved_wc = resolved_ak_id; + + // If the index is not used, use a 0 instead to still pass the range constraints from + // vec_ref + let second_index = ak_key_wc_index; + let is_second_index_valid = builder.and(is_ak, is_ak_key_wc); + let second_index = builder.select(is_second_index_valid, second_index, zero); + let resolved_ak_key = builder.vec_ref(params, &args, second_index); + + let ak_key = ak_key_lit; // is_ak_key_lit + let ak_key = builder.select_flattenable(params, is_ak_key_wc, &resolved_ak_key, &ak_key); + + let first = ValueTarget::zero(builder); // is_none + let first = builder.select_flattenable(params, is_literal, &value_literal, &first); + let first = builder.select_flattenable(params, is_ak, &resolved_ak_id, &first); + let first = builder.select_flattenable(params, is_wc_literal, &resolved_wc, &first); + + let second = ValueTarget::zero(builder); // is_none or is_literal or is_wc_literal + let second = builder.select_flattenable(params, is_ak, &ak_key, &second); + + StatementArgTarget::new(first, second) +} - fn statement_from_template( - &self, - builder: &mut CircuitBuilder, - st_tmpl: &StatementTmplTarget, - args: &[ValueTarget], - ) -> StatementTarget { - let measure = measure_gates_begin!(builder, "StArgFromTmpl"); - let args = st_tmpl - .args - .iter() - .map(|st_tmpl_arg| self.statement_arg_from_template(builder, st_tmpl_arg, args)) - .collect(); - measure_gates_end!(builder, measure); - StatementTarget { - predicate: st_tmpl.pred.clone(), - args, - } +fn make_statement_from_template_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st_tmpl: &StatementTmplTarget, + args: &[ValueTarget], +) -> StatementTarget { + let measure = measure_gates_begin!(builder, "StArgFromTmpl"); + let args = st_tmpl + .args + .iter() + .map(|st_tmpl_arg| { + make_statement_arg_from_template_circuit(params, builder, st_tmpl_arg, args) + }) + .collect(); + measure_gates_end!(builder, measure); + StatementTarget { + predicate: st_tmpl.pred.clone(), + args, } +} - /// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard - /// values (args): - /// - Verify that the custom predicate is satisfied with the given statements - /// - Build the output statement - /// - Build the expected operation type - fn eval( - &self, - builder: &mut CircuitBuilder, - custom_predicate: &CustomPredicateEntryTarget, - op_args: &[StatementTarget], - args: &[ValueTarget], // arguments to the custom predicate, public and private - ) -> Result<(StatementTarget, OperationTypeTarget)> { - let measure = measure_gates_begin!(builder, "CustomOpVerify"); - // Some sanity checks - assert_eq!(self.params.max_operation_args, op_args.len()); - assert_eq!(self.params.max_custom_predicate_wildcards, args.len()); - - let (batch_id, index) = (custom_predicate.id, custom_predicate.index); - let op_type = OperationTypeTarget::new_custom(builder, batch_id, index); - - // Build the statement - let st_predicate = PredicateTarget::new_custom(builder, batch_id, index); - let arg_none = ValueTarget::zero(builder); - let lt_mask = builder.lt_mask( - self.params.max_statement_args, - custom_predicate.predicate.args_len, - ); - let st_args = (0..self.params.max_statement_args) - .map(|i| { - let v = builder.select_flattenable(&self.params, lt_mask[i], &args[i], &arg_none); - StatementArgTarget::wildcard_literal(builder, &v) - }) - .collect(); - let statement = StatementTarget { - predicate: st_predicate, - args: st_args, - }; - - // Check the operation arguments - // From each statement template we generate an expected statement using replacing the - // wildcards by the arguments. Then we compare the expected statement with the operation - // argument. - let expected_sts: Vec<_> = custom_predicate - .predicate - .statements - .iter() - .map(|st_tmpl| self.statement_from_template(builder, st_tmpl, args)) - .collect(); - // expected_sts.len() == self.params.max_custom_predicate_arity - // op_args.len() == self.params.max_operation_args; - assert!(self.params.max_custom_predicate_arity <= self.params.max_operation_args); +/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard +/// values (args): +/// - Verify that the custom predicate is satisfied with the given statements +/// - Build the output statement +/// - Build the expected operation type +fn make_custom_statement_circuit( + params: &Params, + builder: &mut CircuitBuilder, + custom_predicate: &CustomPredicateEntryTarget, + op_args: &[StatementTarget], + args: &[ValueTarget], // arguments to the custom predicate, public and private +) -> Result<(StatementTarget, OperationTypeTarget)> { + let measure = measure_gates_begin!(builder, "CustomOpVerify"); + // Some sanity checks + assert_eq!(params.max_operation_args, op_args.len()); + assert_eq!(params.max_custom_predicate_wildcards, args.len()); + + let (batch_id, index) = (custom_predicate.id, custom_predicate.index); + let op_type = OperationTypeTarget::new_custom(builder, batch_id, index); + + // Build the statement + let st_predicate = PredicateTarget::new_custom(builder, batch_id, index); + let arg_none = ValueTarget::zero(builder); + let lt_mask = builder.lt_mask( + params.max_statement_args, + custom_predicate.predicate.args_len, + ); + let st_args = (0..params.max_statement_args) + .map(|i| { + let v = builder.select_flattenable(params, lt_mask[i], &args[i], &arg_none); + StatementArgTarget::wildcard_literal(builder, &v) + }) + .collect(); + let statement = StatementTarget { + predicate: st_predicate, + args: st_args, + }; - let sts_eq: Vec<_> = expected_sts - .iter() - .zip(op_args.iter()) - .map(|(expected_st, st)| builder.is_equal_flattenable(expected_st, st)) - .collect(); - let all_st_eq = builder.all(sts_eq.clone()); - let some_st_eq = builder.any(sts_eq); - // NOTE: This BoolTarget is safe because both inputs to the select are safe - let is_op_args_ok = BoolTarget::new_unsafe(builder.select( - custom_predicate.predicate.conjunction, - all_st_eq.target, - some_st_eq.target, - )); - - builder.assert_one(is_op_args_ok.target); - measure_gates_end!(builder, measure); - Ok((statement, op_type)) - } + // Check the operation arguments + // From each statement template we generate an expected statement using replacing the + // wildcards by the arguments. Then we compare the expected statement with the operation + // argument. + let expected_sts: Vec<_> = custom_predicate + .predicate + .statements + .iter() + .map(|st_tmpl| make_statement_from_template_circuit(params, builder, st_tmpl, args)) + .collect(); + // expected_sts.len() == params.max_custom_predicate_arity + // op_args.len() == params.max_operation_args; + assert!(params.max_custom_predicate_arity <= params.max_operation_args); + + let sts_eq: Vec<_> = expected_sts + .iter() + .zip(op_args.iter()) + .map(|(expected_st, st)| builder.is_equal_flattenable(expected_st, st)) + .collect(); + let all_st_eq = builder.all(sts_eq.clone()); + let some_st_eq = builder.any(sts_eq); + // NOTE: This BoolTarget is safe because both inputs to the select are safe + let is_op_args_ok = BoolTarget::new_unsafe(builder.select( + custom_predicate.predicate.conjunction, + all_st_eq.target, + some_st_eq.target, + )); + + builder.assert_one(is_op_args_ok.target); + measure_gates_end!(builder, measure); + Ok((statement, op_type)) } /// Replace references to SELF by `self_id` in a statement. -struct NormalizeStatementGadget { - params: Params, +fn normalize_statement_circuit( + params: &Params, + builder: &mut CircuitBuilder, + statement: &StatementTarget, + self_id: &ValueTarget, +) -> StatementTarget { + let zero_value = builder.constant_value(EMPTY_VALUE); + let self_value = builder.constant_value(SELF.0.into()); + let args = statement + .args + .iter() + .map(|arg| { + let first = ValueTarget::from_slice(&arg.elements[..VALUE_SIZE]); + let second = ValueTarget::from_slice(&arg.elements[VALUE_SIZE..]); + let is_not_ak = builder.is_equal_flattenable(&zero_value, &second); + let is_ak = builder.not(is_not_ak); + let is_self = builder.is_equal_flattenable(&self_value, &first); + let normalize = builder.and(is_ak, is_self); + let first_normalized = builder.select_flattenable(params, normalize, self_id, &first); + StatementArgTarget::new(first_normalized, second) + }) + .collect_vec(); + StatementTarget { + predicate: statement.predicate.clone(), + args, + } } -impl NormalizeStatementGadget { - fn eval( - &self, - builder: &mut CircuitBuilder, - statement: &StatementTarget, - self_id: &ValueTarget, - ) -> StatementTarget { - let zero_value = builder.constant_value(EMPTY_VALUE); - let self_value = builder.constant_value(SELF.0.into()); - let args = statement - .args - .iter() - .map(|arg| { - let first = ValueTarget::from_slice(&arg.elements[..VALUE_SIZE]); - let second = ValueTarget::from_slice(&arg.elements[VALUE_SIZE..]); - let is_not_ak = builder.is_equal_flattenable(&zero_value, &second); - let is_ak = builder.not(is_not_ak); - let is_self = builder.is_equal_flattenable(&self_value, &first); - let normalize = builder.and(is_ak, is_self); - let first_normalized = - builder.select_flattenable(&self.params, normalize, self_id, &first); - StatementArgTarget::new(first_normalized, second) - }) - .collect_vec(); - StatementTarget { - predicate: statement.predicate.clone(), - args, - } +/// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder +/// elements that didn't fit into a chunk. +fn precompute_hash_state>(inputs: &[F]) -> (P, &[F]) { + let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE); + let mut perm = P::new(core::iter::repeat(F::ZERO)); + + // Absorb all inputs up to the biggest multiple of RATE. + for input_chunk in inputs.chunks(P::RATE) { + perm.set_from_slice(input_chunk, 0); + perm.permute(); } -} -pub struct CalculateIdGadget { - /// `params.num_public_statements_id` is the total number of statements that will be hashed. - /// The id is calculated with front-padded none-statements and then the input statements - /// reversed. The part of the hash from the front-padded none-statements is precomputed. - pub params: Params, + (perm, inputs_rem) } -impl CalculateIdGadget { - /// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder - /// elements that didn't fit into a chunk. - fn precompute_hash_state>(inputs: &[F]) -> (P, &[F]) { - let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE); - let mut perm = P::new(core::iter::repeat(F::ZERO)); - - // Absorb all inputs up to the biggest multiple of RATE. - for input_chunk in inputs.chunks(P::RATE) { - perm.set_from_slice(input_chunk, 0); - perm.permute(); - } - - (perm, inputs_rem) +/// Hash `inputs` starting from a circuit-constant `perm` state. +fn hash_from_state_circuit, P: PlonkyPermutation>( + builder: &mut CircuitBuilder, + perm: P, + inputs: &[Target], +) -> HashOutTarget { + let mut state = + H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v))); + + // Absorb all input chunks. + for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) { + // Overwrite the first r elements with the inputs. This differs from a standard sponge, + // where we would xor or add in the inputs. This is a well-known variant, though, + // sometimes called "overwrite mode". + state.set_from_slice(input_chunk, 0); + state = builder.permute::(state); } - /// Hash `inputs` starting from a circuit-constant `perm` state. - fn hash_from_state, P: PlonkyPermutation>( - builder: &mut CircuitBuilder, - perm: P, - inputs: &[Target], - ) -> HashOutTarget { - let mut state = - H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v))); - - // Absorb all input chunks. - for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) { - // Overwrite the first r elements with the inputs. This differs from a standard sponge, - // where we would xor or add in the inputs. This is a well-known variant, though, - // sometimes called "overwrite mode". - state.set_from_slice(input_chunk, 0); - state = builder.permute::(state); - } - - let num_outputs = NUM_HASH_OUT_ELTS; - // Squeeze until we have the desired number of outputs. - let mut outputs = Vec::with_capacity(num_outputs); - loop { - for &s in state.squeeze() { - outputs.push(s); - if outputs.len() == num_outputs { - return HashOutTarget::from_vec(outputs); - } + let num_outputs = NUM_HASH_OUT_ELTS; + // Squeeze until we have the desired number of outputs. + let mut outputs = Vec::with_capacity(num_outputs); + loop { + for &s in state.squeeze() { + outputs.push(s); + if outputs.len() == num_outputs { + return HashOutTarget::from_vec(outputs); } - state = builder.permute::(state); } - } - - pub fn eval( - &self, - builder: &mut CircuitBuilder, - // These statements will be padded to reach `self.num_statements` - statements: &[StatementTarget], - ) -> HashOutTarget { - assert!(statements.len() <= self.params.num_public_statements_id); - let measure = measure_gates_begin!(builder, "CalculateId"); - let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten()); - let mut none_st = mainpod::Statement::from(Statement::None); - pad_statement(&self.params, &mut none_st); - let front_pad_elts = iter::repeat(&none_st) - .take(self.params.num_public_statements_id - statements.len()) - .flat_map(|s| s.to_fields(&self.params)) - .collect_vec(); - let (perm, front_pad_elts_rem) = - Self::precompute_hash_state::>(&front_pad_elts); - - // Precompute the Poseidon state for the initial padding chunks - let inputs = front_pad_elts_rem - .iter() - .map(|v| builder.constant(*v)) - .chain(statements_rev_flattened) - .collect_vec(); - let id = - Self::hash_from_state::>(builder, perm, &inputs); - - measure_gates_end!(builder, measure); - id + state = builder.permute::(state); } } -struct MainPodVerifyGadget { - params: Params, +/// `params.num_public_statements_id` is the total number of statements that will be hashed. +/// The id is calculated with front-padded none-statements and then the input statements +/// reversed. The part of the hash from the front-padded none-statements is precomputed. +pub fn calculate_id_circuit( + params: &Params, + builder: &mut CircuitBuilder, + // These statements will be padded to reach `num_statements` + statements: &[StatementTarget], +) -> HashOutTarget { + assert!(statements.len() <= params.num_public_statements_id); + let measure = measure_gates_begin!(builder, "CalculateId"); + let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten()); + let mut none_st = mainpod::Statement::from(Statement::None); + pad_statement(params, &mut none_st); + let front_pad_elts = iter::repeat(&none_st) + .take(params.num_public_statements_id - statements.len()) + .flat_map(|s| s.to_fields(params)) + .collect_vec(); + let (perm, front_pad_elts_rem) = + precompute_hash_state::>(&front_pad_elts); + + // Precompute the Poseidon state for the initial padding chunks + let inputs = front_pad_elts_rem + .iter() + .map(|v| builder.constant(*v)) + .chain(statements_rev_flattened) + .collect_vec(); + let id = + hash_from_state_circuit::>(builder, perm, &inputs); + + measure_gates_end!(builder, measure); + id } -impl MainPodVerifyGadget { - // Replace predicates of batch-self with the corresponding global custom predicate batch_id and - // index - fn normalize_st_tmpl( - &self, - builder: &mut CircuitBuilder, - st_tmpl: &StatementTmplTarget, - id: HashOutTarget, - ) -> StatementTmplTarget { - let params = &self.params; - let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf)); - let is_batch_self = builder.is_equal(st_tmpl.pred.elements[0], prefix_batch_self); - let pred_index = st_tmpl.pred.elements[1]; - let custom_pred = PredicateTarget::new_custom(builder, id, pred_index); - let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, &st_tmpl.pred); - StatementTmplTarget { - pred, - args: st_tmpl.args.clone(), - } - } - /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as - /// hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we - /// calculate the id of each batch. - fn build_custom_predicate_table( - &self, - builder: &mut CircuitBuilder, - ) -> Result<(Vec, Vec)> { - let measure = measure_gates_begin!(builder, "BuildCustomPredTbl"); - let params = &self.params; - let mut custom_predicate_table = - Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size); - let mut custom_predicate_batches = Vec::with_capacity(params.max_custom_predicate_batches); - for _ in 0..params.max_custom_predicate_batches { - let cpb = builder.add_virtual_custom_predicate_batch(params); - let id = cpb.id(builder); // constrain the id - for (index, cp) in cpb.predicates.iter().enumerate() { - let statements = cp - .statements - .iter() - .map(|st_tmpl| self.normalize_st_tmpl(builder, st_tmpl, id)) - .collect_vec(); - let cp = CustomPredicateTarget { - conjunction: cp.conjunction, - statements, - args_len: cp.args_len, - }; - let entry = CustomPredicateEntryTarget { - id, // output - index: builder.constant(F::from_canonical_usize(index)), // constant - predicate: cp.clone(), // input - }; - - let in_query_hash = entry.hash(builder); - custom_predicate_table.push(in_query_hash); - } - custom_predicate_batches.push(cpb); // We keep this for witness assignment - } - measure_gates_end!(builder, measure); - Ok((custom_predicate_table, custom_predicate_batches)) +// Replace predicates of batch-self with the corresponding global custom predicate batch_id and +// index +fn normalize_st_tmpl_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st_tmpl: &StatementTmplTarget, + id: HashOutTarget, +) -> StatementTmplTarget { + let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf)); + let is_batch_self = builder.is_equal(st_tmpl.pred.elements[0], prefix_batch_self); + let pred_index = st_tmpl.pred.elements[1]; + let custom_pred = PredicateTarget::new_custom(builder, id, pred_index); + let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, &st_tmpl.pred); + StatementTmplTarget { + pred, + args: st_tmpl.args.clone(), } +} - /// Build table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args] - /// with queryable part as hash([st, op, op_args]). While building the table we verify each - /// custom predicate against the operation and statement. - fn build_custom_predicate_verification_table( - &self, - builder: &mut CircuitBuilder, - custom_predicate_table: &[HashOutTarget], - ) -> Result<(Vec, Vec)> { - let measure = measure_gates_begin!(builder, "BuildCustomPredVerifyTbl"); - let params = &self.params; - let mut custom_predicate_verifications = - Vec::with_capacity(params.max_custom_predicate_verifications); - let mut custom_predicate_verification_table = - Vec::with_capacity(params.max_custom_predicate_verifications); - for _ in 0..params.max_custom_predicate_verifications { - let custom_predicate_table_index = builder.add_virtual_target(); - let custom_predicate = builder.add_virtual_custom_predicate_entry(params); - let args = (0..params.max_custom_predicate_wildcards) - .map(|_| builder.add_virtual_value()) - .collect_vec(); - let op_args = (0..params.max_operation_args) - .map(|_| builder.add_virtual_statement(params)) +/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as +/// hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we +/// calculate the id of each batch. Return the hash of each table entry. +fn build_custom_predicate_table_circuit( + params: &Params, + builder: &mut CircuitBuilder, + custom_predicate_batches: &[CustomPredicateBatchTarget], +) -> Result> { + let measure = measure_gates_begin!(builder, "BuildCustomPredTbl"); + let mut custom_predicate_table = + Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size); + for cpb in custom_predicate_batches { + let id = cpb.id(builder); // constrain the id + for (index, cp) in cpb.predicates.iter().enumerate() { + let statements = cp + .statements + .iter() + .map(|st_tmpl| normalize_st_tmpl_circuit(params, builder, st_tmpl, id)) .collect_vec(); - - // Verify the custom predicate operation - let (statement, op_type) = CustomOperationVerifyGadget { - params: params.clone(), - } - .eval(builder, &custom_predicate, &op_args, &args)?; - - // Check that the batch id is correct by querying the custom predicate batches table - let table_query_hash = - builder.vec_ref(params, custom_predicate_table, custom_predicate_table_index); - let out_query_hash = custom_predicate.hash(builder); - builder.connect_array(table_query_hash.elements, out_query_hash.elements); - - let entry = CustomPredicateVerifyEntryTarget { - custom_predicate_table_index, // input - custom_predicate, // input - args, // input - query: CustomPredicateVerifyQueryTarget { - statement, // output - op_type, // output - op_args, // input - }, + let cp = CustomPredicateTarget { + conjunction: cp.conjunction, + statements, + args_len: cp.args_len, + }; + let entry = CustomPredicateEntryTarget { + id, // output + index: builder.constant(F::from_canonical_usize(index)), // constant + predicate: cp.clone(), // input }; - let in_query_hash = entry.query.hash(builder); - custom_predicate_verification_table.push(in_query_hash); - custom_predicate_verifications.push(entry); // We keep this for witness assignment + + let in_query_hash = entry.hash(builder); + custom_predicate_table.push(in_query_hash); } - measure_gates_end!(builder, measure); - Ok(( - custom_predicate_verification_table, - custom_predicate_verifications, - )) } + measure_gates_end!(builder, measure); + Ok(custom_predicate_table) +} - fn eval( - &self, - builder: &mut CircuitBuilder, - verified_proofs: &[VerifiedProofTarget], - ) -> Result { - assert_eq!(self.params.max_input_recursive_pods, verified_proofs.len()); - - let measure = measure_gates_begin!(builder, "MainPodVerify"); - let params = &self.params; - // 1a. Verify all input signed pods - let mut signed_pods = Vec::new(); - for _ in 0..params.max_input_signed_pods { - let signed_pod = SignedPodVerifyGadget { - params: params.clone(), - } - .eval(builder)?; - builder.assert_one(signed_pod.signature.enabled.target); - signed_pods.push(signed_pod); - } +/// Build table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args] +/// with queryable part as hash([st, op, op_args]). While building the table we verify each +/// custom predicate against the operation and statement. Return the hash of each table "query" +/// sub-entry. +fn build_custom_predicate_verification_table_circuit( + params: &Params, + builder: &mut CircuitBuilder, + custom_predicate_table: &[HashOutTarget], + custom_predicate_verifications: &[CustomPredicateVerifyEntryTarget], +) -> Result> { + let measure = measure_gates_begin!(builder, "BuildCustomPredVerifyTbl"); + let mut custom_predicate_verification_table = + Vec::with_capacity(params.max_custom_predicate_verifications); + for entry in custom_predicate_verifications { + // Verify the custom predicate operation + let (statement, op_type) = make_custom_statement_circuit( + params, + builder, + &entry.custom_predicate, + &entry.op_args, + &entry.args, + )?; - // Build the statement array - let mut statements = Vec::new(); - // Statement at index 0 is always None to be used for padding operation arguments in custom - // predicate statements - let st_none = - StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]); - statements.push(st_none); - for signed_pod in &signed_pods { - statements.extend_from_slice(signed_pod.pub_statements(builder, false).as_slice()); - } - debug_assert_eq!( - statements.len(), - 1 + self.params.max_input_signed_pods * self.params.max_signed_pod_values + // Check that the batch id is correct by querying the custom predicate batches table + let table_query_hash = builder.vec_ref( + params, + custom_predicate_table, + entry.custom_predicate_table_index, ); + let out_query_hash = entry.custom_predicate.hash(builder); + builder.connect_array(table_query_hash.elements, out_query_hash.elements); - // 1b. Verify all input recursive pods - let id_gadget = CalculateIdGadget { - params: params.clone(), - }; - let normalize_statement_gadget = NormalizeStatementGadget { - params: self.params.clone(), + let query = CustomPredicateVerifyQueryTarget { + statement, // output + op_type, // output + op_args: entry.op_args.clone(), // input }; - let vds_root = builder.add_virtual_hash(); - let mut input_pods_self_statements: Vec> = Vec::new(); - let mut vd_mt_proofs: Vec = Vec::new(); - - for verified_proof in verified_proofs { - let measure_in_pod = measure_gates_begin!(builder, "VerifyInPod"); - - // - // Verify id from the statements - // - let expected_id = HashOutTarget::try_from( - &verified_proof.public_inputs[PI_OFFSET_ID..PI_OFFSET_ID + HASH_SIZE], - ) - .expect("4 elements"); - let id_value = ValueTarget { - elements: expected_id.elements, - }; - - let mut input_pod_self_statements = Vec::new(); - for _ in 0..self.params.max_input_pods_public_statements { - let self_st = builder.add_virtual_statement(params); - let normalized_st = normalize_statement_gadget.eval(builder, &self_st, &id_value); - input_pod_self_statements.push(self_st); - statements.push(normalized_st); - } - let id = id_gadget.eval(builder, &input_pod_self_statements); - builder.connect_hashes(expected_id, id); - input_pods_self_statements.push(input_pod_self_statements); - - // - // Verify that all input pod proofs use verifier data from the public input VD - // array. This requires merkle proofs - // - - // add target for the vd_mt_proof - let vd_mt_proof = MerkleProofGadget { - max_depth: params.max_depth_mt_vds, - } - .eval(builder); - - // ensure that mt_proof is enabled - let true_targ = builder._true(); - builder.connect(vd_mt_proof.enabled.target, true_targ.target); - // connect the vd_mt_proof's root to the actual vds_root, to ensure that the mt proof - // verifies against the vds_root - builder.connect_hashes(vds_root, vd_mt_proof.root); - // connect vd_mt_proof's value with the verified_proof.verifier_data_hash - builder.connect_hashes( - verified_proof.verifier_data_hash, - HashOutTarget::from_vec(vd_mt_proof.value.elements.to_vec()), - ); - vd_mt_proofs.push(vd_mt_proof); - - // - // Verify that VD array that input pod uses is the same we use now. - // - let verified_proof_vds_root = HashOutTarget::try_from( - &verified_proof.public_inputs[PI_OFFSET_VDSROOT..PI_OFFSET_VDSROOT + HASH_SIZE], - ) - .expect("4 elements"); - builder.connect_hashes(vds_root, verified_proof_vds_root); - - measure_gates_end!(builder, measure_in_pod); - } + let in_query_hash = query.hash(builder); + custom_predicate_verification_table.push(in_query_hash); + } + measure_gates_end!(builder, measure); + Ok(custom_predicate_verification_table) +} - // Add the input (private and public) statements and corresponding operations - let mut operations = Vec::new(); - let input_statements_offset = statements.len(); - for _ in 0..params.max_statements { - statements.push(builder.add_virtual_statement(params)); - operations.push(builder.add_virtual_operation(params)); - } +fn verify_main_pod_circuit( + builder: &mut CircuitBuilder, + main_pod: &MainPodVerifyTarget, + verified_proofs: &[VerifiedProofTarget], +) -> Result { + let params = &main_pod.params; + assert_eq!(params.max_input_recursive_pods, verified_proofs.len()); + + let measure = measure_gates_begin!(builder, "MainPodVerify"); + // 1a. Verify all input signed pods + for signed_pod in &main_pod.signed_pods { + verify_signed_pod_circuit(builder, signed_pod)?; + builder.assert_one(signed_pod.signature.enabled.target); + } - let input_statements = &statements[input_statements_offset..]; - let pub_statements = - &input_statements[input_statements.len() - params.max_public_statements..]; + // Build the statement array + let mut statements = Vec::new(); + // Statement at index 0 is always None to be used for padding operation arguments in custom + // predicate statements + let st_none = StatementTarget::new_native(builder, params, NativePredicate::None, &[]); + statements.push(st_none); + for signed_pod in &main_pod.signed_pods { + statements.extend_from_slice(signed_pod.pub_statements(builder, false).as_slice()); + } + debug_assert_eq!( + statements.len(), + 1 + params.max_input_signed_pods * params.max_signed_pod_values + ); - // Add Merkle claim/proof targets - let mp_gadget = MerkleProofGadget { - max_depth: params.max_depth_mt_containers, + // 1b. Verify all input recursive pods + for (verified_proof, vd_mt_proof, input_pod_self_statements) in izip!( + verified_proofs, + &main_pod.vd_mt_proofs, + &main_pod.input_pods_self_statements + ) { + let measure_in_pod = measure_gates_begin!(builder, "VerifyInPod"); + + // + // Verify id from the statements + // + let expected_id = HashOutTarget::try_from( + &verified_proof.public_inputs[PI_OFFSET_ID..PI_OFFSET_ID + HASH_SIZE], + ) + .expect("4 elements"); + let id_value = ValueTarget { + elements: expected_id.elements, }; - let merkle_proofs: Vec<_> = (0..params.max_merkle_proofs_containers) - .map(|_| mp_gadget.eval(builder)) - .collect(); - let merkle_claims: Vec<_> = merkle_proofs - .clone() - .into_iter() - .map(|pf| pf.into()) - .collect(); - - // Table of custom predicate batches with batch_id calculation - let (custom_predicate_table, custom_predicate_batches) = - self.build_custom_predicate_table(builder)?; - // Table of custom predicate statements verification against operations - let (custom_predicate_verification_table, custom_predicate_verifications) = - self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?; - - // 2. Calculate the Pod Id from the public statements - let id = CalculateIdGadget { - params: self.params.clone(), + for self_st in input_pod_self_statements { + let normalized_st = normalize_statement_circuit(params, builder, self_st, &id_value); + statements.push(normalized_st); } - .eval(builder, pub_statements); - - // 4. Verify type - let type_statement = &pub_statements[0]; - // TODO: Store this hash in a global static with lazy init so that we don't have to - // compute it every time. - let expected_type_statement = StatementTarget::from_flattened( - &self.params, - &builder.constants( - &Statement::equal( - ValueRef::Key(AnchoredKey::from((SELF, KEY_TYPE))), - ValueRef::Literal(Value::from(PodType::Main)), - ) - .to_fields(params), - ), + let id = calculate_id_circuit(params, builder, input_pod_self_statements); + builder.connect_hashes(expected_id, id); + + // + // Verify that all input pod proofs use verifier data from the public input VD + // array. This requires merkle proofs + // + + verify_merkle_proof_circuit(builder, vd_mt_proof); + + // ensure that mt_proof is enabled + let true_targ = builder._true(); + builder.connect(vd_mt_proof.enabled.target, true_targ.target); + // connect the vd_mt_proof's root to the actual vds_root, to ensure that the mt proof + // verifies against the vds_root + builder.connect_hashes(main_pod.vds_root, vd_mt_proof.root); + // connect vd_mt_proof's value with the verified_proof.verifier_data_hash + builder.connect_hashes( + verified_proof.verifier_data_hash, + HashOutTarget::from_vec(vd_mt_proof.value.elements.to_vec()), ); - builder.connect_flattenable(type_statement, &expected_type_statement); - - // 3. check that all `input_statements` of type `ValueOf` with origin=SELF have unique keys - // (no duplicates). We do this in the verification of NewEntry operation. - // 5. Verify input statements - for (i, (st, op)) in input_statements.iter().zip(operations.iter()).enumerate() { - let prev_statements = &statements[..input_statements_offset + i]; - OperationVerifyGadget { - params: params.clone(), - } - .eval( - builder, - st, - op, - prev_statements, - input_statements_offset, - &merkle_claims, - &custom_predicate_verification_table, - )?; - } - measure_gates_end!(builder, measure); - Ok(MainPodVerifyTarget { - params: params.clone(), - vds_root, - vd_mt_proofs, - id, - signed_pods, - input_pods_self_statements, - statements: input_statements.to_vec(), - operations, - merkle_proofs, - custom_predicate_batches, - custom_predicate_verifications, + // + // Verify that VD array that input pod uses is the same we use now. + // + let verified_proof_vds_root = HashOutTarget::try_from( + &verified_proof.public_inputs[PI_OFFSET_VDSROOT..PI_OFFSET_VDSROOT + HASH_SIZE], + ) + .expect("4 elements"); + builder.connect_hashes(main_pod.vds_root, verified_proof_vds_root); + + measure_gates_end!(builder, measure_in_pod); + } + + let input_statements_offset = statements.len(); + // Add the input (private and public) statements + for statement in &main_pod.input_statements { + statements.push(statement.clone()); + } + let pub_statements = &main_pod.input_statements + [main_pod.input_statements.len() - params.max_public_statements..]; + + // Verify Merkle claim/proof targets + let merkle_claims = main_pod + .merkle_proofs + .iter() + .map(|mt_proof| { + verify_merkle_proof_circuit(builder, mt_proof); + MerkleClaimTarget::from(mt_proof.clone()) }) + .collect_vec(); + + // Table of custom predicate batches with batch_id calculation + let custom_predicate_table = + build_custom_predicate_table_circuit(params, builder, &main_pod.custom_predicate_batches)?; + + // Table of custom predicate statements verification against operations + let custom_predicate_verification_table = build_custom_predicate_verification_table_circuit( + params, + builder, + &custom_predicate_table, + &main_pod.custom_predicate_verifications, + )?; + + // 2. Calculate the Pod Id from the public statements + let id = calculate_id_circuit(params, builder, pub_statements); + + // 4. Verify type + let type_statement = &pub_statements[0]; + // TODO: Store this hash in a global static with lazy init so that we don't have to + // compute it every time. + let expected_type_statement = StatementTarget::from_flattened( + params, + &builder.constants( + &Statement::equal( + ValueRef::Key(AnchoredKey::from((SELF, KEY_TYPE))), + ValueRef::Literal(Value::from(PodType::Main)), + ) + .to_fields(params), + ), + ); + builder.connect_flattenable(type_statement, &expected_type_statement); + + // 3. check that all `input_statements` of type `ValueOf` with origin=SELF have unique keys + // (no duplicates). We do this in the verification of NewEntry operation. + // 5. Verify input statements + for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() { + let prev_statements = &statements[..input_statements_offset + i]; + verify_operation_circuit( + params, + builder, + st, + op, + prev_statements, + input_statements_offset, + &merkle_claims, + &custom_predicate_verification_table, + )?; } + + measure_gates_end!(builder, measure); + Ok(id) } pub struct MainPodVerifyTarget { params: Params, vds_root: HashOutTarget, vd_mt_proofs: Vec, - id: HashOutTarget, signed_pods: Vec, input_pods_self_statements: Vec>, // The KEY_TYPE statement must be the first public one - statements: Vec, + input_statements: Vec, operations: Vec, merkle_proofs: Vec, custom_predicate_batches: Vec, custom_predicate_verifications: Vec, } +impl MainPodVerifyTarget { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + MainPodVerifyTarget { + params: params.clone(), + vds_root: builder.add_virtual_hash(), + vd_mt_proofs: (0..params.max_input_recursive_pods) + .map(|_| MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_vds, builder)) + .collect(), + signed_pods: (0..params.max_input_signed_pods) + .map(|_| SignedPodVerifyTarget::new_virtual(params, builder)) + .collect(), + input_pods_self_statements: (0..params.max_input_recursive_pods) + .map(|_| { + (0..params.max_input_pods_public_statements) + .map(|_| builder.add_virtual_statement(params)) + .collect_vec() + }) + .collect(), + input_statements: (0..params.max_statements) + .map(|_| builder.add_virtual_statement(params)) + .collect(), + operations: (0..params.max_statements) + .map(|_| builder.add_virtual_operation(params)) + .collect(), + merkle_proofs: (0..params.max_merkle_proofs_containers) + .map(|_| { + MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, builder) + }) + .collect(), + custom_predicate_batches: (0..params.max_custom_predicate_batches) + .map(|_| builder.add_virtual_custom_predicate_batch(params)) + .collect(), + custom_predicate_verifications: (0..params.max_custom_predicate_verifications) + .map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder)) + .collect(), + } + } +} + pub struct CustomPredicateVerification { pub custom_predicate_table_index: usize, pub custom_predicate: CustomPredicateRef, @@ -1450,27 +1406,6 @@ fn set_targets_input_pods_self_statements( Ok(()) } -pub struct MainPodVerifyCircuit { - pub params: Params, -} - -// TODO: Remove this type and implement it's logic directly in `impl InnerCircuit for MainPodVerifyTarget` -impl MainPodVerifyCircuit { - pub fn eval( - &self, - builder: &mut CircuitBuilder, - verified_proofs: &[VerifiedProofTarget], - ) -> Result { - let main_pod = MainPodVerifyGadget { - params: self.params.clone(), - } - .eval(builder, verified_proofs)?; - builder.register_public_inputs(&main_pod.id.elements); - builder.register_public_inputs(&main_pod.vds_root.elements); - Ok(main_pod) - } -} - impl InnerCircuit for MainPodVerifyTarget { type Input = MainPodVerifyInput; type Params = Params; @@ -1480,10 +1415,11 @@ impl InnerCircuit for MainPodVerifyTarget { params: &Self::Params, verified_proofs: &[VerifiedProofTarget], ) -> Result { - MainPodVerifyCircuit { - params: params.clone(), - } - .eval(builder, verified_proofs) + let main_pod = MainPodVerifyTarget::new_virtual(params, builder); + let id = verify_main_pod_circuit(builder, &main_pod, verified_proofs)?; + builder.register_public_inputs(&id.elements); + builder.register_public_inputs(&main_pod.vds_root.elements); + Ok(main_pod) } /// assigns the values to the targets @@ -1544,7 +1480,7 @@ impl InnerCircuit for MainPodVerifyTarget { assert_eq!(input.statements.len(), self.params.max_statements); for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() { - self.statements[i].set_targets(pw, &self.params, st)?; + self.input_statements[i].set_targets(pw, &self.params, st)?; self.operations[i].set_targets(pw, &self.params, op)?; } @@ -1638,9 +1574,6 @@ mod tests { max_custom_predicate_verifications: 0, ..Default::default() }; - let mp_gadget = MerkleProofGadget { - max_depth: params.max_depth_mt_containers, - }; let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); @@ -1652,7 +1585,14 @@ mod tests { .collect(); let merkle_proofs_target: Vec<_> = merkle_proofs .iter() - .map(|_| mp_gadget.eval(&mut builder)) + .map(|_| { + let mt_proof = MerkleClaimAndProofTarget::new_virtual( + params.max_depth_mt_containers, + &mut builder, + ); + verify_merkle_proof_circuit(&mut builder, &mt_proof); + mt_proof + }) .collect(); let merkle_claims_target: Vec<_> = merkle_proofs_target .clone() @@ -1661,10 +1601,8 @@ mod tests { .collect(); let custom_predicate_verification_table = vec![]; - OperationVerifyGadget { - params: params.clone(), - } - .eval( + verify_operation_circuit( + ¶ms, &mut builder, &st_target, &op_target, @@ -2517,16 +2455,17 @@ mod tests { ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let gadget = CustomOperationVerifyGadget { - params: params.clone(), - }; let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); let args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_value()) .collect(); - let st_arg_target = - gadget.statement_arg_from_template(&mut builder, &st_tmpl_arg_target, &args_target); + let st_arg_target = make_statement_arg_from_template_circuit( + params, + &mut builder, + &st_tmpl_arg_target, + &args_target, + ); // TODO: Instead of connect, assign witness to result let expected_st_arg_target = builder.add_virtual_statement_arg(); builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); @@ -2589,15 +2528,17 @@ mod tests { ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let gadget = CustomOperationVerifyGadget { - params: params.clone(), - }; let st_tmpl_target = builder.add_virtual_statement_tmpl(params); let args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_value()) .collect(); - let st_target = gadget.statement_from_template(&mut builder, &st_tmpl_target, &args_target); + let st_target = make_statement_from_template_circuit( + params, + &mut builder, + &st_tmpl_target, + &args_target, + ); // TODO: Instead of connect, assign witness to result let expected_st_target = builder.add_virtual_statement(params); builder.connect_flattenable(&expected_st_target, &st_target); @@ -2650,9 +2591,6 @@ mod tests { ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let gadget = CustomOperationVerifyGadget { - params: params.clone(), - }; let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params); let op_args_target: Vec<_> = (0..args.len()) @@ -2661,7 +2599,8 @@ mod tests { let args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_value()) .collect(); - let (st_target, op_type_target) = gadget.eval( + let (st_target, op_type_target) = make_custom_statement_circuit( + params, &mut builder, &custom_predicate_target, &op_args_target, @@ -2966,14 +2905,11 @@ mod tests { fn helper_calculate_id(params: &Params, statements: &[Statement]) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let gadget = CalculateIdGadget { - params: params.clone(), - }; let statements_target = (0..params.max_public_statements) .map(|_| builder.add_virtual_statement(params)) .collect_vec(); - let id_target = gadget.eval(&mut builder, &statements_target); + let id_target = calculate_id_circuit(params, &mut builder, &statements_target); let mut pw = PartialWitness::::new(); diff --git a/src/backends/plonky2/circuits/signedpod.rs b/src/backends/plonky2/circuits/signedpod.rs index ee5a1e2d..065f884b 100644 --- a/src/backends/plonky2/circuits/signedpod.rs +++ b/src/backends/plonky2/circuits/signedpod.rs @@ -16,9 +16,10 @@ use crate::{ error::Result, primitives::{ merkletree::{ - MerkleClaimAndProof, MerkleProofExistenceGadget, MerkleProofExistenceTarget, + verify_merkle_proof_existence_circuit, MerkleClaimAndProof, + MerkleProofExistenceTarget, }, - signature::{SignatureVerifyGadget, SignatureVerifyTarget}, + signature::{verify_signature_circuit, SignatureVerifyTarget}, }, signedpod::SignedPod, }, @@ -29,53 +30,45 @@ use crate::{ }, }; -pub struct SignedPodVerifyGadget { - pub params: Params, -} - -impl SignedPodVerifyGadget { - pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { - let measure = measure_gates_begin!(builder, "SignedPodVerify"); - // 1. Verify id - let id = builder.add_virtual_hash(); - let mut mt_proofs = Vec::new(); - for _ in 0..self.params.max_signed_pod_values { - let mt_proof = MerkleProofExistenceGadget { - max_depth: self.params.max_depth_mt_containers, - } - .eval(builder)?; - builder.connect_hashes(id, mt_proof.root); - mt_proofs.push(mt_proof); - } - - // 2. Verify type - let type_mt_proof = &mt_proofs[0]; - let key_type = builder.constant_value(hash_str(KEY_TYPE).into()); - builder.connect_values(type_mt_proof.key, key_type); - let value_type = builder.constant_value(Value::from(PodType::Signed).raw()); - builder.connect_values(type_mt_proof.value, value_type); - - // 3.a. Verify signature - let signature = SignatureVerifyGadget {}.eval(builder)?; - - // 3.b. Verify signer (ie. hash(signature.pk) == merkletree.signer_leaf) - let signer_mt_proof = &mt_proofs[1]; - let key_signer = builder.constant_value(Key::from(KEY_SIGNER).raw()); - let pk_hash = signature.pk.to_value(builder); - builder.connect_values(signer_mt_proof.key, key_signer); - builder.connect_values(signer_mt_proof.value, pk_hash); - - // 3.c. connect signed message to pod.id - builder.connect_values(ValueTarget::from_slice(&id.elements), signature.msg); - - measure_gates_end!(builder, measure); - Ok(SignedPodVerifyTarget { - params: self.params.clone(), - id, - mt_proofs, - signature, - }) +pub fn verify_signed_pod_circuit( + builder: &mut CircuitBuilder, + signed_pod: &SignedPodVerifyTarget, +) -> Result<()> { + let params = &signed_pod.params; + let measure = measure_gates_begin!(builder, "SignedPodVerify"); + // 1. Verify id + assert_eq!(params.max_signed_pod_values, signed_pod.mt_proofs.len()); + for mt_proof in &signed_pod.mt_proofs { + verify_merkle_proof_existence_circuit(builder, mt_proof); + builder.connect_hashes(signed_pod.id, mt_proof.root); + // mt_proofs.push(mt_proof); } + + // 2. Verify type + let type_mt_proof = &signed_pod.mt_proofs[0]; + let key_type = builder.constant_value(hash_str(KEY_TYPE).into()); + builder.connect_values(type_mt_proof.key, key_type); + let value_type = builder.constant_value(Value::from(PodType::Signed).raw()); + builder.connect_values(type_mt_proof.value, value_type); + + // 3.a. Verify signature + verify_signature_circuit(builder, &signed_pod.signature); + + // 3.b. Verify signer (ie. hash(signature.pk) == merkletree.signer_leaf) + let signer_mt_proof = &signed_pod.mt_proofs[1]; + let key_signer = builder.constant_value(Key::from(KEY_SIGNER).raw()); + let pk_hash = signed_pod.signature.pk.to_value(builder); + builder.connect_values(signer_mt_proof.key, key_signer); + builder.connect_values(signer_mt_proof.value, pk_hash); + + // 3.c. connect signed message to pod.id + builder.connect_values( + ValueTarget::from_slice(&signed_pod.id.elements), + signed_pod.signature.msg, + ); + + measure_gates_end!(builder, measure); + Ok(()) } pub struct SignedPodVerifyTarget { @@ -88,6 +81,18 @@ pub struct SignedPodVerifyTarget { } impl SignedPodVerifyTarget { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + SignedPodVerifyTarget { + params: params.clone(), + id: builder.add_virtual_hash(), + mt_proofs: (0..params.max_signed_pod_values) + .map(|_| { + MerkleProofExistenceTarget::new_virtual(params.max_depth_mt_containers, builder) + }) + .collect(), + signature: SignatureVerifyTarget::new_virtual(builder), + } + } pub fn pub_statements( &self, builder: &mut CircuitBuilder, @@ -229,7 +234,8 @@ pub mod tests { let mut pw = PartialWitness::::new(); // build the circuit logic - let signed_pod_verify = SignedPodVerifyGadget { params }.eval(&mut builder)?; + let signed_pod_verify = SignedPodVerifyTarget::new_virtual(¶ms, &mut builder); + verify_signed_pod_circuit(&mut builder, &signed_pod_verify)?; // set the signed_pod as target values for the circuit signed_pod_verify.set_targets(&mut pw, &signed_pod)?; diff --git a/src/backends/plonky2/emptypod.rs b/src/backends/plonky2/emptypod.rs index 8eb44a8f..7585f31b 100644 --- a/src/backends/plonky2/emptypod.rs +++ b/src/backends/plonky2/emptypod.rs @@ -20,7 +20,7 @@ use crate::{ basetypes::{Proof, C, D}, circuits::{ common::{Flattenable, StatementTarget}, - mainpod::{CalculateIdGadget, PI_OFFSET_ID}, + mainpod::{calculate_id_circuit, PI_OFFSET_ID}, }, deserialize_proof, error::{Error, Result}, @@ -52,10 +52,7 @@ impl EmptyPodVerifyCircuit { &self.params, &builder.constants(&type_statement().to_fields(&self.params)), ); - let id = CalculateIdGadget { - params: self.params.clone(), - } - .eval(builder, &[type_statement]); + let id = calculate_id_circuit(&self.params, builder, &[type_statement]); let vds_root = builder.add_virtual_hash(); builder.register_public_inputs(&id.elements); builder.register_public_inputs(&vds_root.elements); diff --git a/src/backends/plonky2/primitives/ec/schnorr.rs b/src/backends/plonky2/primitives/ec/schnorr.rs index 5e2242e3..516a11d7 100644 --- a/src/backends/plonky2/primitives/ec/schnorr.rs +++ b/src/backends/plonky2/primitives/ec/schnorr.rs @@ -109,6 +109,10 @@ pub trait WitnessWriteSchnorr: WitnessWrite + WitnessWriteCurve impl> WitnessWriteSchnorr for W {} +// TODO: Rename this to a function `verify_signature_circuit`? I think this convention is also +// nice as an object-oriented alternative to `verb_object_circuit` methods. It's clear that this +// is constraining vs native operation because the type is `*Target`. But of course this +// convention doesn't work for all situations. impl SignatureTarget { pub fn verify( &self, diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index 693b84b0..52699abb 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -34,14 +34,6 @@ use crate::{ middleware::{EMPTY_HASH, EMPTY_VALUE, F, HASH_SIZE}, }; -/// `MerkleProofGadget` allows to verify both proofs of existence and proofs -/// non-existence with the same circuit. -/// If only proofs of existence are needed, use `MerkleProofExistenceGadget`, -/// which requires less amount of constraints. -pub struct MerkleProofGadget { - pub max_depth: usize, -} - #[derive(Clone, Debug)] pub struct MerkleClaimAndProofTarget { pub(crate) max_depth: usize, @@ -57,108 +49,103 @@ pub struct MerkleClaimAndProofTarget { pub(crate) other_value: ValueTarget, } -impl MerkleProofGadget { - /// creates the targets and defines the logic of the circuit - pub fn eval(&self, builder: &mut CircuitBuilder) -> MerkleClaimAndProofTarget { - let measure = measure_gates_begin!(builder, format!("MerkleProof_{}", self.max_depth)); - let enabled = builder.add_virtual_bool_target_safe(); - let root = builder.add_virtual_hash(); - let key = builder.add_virtual_value(); - let value = builder.add_virtual_value(); - // from proof struct: - let existence = builder.add_virtual_bool_target_safe(); - // siblings are padded till max_depth length - let siblings = builder.add_virtual_hashes(self.max_depth); - - let case_ii_selector = builder.add_virtual_bool_target_safe(); - let other_key = builder.add_virtual_value(); - let other_value = builder.add_virtual_value(); - - // We have 3 cases for when computing the Leaf's hash: - // - existence: leaf contains the given key & value - // - non-existence: - // - case i) expected leaf does not exist - // - case ii) expected leaf does exist but it has a different key - // - // The following table expresses the options with their in-circuit - // selectors: - // | existence | case_ii | leaf_hash | - // | ----------- | --------- | ---------------------------- | - // | 1 | 0 | H(key, value, 1) | - // | 0 | 0 | EMPTY_HASH | - // | 0 | 1 | H(other_key, other_value, 1) | - // | 1 | 1 | invalid combination | - - // First, ensure that both existence & case_ii are not true at the same - // time: - // 1. sum = existence + case_ii_selector - let sum = builder.add(existence.target, case_ii_selector.target); - // 2. sum * (sum-1) == 0 - builder.assert_bool(BoolTarget::new_unsafe(sum)); - - // define the case_i_selector as true when both existence and - // case_ii_selector are false: - let not_existence = builder.not(existence); - let not_case_ii_selector = builder.not(case_ii_selector); - let case_i_selector = builder.and(not_existence, not_case_ii_selector); - - // use (key,value) or (other_key, other_value) depending if it's a proof - // of existence or of non-existence, ie: - // k = key * existence + other_key * (1-existence) - // v = value * existence + other_value * (1-existence) - let k = builder.select_value(existence, key, other_key); - let v = builder.select_value(existence, value, other_value); - - // get leaf's hash for the selected k & v - let h = kv_hash_target(builder, &k, &v); - - // if we're in the case i), use leaf_hash=EMPTY_HASH, else use the - // previously computed hash h. - let empty_hash = builder.constant_hash(HashOut::from(EMPTY_HASH.0)); - let leaf_hash = HashOutTarget::from_vec( - (0..HASH_SIZE) - .map(|j| builder.select(case_i_selector, empty_hash.elements[j], h.elements[j])) - .collect(), - ); - - // get key's path - let path = keypath_target(self.max_depth, builder, &key); - - // compute the root for the given siblings and the computed leaf_hash - // (this is for the three cases (existence, non-existence case i, and - // non-existence case ii). - let obtained_root = - compute_root_from_leaf(self.max_depth, builder, &path, &leaf_hash, &siblings); - - // check that obtained_root==root (from inputs), when enabled==true - let zero = builder.zero(); - let expected_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(enabled, root.elements[j], zero)) - .collect(); - let computed_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(enabled, obtained_root.elements[j], zero)) - .collect(); - for j in 0..HASH_SIZE { - builder.connect(computed_root[j], expected_root[j]); - } - measure_gates_end!(builder, measure); - - MerkleClaimAndProofTarget { - max_depth: self.max_depth, - enabled, - existence, - root, - siblings, - key, - value, - case_ii_selector, - other_key, - other_value, - } +/// Allows to verify both proofs of existence and proofs non-existence with the same circuit. If +/// only proofs of existence are needed, use `verify_merkle_proof_existence_circuit`, which +/// requires less amount of constraints. +pub fn verify_merkle_proof_circuit( + builder: &mut CircuitBuilder, + proof: &MerkleClaimAndProofTarget, +) { + let max_depth = proof.max_depth; + let measure = measure_gates_begin!(builder, format!("MerkleProof_{}", max_depth)); + + // We have 3 cases for when computing the Leaf's hash: + // - existence: leaf contains the given key & value + // - non-existence: + // - case i) expected leaf does not exist + // - case ii) expected leaf does exist but it has a different key + // + // The following table expresses the options with their in-circuit + // selectors: + // | existence | case_ii | leaf_hash | + // | ----------- | --------- | ---------------------------- | + // | 1 | 0 | H(key, value, 1) | + // | 0 | 0 | EMPTY_HASH | + // | 0 | 1 | H(other_key, other_value, 1) | + // | 1 | 1 | invalid combination | + + // First, ensure that both existence & case_ii are not true at the same + // time: + // 1. sum = existence + case_ii_selector + let sum = builder.add(proof.existence.target, proof.case_ii_selector.target); + // 2. sum * (sum-1) == 0 + builder.assert_bool(BoolTarget::new_unsafe(sum)); + + // define the case_i_selector as true when both existence and + // case_ii_selector are false: + let not_existence = builder.not(proof.existence); + let not_case_ii_selector = builder.not(proof.case_ii_selector); + let case_i_selector = builder.and(not_existence, not_case_ii_selector); + + // use (key,value) or (other_key, other_value) depending if it's a proof + // of existence or of non-existence, ie: + // k = key * existence + other_key * (1-existence) + // v = value * existence + other_value * (1-existence) + let k = builder.select_value(proof.existence, proof.key, proof.other_key); + let v = builder.select_value(proof.existence, proof.value, proof.other_value); + + // get leaf's hash for the selected k & v + let h = kv_hash_target(builder, &k, &v); + + // if we're in the case i), use leaf_hash=EMPTY_HASH, else use the + // previously computed hash h. + let empty_hash = builder.constant_hash(HashOut::from(EMPTY_HASH.0)); + let leaf_hash = HashOutTarget::from_vec( + (0..HASH_SIZE) + .map(|j| builder.select(case_i_selector, empty_hash.elements[j], h.elements[j])) + .collect(), + ); + + // get key's path + let path = keypath_target(max_depth, builder, &proof.key); + + // compute the root for the given siblings and the computed leaf_hash + // (this is for the three cases (existence, non-existence case i, and + // non-existence case ii). + let obtained_root = + compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); + + // check that obtained_root==root (from inputs), when enabled==true + let zero = builder.zero(); + let expected_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) + .collect(); + let computed_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) + .collect(); + for j in 0..HASH_SIZE { + builder.connect(computed_root[j], expected_root[j]); } + measure_gates_end!(builder, measure); } impl MerkleClaimAndProofTarget { + pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { + MerkleClaimAndProofTarget { + max_depth, + enabled: builder.add_virtual_bool_target_safe(), + root: builder.add_virtual_hash(), + key: builder.add_virtual_value(), + value: builder.add_virtual_value(), + // from proof struct: + existence: builder.add_virtual_bool_target_safe(), + // siblings are padded till max_depth length + siblings: builder.add_virtual_hashes(max_depth), + case_ii_selector: builder.add_virtual_bool_target_safe(), + other_key: builder.add_virtual_value(), + other_value: builder.add_virtual_value(), + } + } /// assigns the given values to the targets #[allow(clippy::too_many_arguments)] pub fn set_targets( @@ -205,12 +192,6 @@ impl MerkleClaimAndProofTarget { } } -/// `MerkleProofExistenceCircuit` allows to verify proofs of existence only. If -/// proofs of non-existence are needed, use `MerkleProofCircuit`. -pub struct MerkleProofExistenceGadget { - pub max_depth: usize, -} - pub struct MerkleProofExistenceTarget { max_depth: usize, // `enabled` determines if the merkleproof verification is enabled @@ -221,52 +202,51 @@ pub struct MerkleProofExistenceTarget { pub(crate) siblings: Vec, } -impl MerkleProofExistenceGadget { - /// creates the targets and defines the logic of the circuit - pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { - let measure = measure_gates_begin!(builder, format!("MerkleProofExist_{}", self.max_depth)); - let enabled = builder.add_virtual_bool_target_safe(); - let root = builder.add_virtual_hash(); - let key = builder.add_virtual_value(); - let value = builder.add_virtual_value(); - // siblings are padded till max_depth length - let siblings = builder.add_virtual_hashes(self.max_depth); - - // get leaf's hash for the selected k & v - let leaf_hash = kv_hash_target(builder, &key, &value); - - // get key's path - let path = keypath_target(self.max_depth, builder, &key); - - // compute the root for the given siblings and the computed leaf_hash. - let obtained_root = - compute_root_from_leaf(self.max_depth, builder, &path, &leaf_hash, &siblings); - - // check that obtained_root==root (from inputs), when enabled==true - let zero = builder.zero(); - let expected_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(enabled, root.elements[j], zero)) - .collect(); - let computed_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(enabled, obtained_root.elements[j], zero)) - .collect(); - for j in 0..HASH_SIZE { - builder.connect(computed_root[j], expected_root[j]); - } - measure_gates_end!(builder, measure); - - Ok(MerkleProofExistenceTarget { - max_depth: self.max_depth, - enabled, - root, - siblings, - key, - value, - }) +/// Allows to verify proofs of existence only. If proofs of non-existence are needed, use +/// `verify_merkle_proof_circuit`. +pub fn verify_merkle_proof_existence_circuit( + builder: &mut CircuitBuilder, + proof: &MerkleProofExistenceTarget, +) { + let max_depth = proof.max_depth; + let measure = measure_gates_begin!(builder, format!("MerkleProofExist_{}", max_depth)); + + // get leaf's hash for the selected k & v + let leaf_hash = kv_hash_target(builder, &proof.key, &proof.value); + + // get key's path + let path = keypath_target(max_depth, builder, &proof.key); + + // compute the root for the given siblings and the computed leaf_hash. + let obtained_root = + compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); + + // check that obtained_root==root (from inputs), when enabled==true + let zero = builder.zero(); + let expected_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) + .collect(); + let computed_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) + .collect(); + for j in 0..HASH_SIZE { + builder.connect(computed_root[j], expected_root[j]); } + measure_gates_end!(builder, measure); } impl MerkleProofExistenceTarget { + pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { + MerkleProofExistenceTarget { + max_depth, + enabled: builder.add_virtual_bool_target_safe(), + root: builder.add_virtual_hash(), + key: builder.add_virtual_value(), + value: builder.add_virtual_value(), + // siblings are padded till max_depth length + siblings: builder.add_virtual_hashes(max_depth), + } + } /// assigns the given values to the targets pub fn set_targets( &self, @@ -545,7 +525,8 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofGadget { max_depth }.eval(&mut builder); + let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, true, @@ -591,7 +572,8 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofExistenceGadget { max_depth }.eval(&mut builder)?; + let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, true, @@ -666,7 +648,8 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofGadget { max_depth }.eval(&mut builder); + let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, true, @@ -713,7 +696,8 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofGadget { max_depth }.eval(&mut builder); + let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); // verification enabled & proof of existence let mp = MerkleClaimAndProof::new(tree2.root(), key, Some(value), proof); targets.set_targets(&mut pw, true, &mp)?; @@ -729,7 +713,8 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofGadget { max_depth }.eval(&mut builder); + let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); // verification disabled & proof of existence targets.set_targets(&mut pw, false, &mp)?; diff --git a/src/backends/plonky2/primitives/signature/circuit.rs b/src/backends/plonky2/primitives/signature/circuit.rs index 91aa2651..6903be9b 100644 --- a/src/backends/plonky2/primitives/signature/circuit.rs +++ b/src/backends/plonky2/primitives/signature/circuit.rs @@ -35,8 +35,9 @@ use crate::{ middleware::{Hash, Proof, RawValue, EMPTY_HASH, EMPTY_VALUE, F, VALUE_SIZE}, }; -pub struct SignatureVerifyGadget; - +// TODO: This is a very simple wrapper over the signature verification implemented on +// `SignatureTarget`. I think we can remove this and use it directly. Also we're not using the +// `enabled` flag, so it should be straight-forward to remove this. pub struct SignatureVerifyTarget { // `enabled` determines if the signature verification is enabled pub(crate) enabled: BoolTarget, @@ -46,32 +47,34 @@ pub struct SignatureVerifyTarget { pub(crate) sig: SignatureTarget, } -impl SignatureVerifyGadget { - /// creates the targets and defines the logic of the circuit - pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { - let measure = measure_gates_begin!(builder, "SignatureVerify"); - let enabled = builder.add_virtual_bool_target_safe(); - let pk = builder.add_virtual_point_target(); - let msg = builder.add_virtual_value(); - let sig = builder.add_virtual_schnorr_signature_target(); - - let verified = sig.verify(builder, HashOutTarget::from(msg.elements), &pk); - - let result = builder.mul_sub(enabled.target, verified.target, enabled.target); - - builder.assert_zero(result); - - measure_gates_end!(builder, measure); - Ok(SignatureVerifyTarget { - enabled, - pk, - msg, - sig, - }) - } +pub fn verify_signature_circuit( + builder: &mut CircuitBuilder, + signature: &SignatureVerifyTarget, +) { + let measure = measure_gates_begin!(builder, "SignatureVerify"); + let verified = signature.sig.verify( + builder, + HashOutTarget::from(signature.msg.elements), + &signature.pk, + ); + let result = builder.mul_sub( + signature.enabled.target, + verified.target, + signature.enabled.target, + ); + builder.assert_zero(result); + measure_gates_end!(builder, measure); } impl SignatureVerifyTarget { + pub fn new_virtual(builder: &mut CircuitBuilder) -> Self { + SignatureVerifyTarget { + enabled: builder.add_virtual_bool_target_safe(), + pk: builder.add_virtual_point_target(), + msg: builder.add_virtual_value(), + sig: builder.add_virtual_schnorr_signature_target(), + } + } /// assigns the given values to the targets pub fn set_targets( &self, @@ -115,7 +118,8 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = SignatureVerifyGadget {}.eval(&mut builder)?; + let targets = SignatureVerifyTarget::new_virtual(&mut builder); + verify_signature_circuit(&mut builder, &targets); targets.set_targets(&mut pw, true, pk, msg, sig)?; // generate & verify proof @@ -147,8 +151,9 @@ pub mod tests { // circuit let config = CircuitConfig::standard_recursion_zk_config(); let mut builder = CircuitBuilder::::new(config); + let targets = SignatureVerifyTarget::new_virtual(&mut builder); + verify_signature_circuit(&mut builder, &targets); let mut pw = PartialWitness::::new(); - let targets = SignatureVerifyGadget {}.eval(&mut builder)?; targets.set_targets(&mut pw, true, pk, msg, sig.clone())?; // enabled=true // generate proof, and expect it to fail @@ -162,7 +167,8 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = SignatureVerifyGadget {}.eval(&mut builder)?; + let targets = SignatureVerifyTarget::new_virtual(&mut builder); + verify_signature_circuit(&mut builder, &targets); targets.set_targets(&mut pw, false, pk, msg, sig)?; // enabled=false // generate & verify proof