diff --git a/halo2_backend/src/poly/kzg/multiopen/gwc.rs b/halo2_backend/src/poly/kzg/multiopen/gwc.rs index 8c8e056e83..fd3e5ef1f9 100644 --- a/halo2_backend/src/poly/kzg/multiopen/gwc.rs +++ b/halo2_backend/src/poly/kzg/multiopen/gwc.rs @@ -22,10 +22,28 @@ struct CommitmentData> { _marker: PhantomData, } -fn construct_intermediate_sets>(queries: I) -> Vec> +fn construct_intermediate_sets>( + queries: I, +) -> Option>> where I: IntoIterator + Clone, { + let queries = queries.into_iter().collect::>(); + + // Caller tried to provide two different evaluations for the same + // commitment. Permitting this would be unsound. + { + let mut query_set: Vec<(Q::Commitment, F)> = vec![]; + for query in queries.iter() { + let commitment = query.get_commitment(); + let rotation = query.get_point(); + if query_set.contains(&(commitment, rotation)) { + return None; + } + query_set.push((commitment, rotation)); + } + } + let mut point_query_map: Vec<(F, Vec)> = Vec::new(); for query in queries { if let Some(pos) = point_query_map @@ -39,12 +57,14 @@ where } } - point_query_map - .into_iter() - .map(|(point, queries)| CommitmentData { - queries, - point, - _marker: PhantomData, - }) - .collect() + Some( + point_query_map + .into_iter() + .map(|(point, queries)| CommitmentData { + queries, + point, + _marker: PhantomData, + }) + .collect(), + ) } diff --git a/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs b/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs index add80bbfeb..877b4afb35 100644 --- a/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs +++ b/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs @@ -53,7 +53,12 @@ where R: RngCore, { let v: ChallengeV<_> = transcript.squeeze_challenge_scalar(); - let commitment_data = construct_intermediate_sets(queries); + let commitment_data = construct_intermediate_sets(queries).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "queries iterator contains mismatching evaluations", + ) + })?; for commitment_at_a_point in commitment_data.iter() { let z = commitment_at_a_point.point; diff --git a/halo2_backend/src/poly/kzg/multiopen/gwc/verifier.rs b/halo2_backend/src/poly/kzg/multiopen/gwc/verifier.rs index 196f330819..f9b3da18bc 100644 --- a/halo2_backend/src/poly/kzg/multiopen/gwc/verifier.rs +++ b/halo2_backend/src/poly/kzg/multiopen/gwc/verifier.rs @@ -57,7 +57,7 @@ where { let v: ChallengeV<_> = transcript.squeeze_challenge_scalar(); - let commitment_data = construct_intermediate_sets(queries); + let commitment_data = construct_intermediate_sets(queries).ok_or(Error::OpeningError)?; let w: Vec = (0..commitment_data.len()) .map(|_| transcript.read_point().map_err(|_| Error::SamplingError)) diff --git a/halo2_backend/src/poly/kzg/multiopen/shplonk.rs b/halo2_backend/src/poly/kzg/multiopen/shplonk.rs index 5f963f4049..0a6183c784 100644 --- a/halo2_backend/src/poly/kzg/multiopen/shplonk.rs +++ b/halo2_backend/src/poly/kzg/multiopen/shplonk.rs @@ -47,12 +47,26 @@ struct IntermediateSets> { fn construct_intermediate_sets>( queries: I, -) -> IntermediateSets +) -> Option> where I: IntoIterator + Clone, { let queries = queries.into_iter().collect::>(); + // Caller tried to provide two different evaluations for the same + // commitment. Permitting this would be unsound. + { + let mut query_set: Vec<(Q::Commitment, F)> = vec![]; + for query in queries.iter() { + let commitment = query.get_commitment(); + let rotation = query.get_point(); + if query_set.contains(&(commitment, rotation)) { + return None; + } + query_set.push((commitment, rotation)); + } + } + // Find evaluation of a commitment at a rotation let get_eval = |commitment: Q::Commitment, rotation: F| -> F { queries @@ -133,10 +147,10 @@ where }) .collect::>>(); - IntermediateSets { + Some(IntermediateSets { rotation_sets, super_point_set, - } + }) } #[cfg(test)] @@ -144,7 +158,11 @@ mod proptests { use super::{construct_intermediate_sets, Commitment, IntermediateSets}; use halo2_middleware::ff::FromUniformBytes; use halo2curves::pasta::Fp; - use proptest::{collection::vec, prelude::*, sample::select}; + use proptest::{ + collection::{hash_set, vec}, + prelude::*, + sample::select, + }; use std::convert::TryFrom; #[derive(Debug, Clone)] @@ -194,10 +212,16 @@ mod proptests { prop_compose! { // Mapping from column index to point index. fn arb_queries_inner(num_points: usize, num_cols: usize, num_queries: usize)( - col_indices in vec(select((0..num_cols).collect::>()), num_queries), - point_indices in vec(select((0..num_points).collect::>()), num_queries) + // Use a HashSet to ensure we sample distinct (column, point) queries. + queries in hash_set( + ( + select((0..num_cols).collect::>()), + select((0..num_points).collect::>()), + ), + num_queries, + ) ) -> Vec<(usize, usize)> { - col_indices.into_iter().zip(point_indices.into_iter()).collect() + queries.into_iter().collect() } } @@ -229,14 +253,14 @@ mod proptests { fn test_intermediate_sets( (queries_1, queries_2) in compare_queries(8, 8, 16) ) { - let IntermediateSets { rotation_sets, .. } = construct_intermediate_sets(queries_1); + let IntermediateSets { rotation_sets, .. } = construct_intermediate_sets(queries_1).ok_or_else(|| TestCaseError::Fail("mismatched evals".into()))?; let commitment_sets = rotation_sets.iter().map(|data| data.commitments.iter().map(Commitment::get).collect::>() ).collect::>(); // It shouldn't matter what the point or eval values are; we should get // the same exact point set indices and point indices again. - let IntermediateSets { rotation_sets: new_rotation_sets, .. } = construct_intermediate_sets(queries_2); + let IntermediateSets { rotation_sets: new_rotation_sets, .. } = construct_intermediate_sets(queries_2).ok_or_else(|| TestCaseError::Fail("mismatched evals".into()))?; let new_commitment_sets = new_rotation_sets.iter().map(|data| data.commitments.iter().map(Commitment::get).collect::>() ).collect::>(); diff --git a/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs b/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs index 3bdfc68a5d..0964646bbe 100644 --- a/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs +++ b/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs @@ -173,7 +173,12 @@ where } }; - let intermediate_sets = construct_intermediate_sets(queries); + let intermediate_sets = construct_intermediate_sets(queries).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "queries iterator contains mismatching evaluations", + ) + })?; let (rotation_sets, super_point_set) = ( intermediate_sets.rotation_sets, intermediate_sets.super_point_set, diff --git a/halo2_backend/src/poly/kzg/multiopen/shplonk/verifier.rs b/halo2_backend/src/poly/kzg/multiopen/shplonk/verifier.rs index 4fad8dece3..18a70336b6 100644 --- a/halo2_backend/src/poly/kzg/multiopen/shplonk/verifier.rs +++ b/halo2_backend/src/poly/kzg/multiopen/shplonk/verifier.rs @@ -60,7 +60,7 @@ where where I: IntoIterator>> + Clone, { - let intermediate_sets = construct_intermediate_sets(queries); + let intermediate_sets = construct_intermediate_sets(queries).ok_or(Error::OpeningError)?; let (rotation_sets, super_point_set) = ( intermediate_sets.rotation_sets, intermediate_sets.super_point_set, diff --git a/halo2_backend/src/poly/multiopen_test.rs b/halo2_backend/src/poly/multiopen_test.rs index a2246ac6f9..bcfccf34e0 100644 --- a/halo2_backend/src/poly/multiopen_test.rs +++ b/halo2_backend/src/poly/multiopen_test.rs @@ -90,6 +90,64 @@ mod test { >(&verifier_params, &proof[..], true); } + #[test] + fn test_identical_queries_gwc() { + use crate::poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG}; + use crate::poly::kzg::multiopen::{ProverGWC, VerifierGWC}; + use crate::poly::kzg::strategy::AccumulatorStrategy; + use halo2curves::bn256::Bn256; + + const K: u32 = 4; + + let engine = H2cEngine::new(); + let params = ParamsKZG::::new(K); + + let proof = create_proof::< + KZGCommitmentScheme, + ProverGWC<_>, + _, + Blake2bWrite<_, _, Challenge255<_>>, + >(&engine, ¶ms); + + let verifier_params = params.verifier_params(); + verify_identical_queries::< + KZGCommitmentScheme, + VerifierGWC<_>, + _, + Blake2bRead<_, _, Challenge255<_>>, + AccumulatorStrategy<_>, + >(&verifier_params, &proof[..]); + } + + #[test] + fn test_identical_queries_shplonk() { + use crate::poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG}; + use crate::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK}; + use crate::poly::kzg::strategy::AccumulatorStrategy; + use halo2curves::bn256::Bn256; + + const K: u32 = 4; + + let engine = H2cEngine::new(); + let params = ParamsKZG::::new(K); + + let proof = create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<_>, + _, + Blake2bWrite<_, _, Challenge255<_>>, + >(&engine, ¶ms); + + let verifier_params = params.verifier_params(); + verify_identical_queries::< + KZGCommitmentScheme, + VerifierSHPLONK<_>, + _, + Blake2bRead<_, _, Challenge255<_>>, + AccumulatorStrategy<_>, + >(&verifier_params, &proof[..]); + } + fn verify< 'a, 'params, @@ -223,4 +281,54 @@ mod test { transcript.finalize() } + + fn verify_identical_queries< + 'a, + 'params, + Scheme: CommitmentScheme, + V: Verifier<'params, Scheme>, + E: EncodedChallenge, + T: TranscriptReadBuffer<&'a [u8], Scheme::Curve, E>, + Strategy: VerificationStrategy<'params, Scheme, V> + std::fmt::Debug, + >( + params: &'params Scheme::ParamsVerifier, + proof: &'a [u8], + ) { + use assert_matches::assert_matches; + use group::ff::Field; + + let verifier = V::new(); + + let mut transcript = T::init(proof); + + let a = transcript.read_point().unwrap(); + let b = transcript.read_point().unwrap(); + let c = transcript.read_point().unwrap(); + + let x = transcript.squeeze_challenge(); + let y = transcript.squeeze_challenge(); + + let avx = transcript.read_scalar().unwrap(); + let bvx = transcript.read_scalar().unwrap(); + let cvy = transcript.read_scalar().unwrap(); + + let bvx_bad = ::Scalar::random(OsRng); + + #[rustfmt::skip] + let invalid_queries = std::iter::empty() + .chain(Some(VerifierQuery::new_commitment(&a, x.get_scalar(), avx))) + .chain(Some(VerifierQuery::new_commitment(&b, x.get_scalar(), bvx))) + .chain(Some(VerifierQuery::new_commitment(&b, x.get_scalar(), bvx_bad))) // This is wrong. + .chain(Some(VerifierQuery::new_commitment(&c, y.get_scalar(), cvy))); + + let strategy = Strategy::new(params); + assert_matches!( + strategy.process(|msm_accumulator| { + verifier + .verify_proof(&mut transcript, invalid_queries.clone(), msm_accumulator) + .map_err(|_| Error::Opening) + }), + Err(Error::Opening) + ); + } }