diff --git a/Cargo.toml b/Cargo.toml index 9acc5093..793a4408 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,10 @@ serde = "1.0.219" serde_json = "1.0.140" base64 = "0.22.1" schemars = "0.8.22" +num = { version = "0.4.3", features = ["num-bigint"] } +num-bigint = { version = "0.4.6", features = ["rand"] } +# num-bigint 0.4 requires rand 0.8 +rand = "0.8.5" hashbrown = { version = "0.14.3", default-features = false, features = ["serde"] } # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 5d303c9d..6bca77c1 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -2965,7 +2965,7 @@ mod tests { // Input let statements = statements - .into_iter() + .iter() .map(|st| { let mut st = mainpod::Statement::from(st.clone()); pad_statement(params, &mut st); diff --git a/src/backends/plonky2/circuits/signedpod.rs b/src/backends/plonky2/circuits/signedpod.rs index d80f84d2..e2889e1b 100644 --- a/src/backends/plonky2/circuits/signedpod.rs +++ b/src/backends/plonky2/circuits/signedpod.rs @@ -18,7 +18,7 @@ use crate::{ merkletree::{ MerkleClaimAndProof, MerkleProofExistenceGadget, MerkleProofExistenceTarget, }, - signature::{PublicKey, SignatureVerifyGadget, SignatureVerifyTarget}, + signature::{SignatureVerifyGadget, SignatureVerifyTarget}, }, signedpod::SignedPod, }, @@ -58,11 +58,12 @@ impl SignedPodVerifyGadget { // 3.a. Verify signature let signature = SignatureVerifyGadget {}.eval(builder)?; - // 3.b. Verify signer (ie. signature.pk == merkletree.signer_leaf) + // 3.b. Verify signer (ie. hash(signature.pk) == merkletree.signer_leaf) let signer_mt_proof = &mt_proofs[1]; let key_signer = builder.constant_value(Key::from(KEY_SIGNER).raw()); + let pk_hash = signature.pk.to_value(builder); builder.connect_values(signer_mt_proof.key, key_signer); - builder.connect_values(signer_mt_proof.value, signature.pk); + builder.connect_values(signer_mt_proof.value, pk_hash); // 3.c. connect signed message to pod.id builder.connect_values(ValueTarget::from_slice(&id.elements), signature.msg); @@ -130,19 +131,17 @@ impl SignedPodVerifyTarget { // add proof verification of KEY_TYPE & KEY_SIGNER leaves let key_type_key = Key::from(KEY_TYPE); let key_signer_key = Key::from(KEY_SIGNER); - let key_signer_value = [&key_type_key, &key_signer_key] + [&key_type_key, &key_signer_key] .iter() .enumerate() - .map(|(i, k)| { + .try_for_each(|(i, k)| { let (v, proof) = pod.dict.prove(k)?; self.mt_proofs[i].set_targets( pw, true, &MerkleClaimAndProof::new(pod.dict.commitment(), k.raw(), Some(v.raw()), proof), - )?; - Ok(v) - }) - .collect::>>()?[1]; + ) + })?; // add the verification of the rest of leaves let mut curr = 2; // since we already added key_type and key_signer @@ -174,7 +173,7 @@ impl SignedPodVerifyTarget { } // get the signer pk - let pk = PublicKey(key_signer_value.raw()); + let pk = pod.signer; // the msg signed is the pod.id let msg = RawValue::from(pod.id.0); @@ -199,7 +198,7 @@ pub mod tests { use crate::{ backends::plonky2::{ basetypes::C, - primitives::signature::SecretKey, + primitives::ec::schnorr::SecretKey, signedpod::{SignedPod, Signer}, }, middleware::F, diff --git a/src/backends/plonky2/error.rs b/src/backends/plonky2/error.rs index a0ca4e18..f3a4472b 100644 --- a/src/backends/plonky2/error.rs +++ b/src/backends/plonky2/error.rs @@ -10,6 +10,8 @@ pub enum InnerError { IdNotEqual(PodId, PodId), #[error("type does not match, expected {0}, found {1}")] TypeNotEqual(PodType, Value), + #[error("signer public key does not match, expected {0}, found {1}")] + SignerNotEqual(Value, Value), // POD related #[error("invalid POD ID")] @@ -90,4 +92,7 @@ impl Error { pub fn type_not_equal(expected: PodType, found: Value) -> Self { new!(TypeNotEqual(expected, found)) } + pub(crate) fn signer_not_equal(expected: Value, found: Value) -> Self { + new!(SignerNotEqual(expected, found)) + } } diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 0746ac9d..809a4202 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -15,14 +15,12 @@ pub use statement::*; use crate::{ backends::plonky2::{ basetypes::{Proof, ProofWithPublicInputs, VerifierOnlyCircuitData, D}, - circuits::mainpod::{ - CustomPredicateVerification, MainPodVerifyInput, MainPodVerifyTarget, NUM_PUBLIC_INPUTS, - }, + circuits::mainpod::{CustomPredicateVerification, MainPodVerifyInput, MainPodVerifyTarget}, emptypod::EmptyPod, error::{Error, Result}, mock::emptypod::MockEmptyPod, primitives::merkletree::MerkleClaimAndProof, - recursion::{self, RecursiveCircuit, RecursiveParams}, + recursion::{RecursiveCircuit, RecursiveParams}, signedpod::SignedPod, STANDARD_REC_MAIN_POD_CIRCUIT_DATA, }, @@ -550,12 +548,14 @@ pub struct MainPod { fn get_common_data(params: &Params) -> Result, Error> { // TODO: Cache this somehow // https://github.com/0xPARC/pod2/issues/247 - let rec_params = recursion::new_params::( - params.max_input_recursive_pods, - NUM_PUBLIC_INPUTS, - params, - )?; - Ok(rec_params.common_data().clone()) + let rec_circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; + let (_, circuit_data) = + RecursiveCircuit::::target_and_circuit_data_padded( + params.max_input_recursive_pods, + &rec_circuit_data.common, + params, + )?; + Ok(circuit_data.common.clone()) } impl MainPod { @@ -682,11 +682,13 @@ impl RecursivePod for MainPod { #[cfg(test)] pub mod tests { + use num::{BigUint, One}; + use super::*; use crate::{ backends::plonky2::{ mock::mainpod::{MockMainPod, MockProver}, - primitives::signature::SecretKey, + primitives::ec::schnorr::SecretKey, signedpod::Signer, }, examples::{ @@ -698,7 +700,7 @@ pub mod tests { {self}, }, middleware, - middleware::{CustomPredicateRef, NativePredicate as NP, RawValue}, + middleware::{CustomPredicateRef, NativePredicate as NP}, op, }; @@ -716,11 +718,11 @@ pub mod tests { let (gov_id_builder, pay_stub_builder, sanction_list_builder) = zu_kyc_sign_pod_builders(¶ms); - let mut signer = Signer(SecretKey(RawValue::from(1))); + let mut signer = Signer(SecretKey(BigUint::one())); let gov_id_pod = gov_id_builder.sign(&mut signer)?; - let mut signer = Signer(SecretKey(RawValue::from(2))); + let mut signer = Signer(SecretKey(2u64.into())); let pay_stub_pod = pay_stub_builder.sign(&mut signer)?; - let mut signer = Signer(SecretKey(RawValue::from(3))); + let mut signer = Signer(SecretKey(3u64.into())); let sanction_list_pod = sanction_list_builder.sign(&mut signer)?; let kyc_builder = zu_kyc_pod_builder(¶ms, &gov_id_pod, &pay_stub_pod, &sanction_list_pod)?; @@ -749,7 +751,7 @@ pub mod tests { gov_id_builder.insert("idNumber", "4242424242"); gov_id_builder.insert("dateOfBirth", 1169909384); gov_id_builder.insert("socialSecurityNumber", "G2121210"); - let mut signer = Signer(SecretKey(RawValue::from(42))); + let mut signer = Signer(SecretKey(42u64.into())); let gov_id = gov_id_builder.sign(&mut signer).unwrap(); let now_minus_18y: i64 = 1169909388; let mut kyc_builder = frontend::MainPodBuilder::new(¶ms); @@ -831,24 +833,23 @@ pub mod tests { }; println!("{:#?}", params); - let mut alice = Signer(SecretKey(RawValue::from(1))); - let bob = Signer(SecretKey(RawValue::from(2))); - let mut charlie = Signer(SecretKey(RawValue::from(3))); + let mut alice = Signer(SecretKey(1u32.into())); + let bob = Signer(SecretKey(2u32.into())); + let mut charlie = Signer(SecretKey(3u32.into())); // Alice attests that she is ETH friends with Charlie and Charlie // attests that he is ETH friends with Bob. let alice_attestation = - eth_friend_signed_pod_builder(¶ms, charlie.public_key().0.into()) - .sign(&mut alice)?; + eth_friend_signed_pod_builder(¶ms, charlie.public_key().into()).sign(&mut alice)?; let charlie_attestation = - eth_friend_signed_pod_builder(¶ms, bob.public_key().0.into()).sign(&mut charlie)?; + eth_friend_signed_pod_builder(¶ms, bob.public_key().into()).sign(&mut charlie)?; let alice_bob_ethdos_builder = eth_dos_pod_builder( ¶ms, false, &alice_attestation, &charlie_attestation, - bob.public_key().0.into(), + bob.public_key().into(), )?; let mut prover = MockProver {}; diff --git a/src/backends/plonky2/primitives/ec/bits.rs b/src/backends/plonky2/primitives/ec/bits.rs new file mode 100644 index 00000000..b03dc9d5 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/bits.rs @@ -0,0 +1,365 @@ +use std::{array, marker::PhantomData}; + +use num::BigUint; +use plonky2::{ + field::{ + extension::Extendable, + goldilocks_field::GoldilocksField, + types::{Field, Field64}, + }, + hash::hash_types::RichField, + iop::{ + generator::{GeneratedValues, SimpleGenerator}, + target::{BoolTarget, Target}, + witness::{PartitionWitness, Witness, WitnessWrite}, + }, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData}, + util::serialization::{Buffer, IoResult, Read, Write}, +}; + +use crate::backends::plonky2::basetypes::{D, F}; + +#[derive(Debug)] +struct ConditionalZeroGenerator, const D: usize> { + if_zero: Target, + then_zero: Target, + quot: Target, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for ConditionalZeroGenerator +{ + fn id(&self) -> String { + "ConditionalZeroGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + vec![self.if_zero, self.then_zero] + } + + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> anyhow::Result<()> { + let if_zero = witness.get_target(self.if_zero); + let then_zero = witness.get_target(self.then_zero); + if if_zero.is_zero() { + out_buffer.set_target(self.quot, F::ZERO)?; + } else { + out_buffer.set_target(self.quot, then_zero / if_zero)?; + } + + Ok(()) + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_target(self.if_zero)?; + dst.write_target(self.then_zero)?; + dst.write_target(self.quot) + } + + fn deserialize( + src: &mut plonky2::util::serialization::Buffer, + _common_data: &CommonCircuitData, + ) -> IoResult + where + Self: Sized, + { + Ok(Self { + if_zero: src.read_target()?, + then_zero: src.read_target()?, + quot: src.read_target()?, + _phantom: PhantomData, + }) + } +} + +/// A big integer, represented in base `2^32` with 10 digits, in little endian +/// form. +#[derive(Clone, Debug)] +pub struct BigUInt320Target { + pub limbs: [Target; 10], + pub bits: [BoolTarget; 320], +} + +pub trait CircuitBuilderBits { + /// Enforces the constraint that `then_zero` must be zero if `if_zero` + /// is zero. + /// + /// The prover is required to exhibit a solution to the equation + /// `if_zero * x == then_zero`. If both `if_zero` and `then_zero` + /// are zero, then it chooses the solution `x = 0`. + fn conditional_zero(&mut self, if_zero: Target, then_zero: Target); + + /// Decomposes the target x as `y + 2^32 z`, where `0 < y,z < 2**32`, and + /// `y=0` if `z=2**32-1`. Note that calling [`CircuitBuilder::split_le`] + /// with `num_bits = 64` will not check the latter condition. + fn split_32_bit(&mut self, x: Target) -> [Target; 2]; + + /// Like `split_low_high` except it doesn't discard the bit decompositions. + fn split_low_high_with_bits( + &mut self, + x: Target, + n_log: usize, + num_bits: usize, + ) -> ((Target, Vec), (Target, Vec)); + + /// Interprets `arr` as an integer in base `[GoldilocksField::ORDER]`, + /// with the digits in little endian order. The length of `arr` must be at + /// most 5. + fn field_elements_to_biguint(&mut self, arr: &[Target]) -> BigUInt320Target; + + fn constant_biguint320(&mut self, n: &BigUint) -> BigUInt320Target; + fn biguint320_target_from_limbs(&mut self, x: &[Target]) -> BigUInt320Target; + fn add_virtual_biguint320_target(&mut self) -> BigUInt320Target; + fn connect_biguint320(&mut self, x: &BigUInt320Target, y: &BigUInt320Target); +} + +impl CircuitBuilderBits for CircuitBuilder { + fn conditional_zero(&mut self, if_zero: Target, then_zero: Target) { + let quot = self.add_virtual_target(); + self.add_simple_generator(ConditionalZeroGenerator { + if_zero, + then_zero, + quot, + _phantom: PhantomData, + }); + let prod = self.mul(if_zero, quot); + self.connect(prod, then_zero); + } + + fn field_elements_to_biguint(&mut self, arr: &[Target]) -> BigUInt320Target { + assert!(arr.len() <= 5); + let zero = self.zero(); + let neg_one = self.neg_one(); + let two_32 = self.constant(GoldilocksField::from_canonical_u64(1 << 32)); + // Apply Horner's method to Σarr[i]*p^i. + // First map each target to its limbs. + let arr_limbs: Vec<_> = arr + .iter() + .map(|x| (self.split_32_bit(*x).to_vec(), vec![])) + .collect(); + let (res_limbs, res_bits) = arr_limbs + .into_iter() + .rev() + .enumerate() + .reduce(|(_, res), (i, a)| { + // Compute p*res in unnormalised form, where each + // coefficient is offset so as to ensure none (except + // possibly the last) underflow. + let prod = (0..=(2 * i + 1)) + .map(|j| { + if j == 0 { + // x_0 + res.0[0] + } else if j == 1 { + // x_1 - x_0 + 2^32 + let diff = self.sub(res.0[1], res.0[0]); + self.add(diff, two_32) + } else if j < 2 * i { + // x_j + x_{j-2} - x_{j-1} + 2^32 - 1 + let diff = self.sub(res.0[j], res.0[j - 1]); + let sum = self.add(diff, res.0[j - 2]); + let sum = self.add(sum, two_32); + self.add(sum, neg_one) + } else if j == 2 * i { + // x_{2*j - 2} - x_{2*j - 1} + 2^32 + let diff = self.sub(res.0[2 * i - 2], res.0[2 * i - 1]); + let sum = self.add(diff, two_32); + self.add(sum, neg_one) + } else { + // x_{2*i - 1} - 1 + self.add(res.0[2 * i - 1], neg_one) + } + }) + .collect::>(); + // Add arr[i]. + let prod_plus_lot = prod + .into_iter() + .enumerate() + .map(|(i, x)| match i { + 0 => self.add(a.0[0], x), + 1 => self.add(a.0[1], x), + _ => x, + }) + .collect::>(); + // Normalise. + ( + i, + normalize_biguint_limbs(self, &prod_plus_lot, 34, 2 * i + 1), + ) + }) + .map(|(_, v)| v) + .unwrap_or((vec![], vec![])); + // Collect limbs, padding with 0s if necessary. + let limbs: [Target; 10] = array::from_fn(|i| { + if i < res_limbs.len() { + res_limbs[i] + } else { + zero + } + }); + // Collect bits, padding with 0s if necessary. + let bits: [BoolTarget; 320] = array::from_fn(|i| { + if i < res_bits.len() { + res_bits[i] + } else { + self._false() + } + }); + BigUInt320Target { limbs, bits } + } + + fn split_32_bit(&mut self, x: Target) -> [Target; 2] { + let (low, high) = self.split_low_high(x, 32, 64); + let max = self.constant(GoldilocksField::from_canonical_i64(0xFFFFFFFF)); + let high_minus_max = self.sub(high, max); + self.conditional_zero(high_minus_max, low); + [low, high] + } + + fn split_low_high_with_bits( + &mut self, + x: Target, + n_log: usize, + num_bits: usize, + ) -> ((Target, Vec), (Target, Vec)) { + let low = self.add_virtual_target(); + let high = self.add_virtual_target(); + + self.add_simple_generator(LowHighGenerator { + integer: x, + n_log, + low, + high, + }); + + let low_bits = self.split_le(low, n_log); + let high_bits = self.split_le(high, num_bits - n_log); + + let pow2 = self.constant(F::from_canonical_u64(1 << n_log)); + let comp_x = self.mul_add(high, pow2, low); + self.connect(x, comp_x); + + ((low, low_bits), (high, high_bits)) + } + + fn constant_biguint320(&mut self, n: &BigUint) -> BigUInt320Target { + assert!(n.bits() <= 320); + let digits = n.to_u32_digits(); + let limbs: [Target; 10] = array::from_fn(|i| { + let d = digits.get(i).copied().unwrap_or(0); + self.constant(GoldilocksField::from_canonical_u32(d)) + }); + self.biguint320_target_from_limbs(&limbs) + } + + fn biguint320_target_from_limbs(&mut self, x: &[Target]) -> BigUInt320Target { + assert!(x.len() == 10); + let limbs = array::from_fn(|i| x[i]); + let bit_vec = biguint_limbs_to_bits(self, x); + BigUInt320Target { + limbs, + bits: array::from_fn(|i| bit_vec[i]), + } + } + + fn add_virtual_biguint320_target(&mut self) -> BigUInt320Target { + let limbs: [Target; 10] = self.add_virtual_target_arr(); + self.biguint320_target_from_limbs(&limbs) + } + + fn connect_biguint320(&mut self, x: &BigUInt320Target, y: &BigUInt320Target) { + for i in 0..10 { + self.connect(x.limbs[i], y.limbs[i]); + } + } +} + +/// Normalises the limbs of a biguint assuming no overflow in the +/// field. Returns the limbs together with their bit decomposition. +fn normalize_biguint_limbs( + builder: &mut CircuitBuilder, + x: &[Target], + max_digit_bits: usize, + max_num_carries: usize, +) -> (Vec, Vec) { + let mut x = x.to_vec(); + let mut bits = Vec::with_capacity(32 * (max_num_carries + 1)); + for i in 0..max_num_carries { + let ((low, mut low_bits), (high, _)) = + builder.split_low_high_with_bits(x[i], 32, max_digit_bits); + x[i] = low; + x[i + 1] = builder.add(x[i + 1], high); + bits.append(&mut low_bits); + } + let mut final_bits = builder.split_le(x[max_num_carries], 32); + bits.append(&mut final_bits); + (x, bits) +} + +/// Converts biguint limbs to bits, checking that each limb is 32-bits +/// long. +fn biguint_limbs_to_bits(builder: &mut CircuitBuilder, limbs: &[Target]) -> Vec { + limbs + .iter() + .flat_map(|t| builder.split_le(*t, 32)) + .collect() +} + +/* +Copied from https://github.com/0xPolygonZero/plonky2/blob/82791c4809d6275682c34b926390ecdbdc2a5297/plonky2/src/gadgets/range_check.rs#L62 + */ + +#[derive(Debug, Default)] +pub struct LowHighGenerator { + integer: Target, + n_log: usize, + low: Target, + high: Target, +} + +impl, const D: usize> SimpleGenerator for LowHighGenerator { + fn id(&self) -> String { + "LowHighGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + vec![self.integer] + } + + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> anyhow::Result<()> { + let integer_value = witness.get_target(self.integer).to_canonical_u64(); + let low = integer_value & ((1 << self.n_log) - 1); + let high = integer_value >> self.n_log; + + out_buffer.set_target(self.low, F::from_canonical_u64(low))?; + out_buffer.set_target(self.high, F::from_canonical_u64(high)) + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_target(self.integer)?; + dst.write_usize(self.n_log)?; + dst.write_target(self.low)?; + dst.write_target(self.high) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let integer = src.read_target()?; + let n_log = src.read_usize()?; + let low = src.read_target()?; + let high = src.read_target()?; + Ok(Self { + integer, + n_log, + low, + high, + }) + } +} diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs new file mode 100644 index 00000000..89f9b091 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -0,0 +1,806 @@ +//! Implementation of the elliptic curve ecGFp5. +//! +//! We roughly follow pornin/ecgfp5. +use core::ops::{Add, Mul}; +use std::{ + array, + ops::{AddAssign, Neg, Sub}, + sync::LazyLock, +}; + +use num::{bigint::BigUint, Num, One}; +use plonky2::{ + field::{ + extension::{quintic::QuinticExtension, Extendable, FieldExtension}, + goldilocks_field::GoldilocksField, + ops::Square, + types::{Field, PrimeField}, + }, + hash::poseidon::PoseidonHash, + iop::{generator::SimpleGenerator, target::BoolTarget, witness::WitnessWrite}, + plonk::circuit_builder::CircuitBuilder, + util::serialization::{Read, Write}, +}; +use serde::{Deserialize, Serialize}; + +use crate::backends::plonky2::{ + circuits::common::ValueTarget, + primitives::ec::{ + bits::BigUInt320Target, + field::{get_nnf_target, CircuitBuilderNNF, OEFTarget}, + gates::{curve::ECAddHomogOffset, generic::SimpleGate}, + }, + Error, +}; + +type ECField = QuinticExtension; + +fn ec_field_to_bytes(x: &ECField) -> Vec { + x.0.iter() + .flat_map(|f| { + f.to_canonical_biguint() + .to_bytes_le() + .into_iter() + .chain(std::iter::repeat(0u8)) + .take(8) + }) + .collect() +} + +fn ec_field_from_bytes(b: &[u8]) -> Result { + let fields: Vec<_> = b + .chunks(8) + .map(|chunk| { + GoldilocksField::from_canonical_u64( + BigUint::from_bytes_le(chunk) + .try_into() + .expect("Slice should not contain more than 8 bytes."), + ) + }) + .collect(); + + if fields.len() != 5 { + return Err(Error::custom( + "Invalid byte encoding of quintic extension field element.".to_string(), + )); + } + + Ok(QuinticExtension(array::from_fn(|i| fields[i]))) +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Point { + pub x: ECField, + pub u: ECField, +} + +impl Point { + pub fn as_fields(&self) -> Vec { + self.x.0.iter().chain(self.u.0.iter()).cloned().collect() + } + pub fn as_bytes(&self) -> Vec { + [ec_field_to_bytes(&self.x), ec_field_to_bytes(&self.u)].concat() + } + pub fn from_bytes(b: &[u8]) -> Result { + let x_bytes = &b[..40]; + let u_bytes = &b[40..]; + ec_field_from_bytes(x_bytes) + .and_then(|x| ec_field_from_bytes(u_bytes).map(|u| Self { x, u })) + } +} + +#[derive(Clone, Copy, Debug)] +struct HomogPoint { + pub x: ECField, + pub z: ECField, + pub u: ECField, + pub t: ECField, +} + +pub(super) trait ECFieldExt: + Sized + + Copy + + Mul + + Add + + Sub + + Neg +{ + type Base: FieldExtension; + + fn to_base(self) -> [Self::Base; 5]; + fn from_base(components: [Self::Base; 5]) -> Self; + + /// Multiplies a point (viewed as an extension field element) by a + /// small factor times the field extension generator. + fn mul_field_gen(self, factor: u32) -> Self { + let in_arr = self.to_base(); + let field_factor = GoldilocksField::from_canonical_u32(factor); + let field_factor_norm = GoldilocksField::from_canonical_u32(3 * factor); + let out_arr = [ + in_arr[4].scalar_mul(field_factor_norm), + in_arr[0].scalar_mul(field_factor), + in_arr[1].scalar_mul(field_factor), + in_arr[2].scalar_mul(field_factor), + in_arr[3].scalar_mul(field_factor), + ]; + Self::from_base(out_arr) + } + + /// Adds a factor times the extension field generator to a point + /// (viewed as an extension field element). + fn add_field_gen(self, factor: GoldilocksField) -> Self { + let mut b1 = self.to_base(); + let mut b2 = b1[1].to_basefield_array(); + b2[0] += factor; + b1[1] = Self::Base::from_basefield_array(b2); + Self::from_base(b1) + } + + /// Adds a scalar (base field element) to a point (viewed as an + /// extension field element). + fn add_scalar(self, scalar: GoldilocksField) -> Self { + let mut b1 = self.to_base(); + let mut b2 = b1[0].to_basefield_array(); + b2[0] += scalar; + b1[0] = Self::Base::from_basefield_array(b2); + Self::from_base(b1) + } + + fn double(self) -> Self { + self + self + } +} + +impl ECFieldExt<1> for ECField { + type Base = GoldilocksField; + fn to_base(self) -> [Self::Base; 5] { + self.to_basefield_array() + } + fn from_base(components: [Self::Base; 5]) -> Self { + Self::from_basefield_array(components) + } +} + +pub(super) fn add_homog>(x1: F, u1: F, x2: F, u2: F) -> [F; 4] { + let t1 = x1 * x2; + let t3 = u1 * u2; + let t5 = x1 + x2; + let t6 = u1 + u2; + let t7 = t1.add_field_gen(Point::B1); + let t9 = t3 * (t5.mul_field_gen(2 * Point::B1_U32) + t7.double()); + let t10 = t3.double().add_scalar(GoldilocksField::ONE) * (t5 + t7); + let x = (t10 - t7).mul_field_gen(Point::B1_U32); + let z = t7 - t9; + let u = t6 * (-t1).add_field_gen(Point::B1); + let t = t7 + t9; + [x, z, u, t] +} + +// See CircuitBuilderEllptic::add_point for an explanation of why we need this function. +// cf. https://github.com/pornin/ecgfp5/blob/ce059c6d1e1662db437aecbf3db6bb67fe63c716/rust/src/curve.rs#L157 +pub(super) fn add_homog_offset>( + x1: F, + u1: F, + x2: F, + u2: F, +) -> [F; 4] { + let t1 = x1 * x2; + let t3 = u1 * u2; + let t5 = x1 + x2; + let t6 = u1 + u2; + let t7 = t1.add_field_gen(Point::B1); + let t9 = t3 * (t5.mul_field_gen(2 * Point::B1_U32) + t7.double()); + let t10 = t3.double().add_scalar(GoldilocksField::ONE) * (t5 + t7); + let x = (t10 - t7).mul_field_gen(Point::B1_U32); + let z = t1 - t9; + let u = t6 * (-t1).add_field_gen(Point::B1); + let t = t1 + t9; + [x, z, u, t] +} + +const GROUP_ORDER_STR: &str = "1067993516717146951041484916571792702745057740581727230159139685185762082554198619328292418486241"; +pub static GROUP_ORDER: LazyLock = + LazyLock::new(|| BigUint::from_str_radix(GROUP_ORDER_STR, 10).unwrap()); + +static FIELD_NUM_SQUARES: LazyLock = + LazyLock::new(|| (ECField::order() - BigUint::one()) >> 1); + +static GROUP_ORDER_HALF_ROUND_UP: LazyLock = + LazyLock::new(|| (&*GROUP_ORDER + BigUint::one()) >> 1); + +impl Point { + const B1_U32: u32 = 263; + const B1: GoldilocksField = GoldilocksField(Self::B1_U32 as u64); + + pub fn b() -> ECField { + ECField::from_basefield_array([ + GoldilocksField::ZERO, + Self::B1, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + ]) + } + + const ZERO: Self = Self { + x: ECField::ZERO, + u: ECField::ZERO, + }; + + pub fn generator() -> Self { + Self { + x: ECField::from_basefield_array([ + GoldilocksField::from_canonical_u64(12883135586176881569), + GoldilocksField::from_canonical_u64(4356519642755055268), + GoldilocksField::from_canonical_u64(5248930565894896907), + GoldilocksField::from_canonical_u64(2165973894480315022), + GoldilocksField::from_canonical_u64(2448410071095648785), + ]), + u: ECField::from_canonical_u64(13835058052060938241), + } + } + + fn add_homog(self, rhs: Point) -> HomogPoint { + let [x, z, u, t] = add_homog(self.x, self.u, rhs.x, rhs.u); + HomogPoint { x, z, u, t } + } + + fn double_homog(self) -> HomogPoint { + self.add_homog(self) + /* + let [x, z, u, t] = double_homog(self.x, self.u); + HomogPoint { x, z, u, t } + */ + } + + pub fn double(self) -> Self { + self.double_homog().into() + } + + pub fn inverse(self) -> Self { + Self { + x: self.x, + u: -self.u, + } + } + + pub fn is_zero(self) -> bool { + self.x.is_zero() && self.u.is_zero() + } + + pub fn is_on_curve(self) -> bool { + self.x == self.u.square() * (self.x * (self.x + ECField::TWO) + Self::b()) + } + + pub fn is_in_subgroup(self) -> bool { + if self.is_on_curve() { + self.x.exp_biguint(&FIELD_NUM_SQUARES) != ECField::ONE + } else { + false + } + } +} + +impl From for Point { + fn from(value: HomogPoint) -> Self { + Self { + x: value.x / value.z, + u: value.u / value.t, + } + } +} + +impl Add for Point { + type Output = Self; + + fn add(self, rhs: Point) -> Self::Output { + self.add_homog(rhs).into() + } +} + +impl AddAssign for Point { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl Mul for &BigUint { + type Output = Point; + fn mul(self, rhs: Point) -> Self::Output { + let bits = self.to_radix_be(2); + bits.into_iter().fold(Point::ZERO, |prod, bit| { + let double = prod.double(); + if bit == 1 { + double + rhs + } else { + double + } + }) + } +} + +type FieldTarget = OEFTarget<5, QuinticExtension>; + +#[derive(Clone, Debug)] +pub struct PointTarget { + pub x: FieldTarget, + pub u: FieldTarget, +} + +impl PointTarget { + pub fn to_value(&self, builder: &mut CircuitBuilder) -> ValueTarget { + let hash = builder.hash_n_to_hash_no_pad::( + self.x + .components + .iter() + .chain(self.u.components.iter()) + .cloned() + .collect(), + ); + ValueTarget::from_slice(&hash.elements) + } +} + +#[derive(Clone, Debug)] +struct PointSquareRootGenerator { + pub orig: PointTarget, + pub sqrt: PointTarget, +} + +impl SimpleGenerator for PointSquareRootGenerator +where + GoldilocksField: Extendable, +{ + fn id(&self) -> String { + "PointSquareRootGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + let mut deps = Vec::with_capacity(10); + deps.extend_from_slice(&self.orig.x.components); + deps.extend_from_slice(&self.orig.u.components); + deps + } + + fn run_once( + &self, + witness: &plonky2::iop::witness::PartitionWitness, + out_buffer: &mut plonky2::iop::generator::GeneratedValues, + ) -> anyhow::Result<()> { + let pt = Point { + x: get_nnf_target(witness, &self.orig.x), + u: get_nnf_target(witness, &self.orig.u), + }; + let sqrt = &*GROUP_ORDER_HALF_ROUND_UP * pt; + out_buffer.set_target_arr(&self.sqrt.x.components, &sqrt.x.0)?; + out_buffer.set_target_arr(&self.sqrt.u.components, &sqrt.u.0) + } + + fn serialize( + &self, + dst: &mut Vec, + _common_data: &plonky2::plonk::circuit_data::CommonCircuitData, + ) -> plonky2::util::serialization::IoResult<()> { + dst.write_target_array(&self.orig.x.components)?; + dst.write_target_array(&self.orig.u.components)?; + dst.write_target_array(&self.sqrt.x.components)?; + dst.write_target_array(&self.sqrt.u.components) + } + + fn deserialize( + src: &mut plonky2::util::serialization::Buffer, + _common_data: &plonky2::plonk::circuit_data::CommonCircuitData, + ) -> plonky2::util::serialization::IoResult + where + Self: Sized, + { + let orig = PointTarget { + x: FieldTarget::new(src.read_target_array()?), + u: FieldTarget::new(src.read_target_array()?), + }; + let sqrt = PointTarget { + x: FieldTarget::new(src.read_target_array()?), + u: FieldTarget::new(src.read_target_array()?), + }; + Ok(Self { orig, sqrt }) + } +} + +pub trait CircuitBuilderElliptic { + fn add_virtual_point_target(&mut self) -> PointTarget; + fn identity_point(&mut self) -> PointTarget; + fn constant_point(&mut self, p: Point) -> PointTarget; + + fn add_point(&mut self, p1: &PointTarget, p2: &PointTarget) -> PointTarget; + fn double_point(&mut self, p: &PointTarget) -> PointTarget; + fn linear_combination_points( + &mut self, + p1_scalar: &[BoolTarget; 320], + p2_scalar: &[BoolTarget; 320], + p1: &PointTarget, + p2: &PointTarget, + ) -> PointTarget; + fn if_point( + &mut self, + b: BoolTarget, + p_true: &PointTarget, + p_false: &PointTarget, + ) -> PointTarget; + + /// Check that two points are equal. This assumes that the points are + /// already known to be in the subgroup. + fn connect_point(&mut self, p1: &PointTarget, p2: &PointTarget); + fn check_point_on_curve(&mut self, p: &PointTarget); + fn check_point_in_subgroup(&mut self, p: &PointTarget); +} + +impl CircuitBuilderElliptic for CircuitBuilder { + fn add_virtual_point_target(&mut self) -> PointTarget { + let p = PointTarget { + x: self.add_virtual_nnf_target(), + u: self.add_virtual_nnf_target(), + }; + self.check_point_in_subgroup(&p); + p + } + + fn identity_point(&mut self) -> PointTarget { + self.constant_point(Point::ZERO) + } + + fn constant_point(&mut self, p: Point) -> PointTarget { + assert!(p.is_in_subgroup()); + PointTarget { + x: self.nnf_constant(&p.x), + u: self.nnf_constant(&p.u), + } + } + + fn add_point(&mut self, p1: &PointTarget, p2: &PointTarget) -> PointTarget { + let mut inputs = Vec::with_capacity(20); + inputs.extend_from_slice(&p1.x.components); + inputs.extend_from_slice(&p1.u.components); + inputs.extend_from_slice(&p2.x.components); + inputs.extend_from_slice(&p2.u.components); + let outputs = ECAddHomogOffset::apply(self, &inputs); + // plonky2 expects all gate constraints to be satisfied by the zero vector. + // So our elliptic curve addition gate computes [x,z-b,u,t-b], and we have to add the b here. + let x = FieldTarget::new(outputs[0..5].try_into().unwrap()); + let z = FieldTarget::new(outputs[5..10].try_into().unwrap()); + let u = FieldTarget::new(outputs[10..15].try_into().unwrap()); + let t = FieldTarget::new(outputs[15..20].try_into().unwrap()); + let b1 = self.constant(Point::B1); + let z = self.nnf_add_scalar_times_generator_power(b1, 1, &z); + let t = self.nnf_add_scalar_times_generator_power(b1, 1, &t); + let xq = self.nnf_div(&x, &z); + let uq = self.nnf_div(&u, &t); + PointTarget { x: xq, u: uq } + /* + let t1 = self.nnf_mul(&p1.x, &p2.x); + let t3 = self.nnf_mul(&p1.u, &p2.u); + let t5 = self.nnf_add(&p1.x, &p2.x); + let t6 = self.nnf_add(&p1.u, &p2.u); + let b1 = self.constant(GoldilocksField::from_canonical_u32(Point::B1_U32)); + let t7 = self.nnf_add_scalar_times_generator_power(b1, 1, &t1); + let t9_1 = self.nnf_mul_generator(&t5); + let t9_2 = self.nnf_mul_scalar(b1, &t9_1); + let t9_3 = self.nnf_add(&t9_2, &t7); + let t9_4 = self.nnf_add(&t9_3, &t9_3); + let t9 = self.nnf_mul(&t3, &t9_4); + let one = self.one(); + let t10_1 = self.nnf_add(&t3, &t3); + let t10_2 = self.nnf_add_scalar_times_generator_power(one, 0, &t10_1); + let t10_3 = self.nnf_add(&t5, &t7); + let t10 = self.nnf_mul(&t10_2, &t10_3); + let x_1 = self.nnf_sub(&t10, &t7); + let x_2 = self.nnf_mul_generator(&x_1); + let x = self.nnf_mul_scalar(b1, &x_2); + let z = self.nnf_sub(&t7, &t9); + let neg_one = self.neg_one(); + let u_1 = self.nnf_mul_scalar(neg_one, &t1); + let u_2 = self.nnf_add_scalar_times_generator_power(b1, 1, &u_1); + let u = self.nnf_mul(&t6, &u_2); + let t = self.nnf_add(&t7, &t9); + let xq = self.nnf_div(&x, &z); + let uq = self.nnf_div(&u, &t); + PointTarget { x: xq, u: uq } + */ + } + + fn double_point(&mut self, p: &PointTarget) -> PointTarget { + self.add_point(p, p) + /* + let t3 = self.nnf_mul(&p.u, &p.u); + let one = self.one(); + let neg_one = self.neg_one(); + let two = self.two(); + let neg_four = self.constant(GoldilocksField::from_noncanonical_i64(-4)); + let four_b = self.constant(GoldilocksField::from_canonical_u32(4 * Point::B1_U32)); + let w1_1 = self.nnf_add_scalar_times_generator_power(one, 0, &p.x); + let w1_2 = self.nnf_add(&w1_1, &w1_1); + let w1_3 = self.nnf_mul(&w1_2, &t3); + let w1_4 = self.nnf_mul_scalar(neg_one, &w1_3); + let w1 = self.nnf_add_scalar_times_generator_power(one, 0, &w1_4); + let x_1 = self.nnf_mul_scalar(four_b, &t3); + let x = self.nnf_mul_generator(&x_1); + let z = self.nnf_mul(&w1, &w1); + let u_1 = self.nnf_add(&w1, &p.u); + let u_2 = self.nnf_mul(&u_1, &u_1); + let u_3 = self.nnf_sub(&u_2, &t3); + let u = self.nnf_sub(&u_3, &z); + let t_1 = self.nnf_mul_scalar(neg_four, &t3); + let t_2 = self.nnf_add_scalar_times_generator_power(two, 0, &t_1); + let t = self.nnf_sub(&t_2, &z); + let xq = self.nnf_div(&x, &z); + let uq = self.nnf_div(&u, &t); + PointTarget { x: xq, u: uq } + */ + } + + fn linear_combination_points( + &mut self, + p1_scalar: &[BoolTarget; 320], + p2_scalar: &[BoolTarget; 320], + p1: &PointTarget, + p2: &PointTarget, + ) -> PointTarget { + let zero = self.identity_point(); + let sum = self.add_point(p1, p2); + let mut ans = zero.clone(); + for i in (0..320).rev() { + ans = self.double_point(&ans); + let maybe_p1 = self.if_point(p1_scalar[i], p1, &zero); + let p2_maybe_p1 = self.if_point(p1_scalar[i], &sum, p2); + let p = self.if_point(p2_scalar[i], &p2_maybe_p1, &maybe_p1); + ans = self.add_point(&ans, &p); + } + ans + } + + fn if_point( + &mut self, + b: BoolTarget, + p_true: &PointTarget, + p_false: &PointTarget, + ) -> PointTarget { + PointTarget { + x: self.nnf_if(b, &p_true.x, &p_false.x), + u: self.nnf_if(b, &p_true.u, &p_false.u), + } + } + + fn connect_point(&mut self, p1: &PointTarget, p2: &PointTarget) { + // The elements of the subgroup have distinct u-coordinates. So it + // is not necessary to connect the x-coordinates. + // Explanation: If a point has u-coordinate lambda: + // If lambda is nonzero, then the other two points on the line x = lambda y + // are the origin (which has u=0 rather than lambda) and a point that's not + // in our subgroup (it differs from an element of our subgroup by + // a 2-torsion point). + // If lambda is zero, then the line x = 0 is tangent to the origin and also + // passes through the point at infinity (which is not in our subgroup). + self.nnf_connect(&p1.u, &p2.u); + } + + fn check_point_on_curve(&mut self, p: &PointTarget) { + let t1 = self.nnf_mul(&p.u, &p.u); + let two = self.two(); + let t2 = self.nnf_add_scalar_times_generator_power(two, 0, &p.x); + let t3 = self.nnf_mul(&p.x, &t2); + let b1 = self.constant(Point::B1); + let t4 = self.nnf_add_scalar_times_generator_power(b1, 1, &t3); + let t5 = self.nnf_mul(&t1, &t4); + self.nnf_connect(&p.x, &t5); + } + + fn check_point_in_subgroup(&mut self, p: &PointTarget) { + // In order to be in the subgroup, the point needs to be a multiple + // of two. + let sqrt = PointTarget { + x: self.add_virtual_nnf_target(), + u: self.add_virtual_nnf_target(), + }; + self.check_point_on_curve(&sqrt); + let doubled = self.double_point(&sqrt); + // connect_point assumes that the point is already known to be in the + // subgroup, so connect the coordinates instead + self.nnf_connect(&doubled.x, &p.x); + self.nnf_connect(&doubled.u, &p.u); + self.add_simple_generator(PointSquareRootGenerator { + orig: p.clone(), + sqrt, + }); + } +} + +pub trait WitnessWriteCurve: WitnessWrite { + fn set_field_target(&mut self, target: &FieldTarget, value: &ECField) -> anyhow::Result<()> { + self.set_target_arr(&target.components, &value.0) + } + fn set_point_target(&mut self, target: &PointTarget, value: &Point) -> anyhow::Result<()> { + self.set_field_target(&target.x, &value.x)?; + self.set_field_target(&target.u, &value.u) + } + fn set_biguint320_target( + &mut self, + target: &BigUInt320Target, + value: &BigUint, + ) -> anyhow::Result<()> { + assert!(value.bits() <= 320); + let digits = value.to_u32_digits(); + for i in 0..10 { + let d = digits.get(i).copied().unwrap_or(0); + self.set_target(target.limbs[i], GoldilocksField::from_canonical_u32(d))?; + } + Ok(()) + } +} + +impl> WitnessWriteCurve for W {} + +#[cfg(test)] +mod test { + use num::{BigUint, FromPrimitive}; + use num_bigint::RandBigInt; + use plonky2::{ + field::{goldilocks_field::GoldilocksField, types::Field}, + iop::witness::PartialWitness, + plonk::{ + circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, + config::PoseidonGoldilocksConfig, + }, + }; + use rand::rngs::OsRng; + + use crate::backends::plonky2::primitives::ec::{ + bits::CircuitBuilderBits, + curve::{CircuitBuilderElliptic, ECField, Point, WitnessWriteCurve, GROUP_ORDER}, + }; + + #[test] + fn test_double() { + let g = Point::generator(); + let p1 = g + g; + let p2 = g.double(); + assert_eq!(p1, p2); + } + + #[test] + fn test_id() { + let p1 = Point::generator(); + let p2 = p1 + Point::ZERO; + assert_eq!(p1, p2); + } + + #[test] + fn test_triple() { + let g = Point::generator(); + let p1 = g + g + g; + let p2 = g + g.double(); + let three = BigUint::from_u64(3).unwrap(); + let p3 = (&three) * g; + assert_eq!(p1, p2); + assert_eq!(p2, p3); + } + + #[test] + fn test_associativity() { + let g = Point::generator(); + let n1 = OsRng.gen_biguint_below(&GROUP_ORDER); + let n2 = OsRng.gen_biguint_below(&GROUP_ORDER); + let prod = (&n1 * &n2) % &*GROUP_ORDER; + assert_eq!(&prod * g, &n1 * (&n2 * g)); + } + + #[test] + fn test_distributivity() { + let g = Point::generator(); + let n1 = OsRng.gen_biguint_below(&GROUP_ORDER); + let n2 = OsRng.gen_biguint_below(&GROUP_ORDER); + let sum = (&n1 + &n2) % &*GROUP_ORDER; + let p1 = &n1 * g; + let p2 = &n2 * g; + let psum = &sum * g; + assert_eq!(p1 + p2, psum); + } + + #[test] + fn test_in_subgroup() { + let g = Point::generator(); + assert!(g.is_in_subgroup()); + let n = OsRng.gen_biguint_below(&GROUP_ORDER); + assert!((&n * g).is_in_subgroup()); + let fake = Point { + x: ECField::ONE, + u: ECField::ONE, + }; + assert!(!fake.is_on_curve()); + let not_sub = Point { + x: Point::b() / g.x, + u: g.u, + }; + assert!(not_sub.is_on_curve()); + assert!(!not_sub.is_in_subgroup()); + } + + #[test] + fn test_double_circuit() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let g = Point::generator(); + let n = OsRng.gen_biguint_below(&GROUP_ORDER); + let p = (&n) * g; + let a = builder.constant_point(p); + let b = builder.double_point(&a); + let c = builder.constant_point(p.double()); + builder.connect_point(&b, &c); + let pw = PartialWitness::new(); + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } + + #[test] + fn test_add_circuit() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let g = Point::generator(); + let n1 = OsRng.gen_biguint_below(&GROUP_ORDER); + let n2 = OsRng.gen_biguint_below(&GROUP_ORDER); + let p1 = (&n1) * g; + let p2 = (&n2) * g; + let a = builder.constant_point(p1); + let b = builder.constant_point(p2); + let c = builder.add_point(&a, &b); + let d = builder.constant_point(p1 + p2); + builder.connect_point(&c, &d); + let pw = PartialWitness::new(); + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } + + #[test] + fn test_linear_combination_circuit() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let g = Point::generator(); + let n1 = OsRng.gen_biguint_below(&GROUP_ORDER); + let n2 = OsRng.gen_biguint_below(&GROUP_ORDER); + let n3 = OsRng.gen_biguint_below(&GROUP_ORDER); + let p = (&n1) * g; + let g_tgt = builder.constant_point(g); + let p_tgt = builder.constant_point(p); + let g_scalar_bigint = builder.constant_biguint320(&n2); + let p_scalar_bigint = builder.constant_biguint320(&n3); + let g_scalar_bits = g_scalar_bigint.bits; + let p_scalar_bits = p_scalar_bigint.bits; + let e = builder.constant_point((&n2) * g + (&n3) * p); + let f = builder.linear_combination_points(&g_scalar_bits, &p_scalar_bits, &g_tgt, &p_tgt); + builder.connect_point(&e, &f); + let pw = PartialWitness::new(); + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } + + #[test] + fn test_not_in_subgroup_circuit() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let g = Point::generator(); + let not_sub = Point { + x: Point::b() / g.x, + u: g.u, + }; + let pt = builder.add_virtual_point_target(); + let mut pw = PartialWitness::new(); + pw.set_point_target(&pt, ¬_sub)?; + let data = builder.build::(); + assert!(data.prove(pw).is_err()); + Ok(()) + } +} diff --git a/src/backends/plonky2/primitives/ec/field.rs b/src/backends/plonky2/primitives/ec/field.rs new file mode 100644 index 00000000..8bdcb373 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/field.rs @@ -0,0 +1,402 @@ +use std::marker::PhantomData; + +use num::BigUint; +use plonky2::{ + field::{ + extension::{Extendable, FieldExtension, OEF}, + types::Field, + }, + hash::hash_types::RichField, + iop::{ + generator::{GeneratedValues, SimpleGenerator}, + target::{BoolTarget, Target}, + witness::{PartitionWitness, Witness, WitnessWrite}, + }, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData}, + util::serialization::{Buffer, IoError, Read, Write}, +}; + +//use super::gates::field::NNFMulGate; +use crate::{ + backends::plonky2::{ + basetypes::D, + primitives::ec::gates::{field::NNFMulSimple, generic::SimpleGate}, + }, + middleware::F, +}; + +/// Trait for incorporating non-native field (NNF) arithmetic into a +/// circuit. For our purposes, this is a field extension generated by +/// a single element. +pub trait CircuitBuilderNNF< + F: RichField + Extendable, + const D: usize, + NNF: Field, + NNFTarget: Clone, +> +{ + // Target 'adder' + fn add_virtual_nnf_target(&mut self) -> NNFTarget; + // Constant introducers. + fn nnf_constant(&mut self, x: &NNF) -> NNFTarget; + fn nnf_zero(&mut self) -> NNFTarget { + self.nnf_constant(&NNF::ZERO) + } + fn nnf_one(&mut self) -> NNFTarget { + self.nnf_constant(&NNF::ONE) + } + + // Field ops + fn nnf_add(&mut self, x: &NNFTarget, y: &NNFTarget) -> NNFTarget; + fn nnf_sub(&mut self, x: &NNFTarget, y: &NNFTarget) -> NNFTarget; + fn nnf_mul(&mut self, x: &NNFTarget, y: &NNFTarget) -> NNFTarget; + fn nnf_div(&mut self, x: &NNFTarget, y: &NNFTarget) -> NNFTarget; + fn nnf_inverse(&mut self, x: &NNFTarget) -> NNFTarget { + let one = self.nnf_one(); + self.nnf_div(&one, x) + } + + /// Multiplies an extension field element by the generator of the + /// extension field. + fn nnf_mul_generator(&mut self, x: &NNFTarget) -> NNFTarget; + + /// Multiplies an extension field element by a base field element. + fn nnf_mul_scalar(&mut self, x: Target, y: &NNFTarget) -> NNFTarget; + + /// Multiplies an extension field element by a base field element + /// times the generator of the extension field to a given power. + fn nnf_add_scalar_times_generator_power( + &mut self, + x: Target, + gen_power: usize, + y: &NNFTarget, + ) -> NNFTarget; + fn nnf_if(&mut self, b: BoolTarget, x_true: &NNFTarget, x_false: &NNFTarget) -> NNFTarget; + + /// Computes an extension field element to a given (biguint) + /// power. + fn nnf_exp_biguint(&mut self, base: &NNFTarget, exponent: &BigUint) -> NNFTarget; + + // Equality check and connection + fn nnf_eq(&mut self, x: &NNFTarget, y: &NNFTarget) -> BoolTarget; + fn nnf_connect(&mut self, x: &NNFTarget, y: &NNFTarget); +} + +/// Target type modelled on OEF. +#[derive(Debug, Clone)] +pub struct OEFTarget> { + pub components: [Target; DEG], + _phantom_data: PhantomData, +} + +impl> OEFTarget { + pub fn new(components: [Target; DEG]) -> Self { + Self { + components, + _phantom_data: PhantomData, + } + } +} + +impl> Default for OEFTarget { + fn default() -> Self { + Self::new([Target::default(); DEG]) + } +} + +/// Quotient generator for OEF targets. Allows us to automagically +/// generate quotients as witnesses. +#[derive(Debug, Default)] +struct QuotientGeneratorOEF> { + numerator: OEFTarget, + denominator: OEFTarget, + quotient: OEFTarget, +} + +impl< + const DEG: usize, + NNF: OEF + FieldExtension, + F: RichField + Extendable, + const D: usize, + > SimpleGenerator for QuotientGeneratorOEF +{ + fn id(&self) -> String { + "QuotientGeneratorOEF".to_string() + } + + fn dependencies(&self) -> Vec { + self.numerator + .components + .iter() + .chain(self.denominator.components.iter()) + .cloned() + .collect() + } + + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<(), anyhow::Error> { + // Dereference numerator & denominator targets to vectors + // and construct field elements. + let num_components = self + .numerator + .components + .iter() + .map(|t| witness.get_target(*t)) + .collect::>(); + let den_components = self + .denominator + .components + .iter() + .map(|t| witness.get_target(*t)) + .collect::>(); + + let num = NNF::from_basefield_array(std::array::from_fn(|i| num_components[i])); + let den = NNF::from_basefield_array(std::array::from_fn(|i| den_components[i])); + + let quotient = num / den; + + out_buffer.set_target_arr(&self.quotient.components, "ient.to_basefield_array()) + } + + fn serialize( + &self, + dst: &mut Vec, + _common_data: &CommonCircuitData, + ) -> Result<(), IoError> { + dst.write_target_array(&self.numerator.components)?; + dst.write_target_array(&self.denominator.components)?; + dst.write_target_array(&self.quotient.components) + } + + fn deserialize( + src: &mut Buffer, + _common_data: &CommonCircuitData, + ) -> Result { + let numerator = OEFTarget::new(src.read_target_array()?); + let denominator = OEFTarget::new(src.read_target_array()?); + let quotient = OEFTarget::new(src.read_target_array()?); + Ok(Self { + numerator, + denominator, + quotient, + }) + } +} + +impl + FieldExtension> + CircuitBuilderNNF> for CircuitBuilder +{ + fn add_virtual_nnf_target(&mut self) -> OEFTarget { + OEFTarget::new(self.add_virtual_target_arr()) + } + fn nnf_constant(&mut self, x: &NNF) -> OEFTarget { + let targets = x + .to_basefield_array() + .iter() + .map(|c| self.constant(*c)) + .collect::>(); + OEFTarget::new(std::array::from_fn(|i| targets[i])) + } + fn nnf_add(&mut self, x: &OEFTarget, y: &OEFTarget) -> OEFTarget { + let sum_targets = std::iter::zip(&x.components, &y.components) + .map(|(a, b)| self.add(*a, *b)) + .collect::>(); + OEFTarget::new(std::array::from_fn(|i| sum_targets[i])) + } + fn nnf_sub(&mut self, x: &OEFTarget, y: &OEFTarget) -> OEFTarget { + let sub_targets = std::iter::zip(&x.components, &y.components) + .map(|(a, b)| self.sub(*a, *b)) + .collect::>(); + OEFTarget::new(std::array::from_fn(|i| sub_targets[i])) + } + fn nnf_mul(&mut self, x: &OEFTarget, y: &OEFTarget) -> OEFTarget { + let mut inputs = Vec::with_capacity(10); + inputs.extend_from_slice(&x.components); + inputs.extend_from_slice(&y.components); + let outputs = NNFMulSimple::::apply(self, &inputs); + OEFTarget::new(outputs.try_into().unwrap()) + } + fn nnf_div(&mut self, x: &OEFTarget, y: &OEFTarget) -> OEFTarget { + let one = self.nnf_one(); + // Determine denominator inverse witness. + let y_inv = self.add_virtual_nnf_target(); + self.add_simple_generator(QuotientGeneratorOEF { + numerator: one.clone(), + denominator: y.clone(), + quotient: y_inv.clone(), + }); + + // Add constraints and generate quotient. + let maybe_one = self.nnf_mul(&y_inv, y); + self.nnf_connect(&one, &maybe_one); + self.nnf_mul(x, &y_inv) + } + fn nnf_mul_generator(&mut self, x: &OEFTarget) -> OEFTarget { + OEFTarget::new(std::array::from_fn(|i| { + if i == 0 { + self.mul_const(NNF::W, x.components[DEG - 1]) + } else { + x.components[i - 1] + } + })) + } + fn nnf_mul_scalar(&mut self, x: Target, y: &OEFTarget) -> OEFTarget { + OEFTarget::new(std::array::from_fn(|i| self.mul(x, y.components[i]))) + } + fn nnf_add_scalar_times_generator_power( + &mut self, + x: Target, + gen_power: usize, + y: &OEFTarget, + ) -> OEFTarget { + OEFTarget::new(std::array::from_fn(|i| { + if i == gen_power { + self.add(x, y.components[i]) + } else { + y.components[i] + } + })) + } + fn nnf_if( + &mut self, + b: BoolTarget, + x_true: &OEFTarget, + x_false: &OEFTarget, + ) -> OEFTarget { + OEFTarget::new(std::array::from_fn(|i| { + self._if(b, x_true.components[i], x_false.components[i]) + })) + } + fn nnf_exp_biguint( + &mut self, + base: &OEFTarget, + exponent: &BigUint, + ) -> OEFTarget { + let mut ans = self.nnf_one(); + for i in (0..exponent.bits()).rev() { + ans = self.nnf_mul(&ans, &ans); + if exponent.bit(i) { + ans = self.nnf_mul(&ans, base); + } + } + ans + } + fn nnf_eq(&mut self, x: &OEFTarget, y: &OEFTarget) -> BoolTarget { + let eq_checks = std::iter::zip(&x.components, &y.components) + .map(|(a, b)| self.is_equal(*a, *b)) + .collect::>(); + eq_checks + .into_iter() + .reduce(|check, c| self.and(check, c)) + .expect("Missing equality checks") + } + fn nnf_connect(&mut self, x: &OEFTarget, y: &OEFTarget) { + std::iter::zip(&x.components, &y.components).for_each(|(a, b)| self.connect(*a, *b)) + } +} + +pub(super) fn get_nnf_target>( + witness: &impl Witness, + tgt: &OEFTarget, +) -> NNF { + let values = tgt.components.map(|x| witness.get_target(x)); + NNF::from_basefield_array(values) +} + +#[cfg(test)] +mod test { + use plonky2::{ + field::{ + extension::quintic::QuinticExtension, + goldilocks_field::GoldilocksField, + types::{Field, Sample}, + }, + iop::witness::{PartialWitness, WitnessWrite}, + plonk::{ + circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, + config::PoseidonGoldilocksConfig, + }, + }; + + use super::{CircuitBuilderNNF, OEFTarget}; + + #[test] + fn quintic_arithmetic_check() -> Result<(), anyhow::Error> { + type QuinticGoldilocks = QuinticExtension; + + // Circuit declaration + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let zero = builder.nnf_zero(); + let one = builder.nnf_one(); + + // Let c = a * b. + let a_target: OEFTarget<5, QuinticGoldilocks> = builder.add_virtual_nnf_target(); + let b_target: OEFTarget<5, QuinticGoldilocks> = builder.add_virtual_nnf_target(); + let c_target = builder.nnf_mul(&a_target, &b_target); + + // Pick some values. + let a_value = QuinticExtension(std::array::from_fn(|_| GoldilocksField::rand())); + let b_value = { + let rand_value = QuinticExtension(std::array::from_fn(|_| GoldilocksField::rand())); + if rand_value == QuinticGoldilocks::ZERO { + QuinticGoldilocks::ONE + } else { + rand_value + } + }; + let c_value = a_value * b_value; + + // How about d = a/b? + let d_target = builder.nnf_div(&a_target, &b_target); + let d_value = a_value / b_value; + + // Also e = a - b. + let e_target = builder.nnf_sub(&a_target, &b_target); + let e_value = a_value - b_value; + + // a +- 0 == a, a * 1 == a, etc. + let a_plus_zero = builder.nnf_add(&a_target, &zero); + let a_minus_zero = builder.nnf_sub(&a_target, &zero); + let a_times_one = builder.nnf_mul(&a_target, &one); + let a_div_one = builder.nnf_div(&a_target, &one); + + builder.nnf_connect(&a_target, &a_plus_zero); + builder.nnf_connect(&a_target, &a_minus_zero); + builder.nnf_connect(&a_target, &a_times_one); + builder.nnf_connect(&a_target, &a_div_one); + + // a == a, a != a + 1 + let a_plus_one = builder.nnf_add(&a_target, &one); + let a_eq_a = builder.nnf_eq(&a_target, &a_target); + let a_eq_a_plus_one = builder.nnf_eq(&a_target, &a_plus_one); + + builder.assert_one(a_eq_a.target); + builder.assert_zero(a_eq_a_plus_one.target); + + // b * (1/b) == 1 + let one_on_b = builder.nnf_inverse(&b_target); + let b_times_one_on_b = builder.nnf_mul(&b_target, &one_on_b); + + builder.nnf_connect(&one, &b_times_one_on_b); + + // Prove + let mut pw = PartialWitness::::new(); + + pw.set_target_arr(&a_target.components, &a_value.0)?; + pw.set_target_arr(&b_target.components, &b_value.0)?; + pw.set_target_arr(&c_target.components, &c_value.0)?; + pw.set_target_arr(&d_target.components, &d_value.0)?; + pw.set_target_arr(&e_target.components, &e_value.0)?; + + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } +} diff --git a/src/backends/plonky2/primitives/ec/gates/curve.rs b/src/backends/plonky2/primitives/ec/gates/curve.rs new file mode 100644 index 00000000..7add65a4 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/gates/curve.rs @@ -0,0 +1,83 @@ +use plonky2::field::goldilocks_field::GoldilocksField; + +use crate::backends::plonky2::primitives::ec::{ + curve::{add_homog_offset, ECFieldExt}, + gates::{field::QuinticTensor, generic::SimpleGate}, +}; + +/// Gate computing the addition of two elliptic curve points in +/// homogeneous coordinates *minus* an offset in the `z` and `t` +/// coordinates, viz. the extension field generator times `Point::B1`, +/// cf. CircuitBuilderElliptic::add_point. +#[derive(Debug, Clone)] +pub struct ECAddHomogOffset; + +impl SimpleGate for ECAddHomogOffset { + type F = GoldilocksField; + const INPUTS_PER_OP: usize = 20; + const OUTPUTS_PER_OP: usize = 20; + const DEGREE: usize = 4; + const ID: &'static str = "ECAddHomog"; + fn eval( + wires: &[>::Extension], + ) -> Vec<>::Extension> + where + Self::F: plonky2::field::extension::Extendable, + { + let mut ans = Vec::with_capacity(20); + let x1 = QuinticTensor::from_base(wires[0..5].try_into().unwrap()); + let u1 = QuinticTensor::from_base(wires[5..10].try_into().unwrap()); + let x2 = QuinticTensor::from_base(wires[10..15].try_into().unwrap()); + let u2 = QuinticTensor::from_base(wires[15..20].try_into().unwrap()); + let out = add_homog_offset(x1, u1, x2, u2); + for v in out { + ans.extend(v.to_base()); + } + ans + } +} + +#[cfg(test)] +mod test { + use plonky2::{ + gates::gate_testing::{test_eval_fns, test_low_degree}, + plonk::{circuit_data::CircuitConfig, config::PoseidonGoldilocksConfig}, + }; + + use crate::backends::plonky2::primitives::ec::gates::{ + curve::ECAddHomogOffset, generic::GateAdapter, + }; + + #[test] + fn test_recursion() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let gate = GateAdapter::::new_from_config(&config); + + test_eval_fns::<_, PoseidonGoldilocksConfig, _, 2>(gate) + } + + #[test] + fn test_low_degree_orig() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let gate = GateAdapter::::new_from_config(&config); + + test_low_degree::<_, _, 2>(gate); + Ok(()) + } + + #[test] + fn test_low_degree_recursive() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let orig_gate = GateAdapter::::new_from_config(&config); + + test_low_degree::<_, _, 2>(orig_gate.recursive_gate()); + Ok(()) + } + + #[test] + fn test_double_recursion() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let orig_gate = GateAdapter::::new_from_config(&config); + test_eval_fns::<_, PoseidonGoldilocksConfig, _, 2>(orig_gate.recursive_gate()) + } +} diff --git a/src/backends/plonky2/primitives/ec/gates/field.rs b/src/backends/plonky2/primitives/ec/gates/field.rs new file mode 100644 index 00000000..164f9684 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/gates/field.rs @@ -0,0 +1,234 @@ +use std::{ + array, + marker::PhantomData, + ops::{Add, Mul, Neg, Sub}, +}; + +use plonky2::{ + field::{ + extension::{quintic::QuinticExtension, Extendable, FieldExtension, OEF}, + goldilocks_field::GoldilocksField, + types::Field, + }, + hash::hash_types::RichField, +}; + +use crate::backends::plonky2::primitives::ec::{curve::ECFieldExt, gates::generic::SimpleGate}; + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub struct TensorProduct +where + F1: OEF, + F2: FieldExtension, +{ + pub components: [F2; D1], + _phantom_data: PhantomData, +} + +impl TensorProduct +where + F1: OEF, + F2: FieldExtension, +{ + pub fn new(components: [F2; D1]) -> Self { + Self { + components, + _phantom_data: PhantomData, + } + } + + pub fn add_base_field(self, rhs: F2::BaseField) -> Self { + let mut c = self.components; + let mut c2 = c[0].to_basefield_array(); + c2[0] += rhs; + c[0] = F2::from_basefield_array(c2); + Self::new(c) + } + + pub fn add_one(self) -> Self { + self.add_base_field(F2::BaseField::ONE) + } + + pub fn mul_scalar(self, rhs: F2::BaseField) -> Self { + Self::new(self.components.map(|x| x.scalar_mul(rhs))) + } + + pub fn double(self) -> Self { + self + self + } + + pub fn is_zero(self) -> bool { + self.components.iter().all(|x| x.is_zero()) + } +} + +impl Add for TensorProduct +where + F1: OEF, + F2: FieldExtension, +{ + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self::new(array::from_fn(|i| self.components[i] + rhs.components[i])) + } +} + +impl Mul for TensorProduct +where + F1: OEF, + F2: FieldExtension, +{ + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + let mut components = array::from_fn(|_| F2::ZERO); + for i in 0..D1 { + for j in 0..D1 { + let prod = self.components[i] * rhs.components[j]; + if i + j < D1 { + components[i + j] += prod; + } else { + components[i + j - D1] += prod.scalar_mul(F1::W) + } + } + } + Self::new(components) + } +} + +impl Sub for TensorProduct +where + F1: OEF, + F2: FieldExtension, +{ + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self::new(array::from_fn(|i| self.components[i] - rhs.components[i])) + } +} + +impl Neg for TensorProduct +where + F1: OEF, + F2: FieldExtension, +{ + type Output = Self; + fn neg(self) -> Self::Output { + Self::new(self.components.map(|x| -x)) + } +} + +pub(super) type QuinticTensor = TensorProduct< + 5, + D, + QuinticExtension, + >::Extension, +>; + +impl ECFieldExt for QuinticTensor +where + GoldilocksField: Extendable, +{ + type Base = >::Extension; + fn to_base(self) -> [Self::Base; 5] { + self.components + } + fn from_base(components: [Self::Base; 5]) -> Self { + Self::new(components) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct NNFMulSimple> { + _phantom_data: PhantomData NNF>, +} + +impl> NNFMulSimple { + pub fn new() -> Self { + Self { + _phantom_data: PhantomData, + } + } +} + +impl SimpleGate for NNFMulSimple +where + NNF: OEF, + NNF::BaseField: RichField + Extendable<1>, +{ + type F = NNF::BaseField; + const INPUTS_PER_OP: usize = 2 * NNF_DEG; + const OUTPUTS_PER_OP: usize = NNF_DEG; + const DEGREE: usize = 2; + const ID: &'static str = "NNFSimpleGate"; + + fn eval( + wires: &[>::Extension], + ) -> Vec<>::Extension> + where + Self::F: Extendable, + { + let x: TensorProduct>::Extension> = + TensorProduct::new(array::from_fn(|i| wires[i])); + let y = TensorProduct::new(array::from_fn(|i| wires[NNF_DEG + i])); + let prod = x * y; + prod.components.into() + } +} + +#[cfg(test)] +mod test { + use plonky2::{ + field::{extension::quintic::QuinticExtension, goldilocks_field::GoldilocksField}, + gates::gate_testing::{test_eval_fns, test_low_degree}, + plonk::{circuit_data::CircuitConfig, config::PoseidonGoldilocksConfig}, + }; + + use crate::backends::plonky2::primitives::ec::gates::{ + field::NNFMulSimple, generic::GateAdapter, + }; + + #[test] + fn test_recursion() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let gate = + GateAdapter::>>::new_from_config( + &config, + ); + + test_eval_fns::<_, PoseidonGoldilocksConfig, _, 2>(gate) + } + + #[test] + fn test_low_degree_orig() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let gate = + GateAdapter::>>::new_from_config( + &config, + ); + + test_low_degree::<_, _, 2>(gate); + Ok(()) + } + + #[test] + fn test_low_degree_recursive() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let orig_gate = + GateAdapter::>>::new_from_config( + &config, + ); + + test_low_degree::<_, _, 2>(orig_gate.recursive_gate()); + Ok(()) + } + + #[test] + fn test_double_recursion() -> Result<(), anyhow::Error> { + let config = CircuitConfig::standard_recursion_config(); + let orig_gate = + GateAdapter::>>::new_from_config( + &config, + ); + test_eval_fns::<_, PoseidonGoldilocksConfig, _, 2>(orig_gate.recursive_gate()) + } +} diff --git a/src/backends/plonky2/primitives/ec/gates/generic.rs b/src/backends/plonky2/primitives/ec/gates/generic.rs new file mode 100644 index 00000000..d81c965a --- /dev/null +++ b/src/backends/plonky2/primitives/ec/gates/generic.rs @@ -0,0 +1,600 @@ +#![allow(clippy::needless_range_loop)] + +use std::{fmt::Debug, marker::PhantomData}; + +use plonky2::{ + field::extension::{Extendable, FieldExtension}, + gates::gate::Gate, + hash::hash_types::RichField, + iop::{ + ext_target::ExtensionTarget, + generator::{SimpleGenerator, WitnessGeneratorRef}, + target::Target, + witness::{Witness, WitnessWrite}, + }, + plonk::{ + circuit_builder::CircuitBuilder, + circuit_data::{CircuitConfig, CommonCircuitData}, + vars::EvaluationVars, + }, + util::serialization::{Buffer, IoResult, Read, Write}, +}; + +pub trait SimpleGate: 'static + Send + Sync + Sized + Clone + Debug { + type F: RichField + Extendable<1>; + const INPUTS_PER_OP: usize; + const OUTPUTS_PER_OP: usize; + const WIRES_PER_OP: usize = Self::INPUTS_PER_OP + Self::OUTPUTS_PER_OP; + const DEGREE: usize; + const ID: &'static str; + fn eval( + wires: &[>::Extension], + ) -> Vec<>::Extension> + where + Self::F: Extendable; + fn apply( + builder: &mut CircuitBuilder, + targets: &[Target], + ) -> Vec + where + Self::F: Extendable, + { + assert!(targets.len() == Self::INPUTS_PER_OP); + let gate = GateAdapter::::new_from_config(&builder.config); + let (row, slot) = builder.find_slot(gate, &[], &[]); + let input_start = Self::WIRES_PER_OP * slot; + let output_start = input_start + Self::INPUTS_PER_OP; + for (i, &t) in targets.iter().enumerate() { + builder.connect(t, Target::wire(row, input_start + i)); + } + (0..Self::OUTPUTS_PER_OP) + .map(|i| Target::wire(row, output_start + i)) + .collect() + } +} + +#[derive(Debug, Clone)] +pub struct GateAdapter { + max_ops: usize, + recursive_max_wires: usize, + _gate: PhantomData, +} + +#[derive(Debug, Clone)] +pub struct RecursiveGateAdapter { + max_ops: usize, + _gate: PhantomData, +} + +#[derive(Debug)] +pub struct RecursiveGenerator { + row: usize, + index: usize, + _gate: PhantomData, +} + +impl GateAdapter { + const WIRES_PER_OP: usize = G::INPUTS_PER_OP + G::OUTPUTS_PER_OP; + + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + max_ops: config.num_routed_wires / Self::WIRES_PER_OP, + recursive_max_wires: config.num_routed_wires, + _gate: PhantomData, + } + } + + pub fn recursive_gate(&self) -> RecursiveGateAdapter { + RecursiveGateAdapter:: { + max_ops: self.recursive_max_wires / (D * G::WIRES_PER_OP), + _gate: PhantomData, + } + } +} + +impl RecursiveGateAdapter { + const INPUTS_PER_OP: usize = D * G::INPUTS_PER_OP; + const OUTPUTS_PER_OP: usize = D * G::OUTPUTS_PER_OP; + const WIRES_PER_OP: usize = Self::INPUTS_PER_OP + Self::OUTPUTS_PER_OP; + + fn apply( + &self, + builder: &mut CircuitBuilder, + vars: &[ExtensionTarget], + ) -> Vec> + where + G::F: Extendable, + { + let (row, slot) = builder.find_slot(self.clone(), &[], &[]); + for j in 0..G::INPUTS_PER_OP { + for (k, &v) in vars[j].0.iter().enumerate() { + builder.connect( + v, + Target::wire( + row, + slot * RecursiveGateAdapter::::WIRES_PER_OP + j * D + k, + ), + ); + } + } + (0..G::OUTPUTS_PER_OP) + .map(|j| { + ExtensionTarget(core::array::from_fn(|k| { + Target::wire( + row, + slot * RecursiveGateAdapter::::WIRES_PER_OP + + RecursiveGateAdapter::::INPUTS_PER_OP + + j * D + + k, + ) + })) + }) + .collect() + } +} + +impl RecursiveGenerator { + const WIRES_PER_OP: usize = RecursiveGateAdapter::::WIRES_PER_OP; + const INPUTS_PER_OP: usize = RecursiveGateAdapter::::INPUTS_PER_OP; +} + +#[derive(Debug, Clone)] +pub struct TargetList { + row: usize, + offset: usize, +} + +impl TargetList { + pub fn get(&self, index: usize) -> Target { + Target::wire(self.row, self.offset + index) + } +} + +pub trait WriteTargetList: Write { + fn write_target_list(&mut self, l: &TargetList) -> IoResult<()> { + self.write_usize(l.row)?; + self.write_usize(l.offset) + } +} + +impl WriteTargetList for W {} + +pub trait ReadTargetList: Read { + fn read_target_list(&mut self) -> IoResult { + Ok(TargetList { + row: self.read_usize()?, + offset: self.read_usize()?, + }) + } +} + +impl ReadTargetList for R {} + +impl Gate for GateAdapter +where + G::F: RichField + Extendable + Extendable<1>, +{ + fn id(&self) -> String { + G::ID.to_string() + } + + fn serialize( + &self, + dst: &mut Vec, + _common_data: &CommonCircuitData, + ) -> IoResult<()> { + dst.write_usize(self.max_ops)?; + dst.write_usize(self.recursive_max_wires) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult + where + Self: Sized, + { + Ok(Self { + max_ops: src.read_usize()?, + recursive_max_wires: src.read_usize()?, + _gate: PhantomData, + }) + } + + fn num_wires(&self) -> usize { + self.max_ops * Self::WIRES_PER_OP + } + + fn degree(&self) -> usize { + G::DEGREE + } + + fn num_ops(&self) -> usize { + self.max_ops + } + + fn num_constants(&self) -> usize { + 0 + } + + fn num_constraints(&self) -> usize { + self.max_ops * G::OUTPUTS_PER_OP + } + + fn generators( + &self, + row: usize, + _local_constants: &[G::F], + ) -> Vec> { + (0..self.max_ops) + .map(|index| { + WitnessGeneratorRef::new( + RecursiveGenerator::<1, G> { + row, + index, + _gate: PhantomData, + } + .adapter(), + ) + }) + .collect() + } + + fn eval_unfiltered_base_one( + &self, + vars_base: plonky2::plonk::vars::EvaluationVarsBase, + mut yield_constr: plonky2::gates::util::StridedConstraintConsumer, + ) { + for i in 0..self.max_ops { + let in_start = Self::WIRES_PER_OP * i; + let out_start = in_start + G::INPUTS_PER_OP; + let inputs: Vec<_> = (0..G::INPUTS_PER_OP) + .map(|j| { + >::Extension::from_basefield_array([ + vars_base.local_wires[in_start + j] + ]) + }) + .collect(); + let computed = G::eval::<1>(&inputs[..]); + yield_constr.many( + computed.iter().enumerate().map(|(j, &x)| { + x.to_basefield_array()[0] - vars_base.local_wires[out_start + j] + }), + ); + } + } + + fn eval_unfiltered( + &self, + vars: EvaluationVars, + ) -> Vec<>::Extension> { + let mut constraints = Vec::new(); + for i in 0..self.max_ops { + let in_start = Self::WIRES_PER_OP * i; + let out_start = in_start + G::INPUTS_PER_OP; + let computed = G::eval::(&vars.local_wires[in_start..out_start]); + constraints.extend( + computed + .iter() + .enumerate() + .map(|(j, &x)| x - vars.local_wires[out_start + j]), + ); + } + constraints + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: plonky2::plonk::vars::EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(G::OUTPUTS_PER_OP * self.max_ops); + for i in 0..self.max_ops { + let input_start = i * G::WIRES_PER_OP; + let output_start = input_start + G::INPUTS_PER_OP; + let computed = self + .recursive_gate() + .apply(builder, &vars.local_wires[input_start..output_start]); + for j in 0..G::OUTPUTS_PER_OP { + constraints + .push(builder.sub_extension(computed[j], vars.local_wires[output_start + j])); + } + } + constraints + } +} + +impl RecursiveGenerator { + fn deps(&self) -> Vec { + let offset = self.index * Self::WIRES_PER_OP; + (0..Self::INPUTS_PER_OP) + .map(|i| Target::wire(self.row, offset + i)) + .collect() + } +} + +impl SimpleGenerator for RecursiveGenerator +where + G: SimpleGate, + G::F: RichField + Extendable + Extendable, +{ + fn serialize( + &self, + dst: &mut Vec, + _common_data: &CommonCircuitData, + ) -> IoResult<()> { + dst.write_usize(self.row)?; + dst.write_usize(self.index) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult + where + Self: Sized, + { + Ok(Self { + row: src.read_usize()?, + index: src.read_usize()?, + _gate: PhantomData, + }) + } + + fn id(&self) -> String { + format!("Generator<{},{}>", D, G::ID) + } + + fn dependencies(&self) -> Vec { + self.deps() + } + + fn run_once( + &self, + witness: &plonky2::iop::witness::PartitionWitness, + out_buffer: &mut plonky2::iop::generator::GeneratedValues, + ) -> anyhow::Result<()> { + let deps = self.deps(); + let inputs: Vec<>::Extension> = (0..G::INPUTS_PER_OP) + .map(|i| { + >::Extension::from_basefield_array(core::array::from_fn( + |j| witness.get_target(deps[i * D + j]), + )) + }) + .collect(); + let out: Vec<>::Extension> = G::eval::(&inputs[..]); + for (i, x) in out.into_iter().enumerate() { + for (j, y) in x.to_basefield_array().into_iter().enumerate() { + let offset = self.index * Self::WIRES_PER_OP; + let target = Target::wire(self.row, offset + Self::INPUTS_PER_OP + D * i + j); + out_buffer.set_target(target, y)?; + } + } + Ok(()) + } +} + +impl Gate for RecursiveGateAdapter +where + G: SimpleGate, + F: RichField + Extendable, +{ + fn id(&self) -> String { + format!("Recursive<{},{}>", D, G::ID) + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.max_ops) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult + where + Self: Sized, + { + let max_ops = src.read_usize()?; + Ok(Self { + max_ops, + _gate: PhantomData, + }) + } + + fn num_wires(&self) -> usize { + self.max_ops * Self::WIRES_PER_OP + } + + fn degree(&self) -> usize { + G::DEGREE + } + + fn num_ops(&self) -> usize { + self.max_ops + } + + fn num_constants(&self) -> usize { + 0 + } + + fn num_constraints(&self) -> usize { + self.max_ops * Self::OUTPUTS_PER_OP + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec> { + (0..self.max_ops) + .map(|index| { + WitnessGeneratorRef::new( + RecursiveGenerator:: { + row, + index, + _gate: PhantomData, + } + .adapter(), + ) + }) + .collect() + } + + fn eval_unfiltered_base_one( + &self, + vars_base: plonky2::plonk::vars::EvaluationVarsBase, + mut yield_constr: plonky2::gates::util::StridedConstraintConsumer, + ) { + for i in 0..self.max_ops { + let input_start = D * G::WIRES_PER_OP * i; + let output_start = input_start + D * G::INPUTS_PER_OP; + let input: Vec<_> = (0..G::INPUTS_PER_OP) + .map(|j| { + F::Extension::from_basefield_array(core::array::from_fn(|k| { + vars_base.local_wires[input_start + D * j + k] + })) + }) + .collect(); + let output = G::eval(&input); + for j in 0..G::OUTPUTS_PER_OP { + yield_constr.many((0..D).map(|k| { + output[j].to_basefield_array()[k] + - vars_base.local_wires[output_start + D * j + k] + })); + } + } + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + // We think of `vars.local_wires` as a list of G::WIRES_PER_OP + // elements of ExtensionAlgebra. + // We use the fact that ExtensionAlgebra is isomorphic to D copies of + // F::Extension, with the isomorphism given by (a ⊗ b) -> + // [a.repeated_frobenius(j) * b for j in 0..D]. + let mut constraints = Vec::with_capacity(D * G::OUTPUTS_PER_OP); + let mut evals: [Vec; D] = core::array::from_fn(|_| Vec::new()); + let mut inputs = Vec::with_capacity(G::INPUTS_PER_OP); + let dth_root_inv = F::DTH_ROOT.inverse(); + let d_inv = F::from_canonical_usize(D).inverse(); + let w_inv = F::W.inverse(); + for i in 0..self.max_ops { + let input_start = D * G::WIRES_PER_OP * i; + let output_start = input_start + D * G::INPUTS_PER_OP; + // Phase factor for Frobenius automorphism + // application, cf. definition of `repeated_frobenius` + // in plonky2/field/src/extension/mod.rs. + let mut phase = F::ONE; + for ev in evals.iter_mut() { + inputs.clear(); + // Collect input wires. + for j in 0..G::INPUTS_PER_OP { + let var_start = input_start + D * j; + let var: [[F; D]; D] = core::array::from_fn(|k| { + vars.local_wires[var_start + k].to_basefield_array() + }); + let mut input = [F::ZERO; D]; + let mut factor = F::ONE; + for k in 0..D { + for l in 0..D { + let prod = factor * var[k][l]; + if k + l < D { + input[k + l] += prod; + } else { + input[k + l - D] += F::W * prod; + } + } + factor *= phase; + } + inputs.push(F::Extension::from_basefield_array(input)); + } + // Evaluate SimpleGate. + *ev = G::eval(&inputs); + phase *= F::DTH_ROOT; + } + for j in 0..G::OUTPUTS_PER_OP { + let mut phase = F::ONE; + for k in 0..D { + let mut output = [F::ZERO; D]; + let ev: [[F; D]; D] = + core::array::from_fn(|l| evals[l][j].to_basefield_array()); + let mut factor = d_inv; + for l in 0..D { + for m in 0..D { + let prod = factor * ev[l][m]; + if m >= k { + output[m - k] += prod; + } else { + output[(m + D) - k] += prod * w_inv; + } + } + factor *= phase; + } + phase *= dth_root_inv; + let expected = F::Extension::from_basefield_array(output); + let actual = vars.local_wires[output_start + j * D + k]; + constraints.push(expected - actual); + } + } + } + constraints + } + + // Recursive constraint analogue to `eval_unfiltered`. + fn eval_unfiltered_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: plonky2::plonk::vars::EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(D * G::OUTPUTS_PER_OP); + let mut evals: [Vec>; D] = core::array::from_fn(|_| Vec::new()); + let mut inputs = Vec::with_capacity(G::INPUTS_PER_OP); + let dth_root_inv = F::DTH_ROOT.inverse(); + let d_inv = F::from_canonical_usize(D).inverse(); + let w_inv_t = builder.constant(F::W.inverse()); + let w_t = builder.constant(F::W); + for i in 0..self.max_ops { + let input_start = D * G::WIRES_PER_OP * i; + let output_start = input_start + D * G::INPUTS_PER_OP; + let mut phase = F::ONE; + for ev in evals.iter_mut() { + inputs.clear(); + for j in 0..G::INPUTS_PER_OP { + let var_start = input_start + D * j; + let var: [[Target; D]; D] = + core::array::from_fn(|k| vars.local_wires[var_start + k].0); + let mut input = [builder.zero(); D]; + let mut factor = F::ONE; + for k in 0..D { + for l in 0..D { + let factor_t = builder.constant(factor); + let prod = builder.mul(factor_t, var[k][l]); + if k + l < D { + input[k + l] = builder.add(input[k + l], prod); + } else { + let prod_w = builder.mul(w_t, prod); + input[k + l - D] = builder.add(input[k + l - D], prod_w); + } + } + factor *= phase; + } + inputs.push(ExtensionTarget(input)); + } + *ev = self.apply(builder, &inputs); + phase *= F::DTH_ROOT; + } + for j in 0..G::OUTPUTS_PER_OP { + let mut phase = F::ONE; + for k in 0..D { + let mut output = [builder.zero(); D]; + let ev: [[Target; D]; D] = core::array::from_fn(|l| evals[l][j].0); + let mut factor = d_inv; + for l in 0..D { + for m in 0..D { + let factor_t = builder.constant(factor); + let prod = builder.mul(factor_t, ev[l][m]); + if m >= k { + output[m - k] = builder.add(output[m - k], prod); + } else { + let prod_wi = builder.mul(w_inv_t, prod); + output[(m + D) - k] = builder.add(output[(m + D) - k], prod_wi); + } + } + factor *= phase; + } + phase *= dth_root_inv; + let expected = ExtensionTarget(output); + let actual = vars.local_wires[output_start + j * D + k]; + constraints.push(builder.sub_extension(expected, actual)); + } + } + } + constraints + } +} diff --git a/src/backends/plonky2/primitives/ec/gates/mod.rs b/src/backends/plonky2/primitives/ec/gates/mod.rs new file mode 100644 index 00000000..58418c40 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/gates/mod.rs @@ -0,0 +1,3 @@ +pub mod curve; +pub mod field; +pub mod generic; diff --git a/src/backends/plonky2/primitives/ec/mod.rs b/src/backends/plonky2/primitives/ec/mod.rs new file mode 100644 index 00000000..c873d264 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/mod.rs @@ -0,0 +1,5 @@ +pub mod bits; +pub mod curve; +pub mod field; +pub mod gates; +pub mod schnorr; diff --git a/src/backends/plonky2/primitives/ec/schnorr.rs b/src/backends/plonky2/primitives/ec/schnorr.rs new file mode 100644 index 00000000..1e768027 --- /dev/null +++ b/src/backends/plonky2/primitives/ec/schnorr.rs @@ -0,0 +1,326 @@ +use std::array; + +use num::BigUint; +use num_bigint::RandBigInt; +use plonky2::{ + field::{ + extension::FieldExtension, + goldilocks_field::GoldilocksField, + types::{Field, PrimeField}, + }, + hash::{ + hash_types::HashOutTarget, + hashing::hash_n_to_m_no_pad, + poseidon::{PoseidonHash, PoseidonPermutation}, + }, + iop::{ + target::{BoolTarget, Target}, + witness::WitnessWrite, + }, + plonk::circuit_builder::CircuitBuilder, +}; +use rand::rngs::OsRng; + +use super::curve::Point; +use crate::{ + backends::plonky2::{ + circuits::common::CircuitBuilderPod, + primitives::ec::{ + bits::{BigUInt320Target, CircuitBuilderBits}, + curve::{CircuitBuilderElliptic, PointTarget, WitnessWriteCurve, GROUP_ORDER}, + }, + Error, + }, + middleware::RawValue, +}; + +/// Schnorr signature over ecGFp5. +#[derive(Clone, Debug)] +pub struct Signature { + pub s: BigUint, + pub e: BigUint, +} + +impl Signature { + pub fn verify(&self, public_key: Point, msg: RawValue) -> bool { + let r = &self.s * Point::generator() + &self.e * public_key; + let e = convert_hash_to_biguint(&hash(msg, r)); + e == self.e + } + pub fn as_bytes(&self) -> Vec { + let s_bytes = self + .s + .to_bytes_le() + .into_iter() + .chain(std::iter::repeat(0u8)) + .take(40); + let e_bytes = self + .e + .to_bytes_le() + .into_iter() + .chain(std::iter::repeat(0u8)) + .take(40); + s_bytes.chain(e_bytes).collect() + } + pub fn from_bytes(sig_bytes: &[u8]) -> Result { + if sig_bytes.len() != 80 { + return Err(Error::custom( + "Invalid byte encoding of Schnorr signature.".to_string(), + )); + } + + let s = BigUint::from_bytes_le(&sig_bytes[..40]); + let e = BigUint::from_bytes_le(&sig_bytes[40..]); + + Ok(Self { s, e }) + } +} + +/// Targets for Schnorr signature over ecGFp5. +#[derive(Clone, Debug)] +pub struct SignatureTarget { + pub s: BigUInt320Target, + pub e: BigUInt320Target, +} + +pub trait CircuitBuilderSchnorr { + fn add_virtual_schnorr_signature_target(&mut self) -> SignatureTarget; +} + +impl CircuitBuilderSchnorr for CircuitBuilder { + fn add_virtual_schnorr_signature_target(&mut self) -> SignatureTarget { + SignatureTarget { + s: self.add_virtual_biguint320_target(), + e: self.add_virtual_biguint320_target(), + } + } +} + +pub trait WitnessWriteSchnorr: WitnessWrite + WitnessWriteCurve { + fn set_signature_target( + &mut self, + target: &SignatureTarget, + value: &Signature, + ) -> anyhow::Result<()> { + self.set_biguint320_target(&target.s, &value.s)?; + self.set_biguint320_target(&target.e, &value.e) + } +} + +impl> WitnessWriteSchnorr for W {} + +impl SignatureTarget { + pub fn verify( + &self, + builder: &mut CircuitBuilder, + msg: HashOutTarget, + public_key: &PointTarget, + ) -> BoolTarget { + let g = builder.constant_point(Point::generator()); + let sig1_bits = self.s.bits; + let sig2_bits = self.e.bits; + let r = builder.linear_combination_points(&sig1_bits, &sig2_bits, &g, public_key); + let u_arr = r.u.components; + let inputs = u_arr.into_iter().chain(msg.elements).collect::>(); + let e_hash = hash_array_circuit(builder, &inputs); + let e = builder.field_elements_to_biguint(&e_hash); + builder.is_equal_slice(&self.e.limbs, &e.limbs) + } +} + +pub struct SecretKey(pub BigUint); + +impl SecretKey { + pub fn new_rand() -> Self { + Self(OsRng.gen_biguint_below(&GROUP_ORDER)) + } + pub fn public_key(&self) -> Point { + &self.0 * Point::generator().inverse() + } + + pub fn sign(&self, msg: RawValue, nonce: &BigUint) -> Signature { + let r = nonce * Point::generator(); + let e = convert_hash_to_biguint(&hash(msg, r)); + let s = (nonce + &self.0 * &e) % &*GROUP_ORDER; + Signature { s, e } + } +} + +impl SignatureTarget { + pub fn add_virtual_target(builder: &mut CircuitBuilder) -> Self { + Self { + s: builder.add_virtual_biguint320_target(), + e: builder.add_virtual_biguint320_target(), + } + } +} + +fn hash_array(values: &[GoldilocksField]) -> [GoldilocksField; 5] { + let hash = hash_n_to_m_no_pad::<_, PoseidonPermutation<_>>(values, 5); + std::array::from_fn(|i| hash[i]) +} + +fn hash(msg: RawValue, point: Point) -> [GoldilocksField; 5] { + // The elements of the group have distinct u-coordinates; see the comment in + // CircuitBuilderEllptic::connect_point. So we don't need to hash the + // x-coordinate. + let u_arr: [GoldilocksField; 5] = point.u.to_basefield_array(); + let values: Vec<_> = u_arr.into_iter().chain(msg.0).collect(); + hash_array(&values) +} + +fn convert_hash_to_biguint(hash: &[GoldilocksField; 5]) -> BigUint { + let mut ans = BigUint::ZERO; + for val in hash.iter().rev() { + ans *= GoldilocksField::order(); + ans += val.to_canonical_biguint(); + } + ans +} + +fn hash_array_circuit( + builder: &mut CircuitBuilder, + inputs: &[Target], +) -> [Target; 5] { + let input_vec = inputs.to_owned(); + let hash = builder.hash_n_to_m_no_pad::(input_vec, 5); + array::from_fn(|i| hash[i]) +} + +#[cfg(test)] +mod test { + use num_bigint::RandBigInt; + use plonky2::{ + field::{goldilocks_field::GoldilocksField, types::Sample}, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::{ + circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, + config::PoseidonGoldilocksConfig, + }, + }; + use rand::rngs::OsRng; + + use crate::{ + backends::plonky2::primitives::ec::{ + bits::CircuitBuilderBits, + curve::{CircuitBuilderElliptic, Point, WitnessWriteCurve, GROUP_ORDER}, + schnorr::{ + convert_hash_to_biguint, hash_array, hash_array_circuit, SecretKey, Signature, + SignatureTarget, + }, + }, + middleware::RawValue, + }; + + fn gen_signed_message() -> (Point, RawValue, Signature) { + let msg = RawValue(GoldilocksField::rand_array()); + let private_key = SecretKey(OsRng.gen_biguint_below(&GROUP_ORDER)); + let nonce = OsRng.gen_biguint_below(&GROUP_ORDER); + let public_key = private_key.public_key(); + let sig = private_key.sign(msg, &nonce); + (public_key, msg, sig) + } + + #[test] + fn test_verify_signature() { + let (public_key, msg, sig) = gen_signed_message(); + assert!(&sig.s < &GROUP_ORDER); + assert!(sig.verify(public_key, msg)); + } + + #[test] + fn test_reject_bogus_signature() { + let msg = RawValue(GoldilocksField::rand_array()); + let private_key = SecretKey(OsRng.gen_biguint_below(&GROUP_ORDER)); + let nonce = OsRng.gen_biguint_below(&GROUP_ORDER); + let public_key = private_key.public_key(); + let sig = private_key.sign(msg, &nonce); + let junk = OsRng.gen_biguint_below(&GROUP_ORDER); + assert!(!Signature { + s: sig.s.clone(), + e: junk.clone() + } + .verify(public_key, msg)); + assert!(!Signature { s: junk, e: sig.e }.verify(public_key, msg)); + } + + #[test] + fn test_verify_signature_circuit() -> Result<(), anyhow::Error> { + let (public_key, msg, sig) = gen_signed_message(); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let key_t = builder.add_virtual_point_target(); + let msg_t = builder.add_virtual_hash(); + let sig_t = SignatureTarget::add_virtual_target(&mut builder); + let verified = sig_t.verify(&mut builder, msg_t, &key_t); + builder.assert_one(verified.target); + let mut pw = PartialWitness::new(); + pw.set_point_target(&key_t, &public_key)?; + pw.set_hash_target(msg_t, msg.0.into())?; + pw.set_biguint320_target(&sig_t.s, &sig.s)?; + pw.set_biguint320_target(&sig_t.e, &sig.e)?; + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } + + #[test] + fn test_reject_bogus_signature_circuit() { + let (public_key, msg, sig) = gen_signed_message(); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let key_t = builder.constant_point(public_key); + let msg_t = builder.constant_hash(msg.0.into()); + // sig.s and sig.e are passed out of order + let sig_t = SignatureTarget { + s: builder.constant_biguint320(&sig.e), + e: builder.constant_biguint320(&sig.s), + }; + let verified = sig_t.verify(&mut builder, msg_t, &key_t); + builder.assert_one(verified.target); + let pw = PartialWitness::new(); + let data = builder.build::(); + assert!(data.prove(pw).is_err()); + } + + #[test] + fn test_hash_consistency() -> Result<(), anyhow::Error> { + let values = GoldilocksField::rand_array::<9>(); + let hash = hash_array(&values); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let values_const = values.map(|v| builder.constant(v)); + let hash_const = hash.map(|v| builder.constant(v)); + let hash_circuit = hash_array_circuit(&mut builder, &values_const); + for i in 0..5 { + builder.connect(hash_const[i], hash_circuit[i]); + } + let pw = PartialWitness::new(); + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } + + #[test] + fn test_hash_to_bigint_consistency() -> Result<(), anyhow::Error> { + let hash = GoldilocksField::rand_array(); + let hash_int = convert_hash_to_biguint(&hash); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let hash_const: [Target; 5] = std::array::from_fn(|i| builder.constant(hash[i])); + let int_const = builder.constant_biguint320(&hash_int); + let int_circuit = builder.field_elements_to_biguint(&hash_const); + builder.connect_biguint320(&int_const, &int_circuit); + println!("{}", builder.num_gates()); + let pw = PartialWitness::new(); + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } +} diff --git a/src/backends/plonky2/primitives/mod.rs b/src/backends/plonky2/primitives/mod.rs index 7dcd478e..224fc375 100644 --- a/src/backends/plonky2/primitives/mod.rs +++ b/src/backends/plonky2/primitives/mod.rs @@ -1,2 +1,3 @@ +pub mod ec; pub mod merkletree; pub mod signature; diff --git a/src/backends/plonky2/primitives/signature/circuit.rs b/src/backends/plonky2/primitives/signature/circuit.rs index 36b75ea9..91aa2651 100644 --- a/src/backends/plonky2/primitives/signature/circuit.rs +++ b/src/backends/plonky2/primitives/signature/circuit.rs @@ -26,41 +26,24 @@ use crate::{ basetypes::{C, D}, circuits::common::{CircuitBuilderPod, ValueTarget}, error::Result, - primitives::signature::{ - PublicKey, SecretKey, Signature, DUMMY_PUBLIC_INPUTS, DUMMY_SIGNATURE, VP, + primitives::ec::{ + curve::{CircuitBuilderElliptic, Point, PointTarget, WitnessWriteCurve}, + schnorr::{CircuitBuilderSchnorr, Signature, SignatureTarget, WitnessWriteSchnorr}, }, }, measure_gates_begin, measure_gates_end, middleware::{Hash, Proof, RawValue, EMPTY_HASH, EMPTY_VALUE, F, VALUE_SIZE}, }; -lazy_static! { - /// SignatureVerifyGadget VerifierCircuitData - pub static ref S_VD: VerifierCircuitData = SignatureVerifyGadget::verifier_data().unwrap(); -} +pub struct SignatureVerifyGadget; -pub struct SignatureVerifyGadget {} pub struct SignatureVerifyTarget { - // verifier_data of the SignatureInternalCircuit - verifier_data_targ: VerifierCircuitTarget, // `enabled` determines if the signature verification is enabled pub(crate) enabled: BoolTarget, - pub(crate) pk: ValueTarget, + pub(crate) pk: PointTarget, pub(crate) msg: ValueTarget, // proof of the SignatureInternalCircuit (=signature::Signature.0) - proof: ProofWithPublicInputsTarget, -} - -impl SignatureVerifyGadget { - pub fn verifier_data() -> Result> { - // notice that we use the 'zk' config - let config = CircuitConfig::standard_recursion_zk_config(); - let mut builder = CircuitBuilder::::new(config); - let circuit = SignatureVerifyGadget {}.eval(&mut builder)?; - - let circuit_data = builder.build::(); - Ok(circuit_data.verifier_data()) - } + pub(crate) sig: SignatureTarget, } impl SignatureVerifyGadget { @@ -68,62 +51,22 @@ impl SignatureVerifyGadget { pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { let measure = measure_gates_begin!(builder, "SignatureVerify"); let enabled = builder.add_virtual_bool_target_safe(); + let pk = builder.add_virtual_point_target(); + let msg = builder.add_virtual_value(); + let sig = builder.add_virtual_schnorr_signature_target(); - let common_data = VP.0.common.clone(); - - // targets related to the 'public inputs' for the verification of the - // `SignatureInternalCircuit` proof. - let pk_targ = builder.add_virtual_value(); - let msg_targ = builder.add_virtual_value(); - let inp: Vec = [pk_targ.elements.to_vec(), msg_targ.elements.to_vec()].concat(); - let s_targ = builder.hash_n_to_hash_no_pad::(inp); - - let verifier_data_targ = - builder.add_virtual_verifier_data(common_data.config.fri_config.cap_height); + let verified = sig.verify(builder, HashOutTarget::from(msg.elements), &pk); - let proof_targ = builder.add_virtual_proof_with_pis(&common_data); + let result = builder.mul_sub(enabled.target, verified.target, enabled.target); - let dummy_pi = DUMMY_PUBLIC_INPUTS.clone(); - - let pk_targ_dummy = - builder.constant_value(RawValue(dummy_pi[..VALUE_SIZE].try_into().unwrap())); - let msg_targ_dummy = builder.constant_value(RawValue( - dummy_pi[VALUE_SIZE..VALUE_SIZE * 2].try_into().unwrap(), - )); - let s_targ_dummy = - builder.constant_value(RawValue(dummy_pi[VALUE_SIZE * 2..].try_into().unwrap())); - - // connect the {pk, msg, s} with the proof_targ.public_inputs conditionally - let pk_targ_connect = builder.select_value(enabled, pk_targ, pk_targ_dummy); - let msg_targ_connect = builder.select_value(enabled, msg_targ, msg_targ_dummy); - let s_targ_connect = builder.select_value( - enabled, - ValueTarget { - elements: s_targ.elements, - }, - s_targ_dummy, - ); - for i in 0..VALUE_SIZE { - builder.connect(pk_targ_connect.elements[i], proof_targ.public_inputs[i]); - builder.connect( - msg_targ_connect.elements[i], - proof_targ.public_inputs[VALUE_SIZE + i], - ); - builder.connect( - s_targ_connect.elements[i], - proof_targ.public_inputs[(2 * VALUE_SIZE) + i], - ); - } - - builder.verify_proof::(&proof_targ, &verifier_data_targ, &common_data); + builder.assert_zero(result); measure_gates_end!(builder, measure); Ok(SignatureVerifyTarget { - verifier_data_targ, enabled, - pk: pk_targ, - msg: msg_targ, - proof: proof_targ, + pk, + msg, + sig, }) } } @@ -134,37 +77,14 @@ impl SignatureVerifyTarget { &self, pw: &mut PartialWitness, enabled: bool, - pk: PublicKey, + pk: Point, msg: RawValue, signature: Signature, ) -> Result<()> { pw.set_bool_target(self.enabled, enabled)?; - pw.set_target_arr(&self.pk.elements, &pk.0 .0)?; + pw.set_point_target(&self.pk, &pk)?; pw.set_target_arr(&self.msg.elements, &msg.0)?; - - // note that this hash is checked again in-circuit at the `SignatureInternalCircuit` - let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); - let public_inputs: Vec = [pk.0 .0, msg.0, s.0].concat(); - - if enabled { - pw.set_proof_with_pis_target( - &self.proof, - &ProofWithPublicInputs { - proof: signature.0, - public_inputs, - }, - )?; - } else { - pw.set_proof_with_pis_target( - &self.proof, - &ProofWithPublicInputs { - proof: DUMMY_SIGNATURE.0.clone(), - public_inputs: DUMMY_PUBLIC_INPUTS.clone(), - }, - )?; - } - - pw.set_verifier_data_target(&self.verifier_data_targ, &VP.0.verifier_only)?; + pw.set_signature_target(&self.sig, &signature)?; Ok(()) } @@ -172,8 +92,13 @@ impl SignatureVerifyTarget { #[cfg(test)] pub mod tests { + use num_bigint::RandBigInt; + use super::*; - use crate::{backends::plonky2::primitives::signature::SecretKey, middleware::Hash}; + use crate::{ + backends::plonky2::primitives::ec::{curve::GROUP_ORDER, schnorr::SecretKey}, + middleware::Hash, + }; #[test] fn test_signature_gadget_enabled() -> Result<()> { @@ -181,8 +106,9 @@ pub mod tests { let sk = SecretKey::new_rand(); let pk = sk.public_key(); let msg = RawValue::from(42); - let sig = sk.sign(msg)?; - sig.verify(&pk, msg)?; + let nonce = 1337u64.into(); + let sig = sk.sign(msg, &nonce); + assert!(sig.verify(pk, msg), "Should verify"); // circuit let config = CircuitConfig::standard_recursion_zk_config(); @@ -197,12 +123,6 @@ pub mod tests { let proof = data.prove(pw)?; data.verify(proof.clone())?; - // verify the proof with the lazy_static loaded verifier_data (S_VD) - S_VD.verify(ProofWithPublicInputs { - proof: proof.proof.clone(), - public_inputs: vec![], - })?; - Ok(()) } @@ -212,22 +132,24 @@ pub mod tests { let sk = SecretKey::new_rand(); let pk = sk.public_key(); let msg = RawValue::from(42); - let sig = sk.sign(msg)?; + let nonce = 600613u64.into(); + let sig = sk.sign(msg, &nonce); // verification should pass - sig.verify(&pk, msg)?; + let v = sig.verify(pk, msg); + assert!(v, "should verify"); // replace the message, so that verifications should fail let msg = RawValue::from(24); // expect signature native verification to fail - let v = sig.verify(&pk, RawValue::from(24)); - assert!(v.is_err(), "should fail to verify"); + let v = sig.verify(pk, RawValue::from(24)); + assert!(!v, "should fail to verify"); // circuit let config = CircuitConfig::standard_recursion_zk_config(); let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); let targets = SignatureVerifyGadget {}.eval(&mut builder)?; - targets.set_targets(&mut pw, true, pk.clone(), msg, sig.clone())?; // enabled=true + targets.set_targets(&mut pw, true, pk, msg, sig.clone())?; // enabled=true // generate proof, and expect it to fail let data = builder.build::(); @@ -248,12 +170,6 @@ pub mod tests { let proof = data.prove(pw)?; data.verify(proof.clone())?; - // verify the proof with the lazy_static loaded verifier_data (S_VD) - S_VD.verify(ProofWithPublicInputs { - proof: proof.proof.clone(), - public_inputs: vec![], - })?; - Ok(()) } } diff --git a/src/backends/plonky2/primitives/signature/mod.rs b/src/backends/plonky2/primitives/signature/mod.rs index ba1cc83b..2f57e01c 100644 --- a/src/backends/plonky2/primitives/signature/mod.rs +++ b/src/backends/plonky2/primitives/signature/mod.rs @@ -1,246 +1,5 @@ //! Proof-based signatures using Plonky2 proofs, following //! https://eprint.iacr.org/2024/1553 . -use lazy_static::lazy_static; -use plonky2::{ - field::types::Sample, - hash::{ - hash_types::{HashOut, HashOutTarget}, - poseidon::PoseidonHash, - }, - iop::{ - target::Target, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::{ - circuit_builder::CircuitBuilder, - circuit_data::{CircuitConfig, ProverCircuitData, VerifierCircuitData}, - config::Hasher, - proof::ProofWithPublicInputs, - }, -}; pub mod circuit; pub use circuit::*; -use serde::{Deserialize, Serialize}; - -use crate::{ - backends::plonky2::{ - basetypes::{Proof, C, D}, - error::{Error, Result}, - }, - middleware::{RawValue, F, VALUE_SIZE}, -}; - -lazy_static! { - /// Signature prover parameters - pub static ref PP: ProverParams = Signature::prover_params().unwrap(); - /// Signature verifier parameters - pub static ref VP: VerifierParams = Signature::verifier_params().unwrap(); - - /// DUMMY_SIGNATURE is used for conditionals where we want to use a `selector` to enable or - /// disable signature verification. - pub static ref DUMMY_SIGNATURE: Signature = dummy_signature().unwrap(); - /// DUMMY_PUBLIC_INPUTS accompanies the DUMMY_SIGNATURE. - pub static ref DUMMY_PUBLIC_INPUTS: Vec = dummy_public_inputs().unwrap(); -} - -pub struct ProverParams { - prover: ProverCircuitData, - circuit: SignatureInternalCircuit, -} - -#[derive(Clone, Debug)] -pub struct VerifierParams(pub(crate) VerifierCircuitData); - -#[derive(Clone, Debug)] -pub struct SecretKey(pub(crate) RawValue); - -#[derive(Clone, Debug)] -pub struct PublicKey(pub RawValue); - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(transparent)] -pub struct Signature(pub(crate) Proof); - -/// Implements the key generation and the computation of proof-based signatures. -impl SecretKey { - pub fn new_rand() -> Self { - // note: the `F::rand()` internally uses `rand::rngs::OsRng` - Self(RawValue(std::array::from_fn(|_| F::rand()))) - } - - pub fn public_key(&self) -> PublicKey { - PublicKey(RawValue(PoseidonHash::hash_no_pad(&self.0 .0).elements)) - } - - pub fn sign(&self, msg: RawValue) -> Result { - let pk = self.public_key(); - let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); - - let mut pw = PartialWitness::::new(); - PP.circuit.set_targets(&mut pw, self.clone(), pk, msg, s)?; - - let proof = PP.prover.prove(pw)?; - - Ok(Signature(proof.proof)) - } -} - -/// Implements the parameters generation and the verification of proof-based -/// signatures. -impl Signature { - pub fn prover_params() -> Result { - let (builder, circuit) = Self::builder()?; - let prover = builder.build_prover::(); - Ok(ProverParams { prover, circuit }) - } - pub fn verifier_params() -> Result { - let (builder, _) = Self::builder()?; - let circuit_data = builder.build::(); - let vp = circuit_data.verifier_data(); - - Ok(VerifierParams(vp)) - } - pub fn params() -> Result<(ProverParams, VerifierParams)> { - let pp = Self::prover_params()?; - let vp = Self::verifier_params()?; - Ok((pp, vp)) - } - - fn builder() -> Result<(CircuitBuilder, SignatureInternalCircuit)> { - // notice that we use the 'zk' config - let config = CircuitConfig::standard_recursion_zk_config(); - - let mut builder = CircuitBuilder::::new(config); - let circuit = SignatureInternalCircuit::add_targets(&mut builder)?; - - Ok((builder, circuit)) - } - - pub fn verify(&self, pk: &PublicKey, msg: RawValue) -> Result<()> { - // prepare public inputs as [pk, msg, s] - let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); - let public_inputs: Vec = [pk.0 .0, msg.0, s.0].concat(); - - // verify plonky2 proof - VP.0.verify(ProofWithPublicInputs { - proof: self.0.clone(), - public_inputs, - }) - .map_err(Error::plonky2_proof_fail) - } -} - -fn dummy_public_inputs() -> Result> { - let sk = SecretKey(RawValue::from(0)); - let pk = sk.public_key(); - let msg = RawValue::from(0); - let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); - Ok([pk.0 .0, msg.0, s.0].concat()) -} - -fn dummy_signature() -> Result { - let sk = SecretKey(RawValue::from(0)); - let msg = RawValue::from(0); - sk.sign(msg) -} - -/// The SignatureInternalCircuit implements the circuit used for the proof of -/// the argument described at https://eprint.iacr.org/2024/1553. -/// -/// The circuit proves that for the given public inputs (pk, msg, s), the Prover -/// knows the secret (sk) such that: -/// i) pk == H(sk) -/// ii) s == H(pk, msg) -struct SignatureInternalCircuit { - sk_targ: Vec, - pk_targ: HashOutTarget, - msg_targ: Vec, - s_targ: HashOutTarget, -} - -impl SignatureInternalCircuit { - /// creates the targets and defines the logic of the circuit - fn add_targets(builder: &mut CircuitBuilder) -> Result { - // create the targets - let sk_targ = builder.add_virtual_targets(VALUE_SIZE); - let pk_targ = builder.add_virtual_hash(); - let msg_targ = builder.add_virtual_targets(VALUE_SIZE); - let s_targ = builder.add_virtual_hash(); - - // define the public inputs - builder.register_public_inputs(&pk_targ.elements); - builder.register_public_inputs(&msg_targ); - builder.register_public_inputs(&s_targ.elements); - - // define the logic - let computed_pk_targ = builder.hash_n_to_hash_no_pad::(sk_targ.clone()); - builder.connect_array::(computed_pk_targ.elements, pk_targ.elements); - - let inp: Vec = [pk_targ.elements.to_vec(), msg_targ.clone()].concat(); - let computed_s_targ = builder.hash_n_to_hash_no_pad::(inp); - builder.connect_array::(computed_s_targ.elements, s_targ.elements); - - // return the targets - Ok(Self { - sk_targ, - pk_targ, - msg_targ, - s_targ, - }) - } - - /// assigns the given values to the targets - fn set_targets( - &self, - pw: &mut PartialWitness, - sk: SecretKey, - pk: PublicKey, - msg: RawValue, - s: RawValue, - ) -> Result<()> { - pw.set_target_arr(&self.sk_targ, sk.0 .0.as_ref())?; - pw.set_hash_target(self.pk_targ, HashOut::::from_vec(pk.0 .0.to_vec()))?; - pw.set_target_arr(&self.msg_targ, msg.0.as_ref())?; - pw.set_hash_target(self.s_targ, HashOut::::from_vec(s.0.to_vec()))?; - - Ok(()) - } -} - -#[cfg(test)] -pub mod tests { - use super::*; - use crate::middleware::hash_str; - - #[test] - fn test_signature() -> Result<()> { - let sk = SecretKey::new_rand(); - let pk = sk.public_key(); - - let msg = RawValue::from(42); - let sig = sk.sign(msg)?; - sig.verify(&pk, msg)?; - - // expect the signature verification to fail when using a different msg - let v = sig.verify(&pk, RawValue::from(24)); - assert!(v.is_err(), "should fail to verify"); - - // perform a 2nd signature over another msg and verify it - let msg_2 = RawValue::from(hash_str("message")); - let sig2 = sk.sign(msg_2)?; - sig2.verify(&pk, msg_2)?; - - Ok(()) - } - - #[test] - fn test_dummy_signature() -> Result<()> { - let sk = SecretKey(RawValue::from(0)); - let pk = sk.public_key(); - let msg = RawValue::from(0); - DUMMY_SIGNATURE.clone().verify(&pk, msg)?; - - Ok(()) - } -} diff --git a/src/backends/plonky2/recursion/circuit.rs b/src/backends/plonky2/recursion/circuit.rs index efccda10..33d80972 100644 --- a/src/backends/plonky2/recursion/circuit.rs +++ b/src/backends/plonky2/recursion/circuit.rs @@ -11,7 +11,7 @@ use itertools::Itertools; use plonky2::{ self, - field::types::Field, + field::{extension::quintic::QuinticExtension, types::Field}, gates::noop::NoopGate, hash::hash_types::HashOutTarget, iop::{ @@ -33,6 +33,9 @@ use crate::{ backends::plonky2::{ basetypes::{C, D}, error::Result, + primitives::ec::gates::{ + curve::ECAddHomogOffset, field::NNFMulSimple, generic::GateAdapter, + }, }, middleware::F, timed, @@ -342,6 +345,13 @@ pub fn common_data_for_recursion( GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 4)), GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 5)), GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 6)), + GateRef::new(GateAdapter::>>::new_from_config(&config)), + GateRef::new( + GateAdapter::>>::new_from_config(&config) + .recursive_gate(), + ), + GateRef::new(GateAdapter::::new_from_config(&config)), + GateRef::new(GateAdapter::::new_from_config(&config).recursive_gate()), GateRef::new(plonky2::gates::exponentiation::ExponentiationGate::new_from_config(&config)), // It would be better do `CosetInterpolationGate::with_max_degree(4, 6)` but unfortunately // that plonk2 method is `pub(crate)`, so we need to get around that somehow. @@ -468,7 +478,7 @@ mod tests { let mut aux: F = inp.elements[0]; let two = F::from_canonical_u64(2u64); for _ in 0..5_000 { - aux = aux + two; + aux += two; } HashOut::::from_vec(vec![aux, F::ZERO, F::ZERO, F::ZERO]) } @@ -511,7 +521,7 @@ mod tests { let zero = builder.zero(); let output_targ = HashOutTarget::from_vec(vec![aux, zero, zero, zero]); - builder.register_public_inputs(&output_targ.elements.to_vec()); + builder.register_public_inputs(output_targ.elements.as_ref()); Ok(Self { input: input_targ, @@ -539,13 +549,13 @@ mod tests { ) -> Result { let input_targ = builder.add_virtual_hash(); - let mut output_targ: HashOutTarget = input_targ.clone(); + let mut output_targ: HashOutTarget = input_targ; for _ in 0..100 { output_targ = builder .hash_n_to_hash_no_pad::(output_targ.elements.clone().to_vec()); } - builder.register_public_inputs(&output_targ.elements.to_vec()); + builder.register_public_inputs(output_targ.elements.as_ref()); Ok(Self { input: input_targ, @@ -573,13 +583,13 @@ mod tests { ) -> Result { let input_targ = builder.add_virtual_hash(); - let mut output_targ: HashOutTarget = input_targ.clone(); + let mut output_targ: HashOutTarget = input_targ; for _ in 0..2000 { output_targ = builder .hash_n_to_hash_no_pad::(output_targ.elements.clone().to_vec()); } - builder.register_public_inputs(&output_targ.elements.to_vec()); + builder.register_public_inputs(output_targ.elements.as_ref()); Ok(Self { input: input_targ, @@ -709,7 +719,7 @@ mod tests { start.elapsed() ); - let (dummy_verifier_data, dummy_proof) = dummy(&common_data, num_public_inputs)?; + let (dummy_verifier_data, dummy_proof) = dummy(common_data, num_public_inputs)?; let circuit1 = RC::::build(¶ms_1, &())?; let circuit2 = RC::::build(¶ms_2, &())?; diff --git a/src/backends/plonky2/signedpod.rs b/src/backends/plonky2/signedpod.rs index 0254c2a3..a4928964 100644 --- a/src/backends/plonky2/signedpod.rs +++ b/src/backends/plonky2/signedpod.rs @@ -2,14 +2,18 @@ use std::collections::HashMap; use base64::{prelude::BASE64_STANDARD, Engine}; use itertools::Itertools; -use plonky2::util::serialization::Buffer; +use num_bigint::RandBigInt; +use rand::rngs::OsRng; use crate::{ backends::plonky2::{ error::{Error, Result}, primitives::{ + ec::{ + curve::{Point, GROUP_ORDER}, + schnorr::{SecretKey, Signature}, + }, merkletree::MerkleTree, - signature::{PublicKey, SecretKey, Signature, VP}, }, }, constants::MAX_DEPTH, @@ -25,21 +29,23 @@ impl Signer { fn _sign(&mut self, _params: &Params, kvs: &HashMap) -> Result { let mut kvs = kvs.clone(); let pubkey = self.0.public_key(); - kvs.insert(Key::from(KEY_SIGNER), Value::from(pubkey.0)); + kvs.insert(Key::from(KEY_SIGNER), Value::from(pubkey)); kvs.insert(Key::from(KEY_TYPE), Value::from(PodType::Signed)); let dict = Dictionary::new(kvs)?; let id = RawValue::from(dict.commitment()); // PodId as Value - let signature: Signature = self.0.sign(id)?; + let nonce = OsRng.gen_biguint_below(&GROUP_ORDER); + let signature: Signature = self.0.sign(id, &nonce); Ok(SignedPod { id: PodId(Hash::from(id)), signature, + signer: pubkey, dict, }) } - pub fn public_key(&self) -> PublicKey { + pub fn public_key(&self) -> Point { self.0.public_key() } } @@ -58,6 +64,7 @@ impl PodSigner for Signer { pub struct SignedPod { pub id: PodId, pub signature: Signature, + pub signer: Point, pub dict: Dictionary, } @@ -88,34 +95,35 @@ impl SignedPod { } // 3. Verify signature - let pk_value = self.dict.get(&Key::from(KEY_SIGNER))?; - let pk = PublicKey(pk_value.raw()); - self.signature.verify(&pk, RawValue::from(id.0))?; - - Ok(()) + let embedded_pk_value = self.dict.get(&Key::from(KEY_SIGNER))?; + let pk = self.signer; + let pk_value = Value::from(pk); + if &pk_value != embedded_pk_value { + return Err(Error::signer_not_equal(embedded_pk_value.clone(), pk_value)); + } + self.signature + .verify(pk, RawValue::from(id.0)) + .then_some(()) + .ok_or(Error::custom("Invalid signature!".into())) } - pub fn decode_signature(signature: &str) -> Result { - use plonky2::util::serialization::Read; - - let decoded = BASE64_STANDARD.decode(signature).map_err(|e| { + pub fn decode_proof(signature: &str) -> Result<(Point, Signature), Error> { + let proof_bytes = BASE64_STANDARD.decode(signature).map_err(|e| { Error::custom(format!( - "Failed to decode signature from base64: {}. Value: {}", + "Failed to decode proof from base64: {}. Value: {}", e, signature )) })?; - let mut buf = Buffer::new(&decoded); - let proof = buf.read_proof(&VP.0.common).map_err(|e| { - Error::custom(format!( - "Failed to read signature from buffer: {}. Value: {}", - e, signature - )) - })?; - - let sig = Signature(proof); + if proof_bytes.len() != 160 { + return Err(Error::custom( + "Invalid byte encoding of signed POD proof.".to_string(), + )); + } - Ok(sig) + let signer = Point::from_bytes(&proof_bytes[..80])?; + let signature = Signature::from_bytes(&proof_bytes[80..])?; + Ok((signer, signature)) } } @@ -146,10 +154,9 @@ impl Pod for SignedPod { } fn serialized_proof(&self) -> String { - let mut buffer = Vec::new(); - use plonky2::util::serialization::Write; - buffer.write_proof(&self.signature.0).unwrap(); - BASE64_STANDARD.encode(buffer) + // Serialise signer + signature. + let proof_bytes = [self.signer.as_bytes(), self.signature.as_bytes()].concat(); + BASE64_STANDARD.encode(&proof_bytes) } } @@ -173,8 +180,7 @@ pub mod tests { pod.insert("dateOfBirth", 1169909384); pod.insert("socialSecurityNumber", "G2121210"); - // TODO: Use a deterministic secret key to get deterministic tests - let sk = SecretKey::new_rand(); + let sk = SecretKey(123u64.into()); let mut signer = Signer(sk); let pod = pod.sign(&mut signer).unwrap(); let pod = (pod.pod as Box).downcast::().unwrap(); @@ -184,7 +190,8 @@ pub mod tests { println!("kvs: {:?}", pod.kvs()); let mut bad_pod = pod.clone(); - bad_pod.signature = signer.0.sign(RawValue::from(42_i64))?; + let nonce = 456u64.into(); + bad_pod.signature = signer.0.sign(RawValue::from(42_i64), &nonce); assert!(bad_pod.verify().is_err()); let mut bad_pod = pod.clone(); diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 62ad475a..e7943fec 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -844,13 +844,9 @@ pub mod tests { // Check that frontend public statements agree with those // embedded in a MainPod. fn check_public_statements(pod: &MainPod) -> Result<()> { - Ok( - std::iter::zip(pod.public_statements.clone(), pod.pod.pub_statements()).try_for_each( - |(fes, s)| { - crate::middleware::Statement::try_from(fes).map(|fes| assert_eq!(fes, s)) - }, - )?, - ) + std::iter::zip(pod.public_statements.clone(), pod.pod.pub_statements()) + .for_each(|(fes, s)| assert_eq!(fes, s)); + Ok(()) } // Check that frontend key-values agree with those embedded in a diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 29abbe93..381ccf52 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -78,14 +78,19 @@ impl From for SerializedSignedPod { impl From for SignedPod { fn from(serialized: SerializedSignedPod) -> Self { match serialized.pod_type { - SignedPodType::Signed => SignedPod { - pod: Box::new(Plonky2SignedPod { - id: serialized.id, - signature: Plonky2SignedPod::decode_signature(&serialized.proof).unwrap(), - dict: Dictionary::new(serialized.entries.clone()).unwrap(), - }), - kvs: serialized.entries, - }, + SignedPodType::Signed => { + let (signer, signature) = + Plonky2SignedPod::decode_proof(&serialized.proof).unwrap(); + SignedPod { + pod: Box::new(Plonky2SignedPod { + id: serialized.id, + signer, + signature, + dict: Dictionary::new(serialized.entries.clone()).unwrap(), + }), + kvs: serialized.entries, + } + } SignedPodType::MockSigned => SignedPod { pod: Box::new(MockSignedPod::new( serialized.id, @@ -210,7 +215,7 @@ mod tests { backends::plonky2::{ mainpod::Prover, mock::{mainpod::MockProver, signedpod::MockSigner}, - primitives::signature::SecretKey, + primitives::ec::schnorr::SecretKey, signedpod::Signer, }, examples::{ @@ -221,7 +226,7 @@ mod tests { middleware::{ self, containers::{Array, Set}, - Params, RawValue, TypedValue, + Params, TypedValue, }, }; @@ -309,7 +314,7 @@ mod tests { #[test] fn test_signed_pod_serialization() { let builder = signed_pod_builder(); - let mut signer = Signer(SecretKey(RawValue::from(1))); + let mut signer = Signer(SecretKey(1u32.into())); let pod = builder.sign(&mut signer).unwrap(); let serialized = serde_json::to_string_pretty(&pod).unwrap(); @@ -377,11 +382,11 @@ mod tests { let (gov_id_builder, pay_stub_builder, sanction_list_builder) = zu_kyc_sign_pod_builders(¶ms); - let mut signer = Signer(SecretKey(RawValue::from(1))); + let mut signer = Signer(SecretKey(1u32.into())); let gov_id_pod = gov_id_builder.sign(&mut signer)?; - let mut signer = Signer(SecretKey(RawValue::from(2))); + let mut signer = Signer(SecretKey(2u32.into())); let pay_stub_pod = pay_stub_builder.sign(&mut signer)?; - let mut signer = Signer(SecretKey(RawValue::from(3))); + let mut signer = Signer(SecretKey(3u32.into())); let sanction_list_pod = sanction_list_builder.sign(&mut signer)?; let kyc_builder = zu_kyc_pod_builder(¶ms, &gov_id_pod, &pay_stub_pod, &sanction_list_pod)?; diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index b6b9239b..4308579d 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -27,7 +27,9 @@ pub use operation::*; use serialization::*; pub use statement::*; -use crate::backends::plonky2::primitives::merkletree::MerkleProof; +use crate::backends::plonky2::primitives::{ + ec::curve::Point as PublicKey, merkletree::MerkleProof, +}; pub const SELF: PodId = PodId(SELF_ID_HASH); @@ -56,6 +58,8 @@ pub enum TypedValue { ), // Uses the serialization for middleware::Value: Raw(RawValue), + // Public key variant + PublicKey(PublicKey), // UNTAGGED TYPES: #[serde(untagged)] Array(Array), @@ -95,6 +99,12 @@ impl From for TypedValue { } } +impl From for TypedValue { + fn from(p: PublicKey) -> Self { + TypedValue::PublicKey(p) + } +} + impl From for TypedValue { fn from(s: Set) -> Self { TypedValue::Set(s) @@ -159,6 +169,7 @@ impl fmt::Display for TypedValue { TypedValue::Set(s) => write!(f, "set:{}", s.commitment()), TypedValue::Array(a) => write!(f, "arr:{}", a.commitment()), TypedValue::Raw(v) => write!(f, "{}", v), + TypedValue::PublicKey(p) => write!(f, "ecGFp5_pt:({},{})", p.x, p.u), } } } @@ -173,6 +184,7 @@ impl From<&TypedValue> for RawValue { TypedValue::Set(s) => RawValue::from(s.commitment()), TypedValue::Array(a) => RawValue::from(a.commitment()), TypedValue::Raw(v) => *v, + TypedValue::PublicKey(p) => RawValue::from(hash_fields(&p.as_fields())), } } }