Skip to content

Commit 9b08388

Browse files
committed
chore: wrap ProverState and make it non-clonable.
1 parent 11162ca commit 9b08388

File tree

6 files changed

+42
-52
lines changed

6 files changed

+42
-52
lines changed

src/composition.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ pub enum ComposedCommitment<G: PrimeGroup> {
7676
}
7777

7878
// Structure representing the ProverState type of Protocol as SigmaProtocol
79-
#[derive(Clone)]
8079
pub enum ComposedProverState<G: PrimeGroup> {
8180
Simple(<SchnorrProof<G> as SigmaProtocol>::ProverState),
8281
And(Vec<ComposedProverState<G>>),
@@ -175,7 +174,7 @@ impl<G: PrimeGroup> SigmaProtocol for ComposedRelation<G> {
175174
fn prover_response(
176175
&self,
177176
state: Self::ProverState,
178-
challenge: &Self::Challenge,
177+
mut challenge: &Self::Challenge,
179178
) -> Result<Self::Response, Error> {
180179
match (self, state) {
181180
(ComposedRelation::Simple(p), ComposedProverState::Simple(state)) => p
@@ -201,17 +200,12 @@ impl<G: PrimeGroup> SigmaProtocol for ComposedRelation<G> {
201200
(simulated_challenges, simulated_responses),
202201
),
203202
) => {
204-
let mut challenges = Vec::with_capacity(ps.len());
205-
let mut responses = Vec::with_capacity(ps.len());
206-
207-
let mut real_challenge = *challenge;
208-
for ch in &simulated_challenges {
209-
real_challenge -= ch;
210-
}
211-
let real_response =
212-
ps[w_index].prover_response(real_state[0].clone(), &real_challenge)?;
203+
let n = ps.len();
204+
let mut challenges = Vec::with_capacity(n);
205+
let mut responses = Vec::with_capacity(n);
213206

214207
for (i, _) in ps.iter().enumerate() {
208+
ps
215209
if i == w_index {
216210
challenges.push(real_challenge);
217211
responses.push(real_response.clone());

src/linear_relation/canonical.rs

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
21
use std::collections::HashMap;
32
use std::iter;
43
use std::marker::PhantomData;
54

65
use ff::Field;
76
use group::prime::PrimeGroup;
87

8+
use super::{GroupMap, GroupVar, LinearCombination, LinearRelation, ScalarTerm, ScalarVar};
99
use crate::errors::{Error, InvalidInstance};
10-
use super::{ScalarVar, GroupVar, GroupMap, LinearRelation, LinearCombination, ScalarTerm};
11-
1210

1311
/// A normalized form of the [`LinearRelation`], which is used for serialization into the transcript.
1412
///
@@ -28,6 +26,10 @@ pub struct CanonicalLinearRelation<G: PrimeGroup> {
2826
pub num_scalars: usize,
2927
}
3028

29+
30+
type GroupExpr<G> = Vec<(<G as group::Group>::Scalar, GroupVar<G>)>;
31+
32+
3133
impl<G: PrimeGroup> CanonicalLinearRelation<G> {
3234
/// Create a new empty canonical linear relation
3335
pub fn new() -> Self {
@@ -45,7 +47,7 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
4547
group_var: GroupVar<G>,
4648
weight: &G::Scalar,
4749
original_group_elements: &GroupMap<G>,
48-
weighted_group_cache: &mut HashMap<GroupVar<G>, Vec<(G::Scalar, GroupVar<G>)>>,
50+
weighted_group_cache: &mut HashMap<GroupVar<G>, GroupExpr<G>>,
4951
) -> Result<GroupVar<G>, InvalidInstance> {
5052
// Check if we already have this (weight, group_var) combination
5153
let entry = weighted_group_cache.entry(group_var).or_default();
@@ -74,7 +76,7 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
7476
&image_var: &GroupVar<G>,
7577
equation: &LinearCombination<G>,
7678
original_relation: &LinearRelation<G>,
77-
weighted_group_cache: &mut HashMap<GroupVar<G>, Vec<(G::Scalar, GroupVar<G>)>>,
79+
weighted_group_cache: &mut HashMap<GroupVar<G>, GroupExpr<G>>,
7880
) -> Result<(), InvalidInstance> {
7981
let mut rhs_terms = Vec::new();
8082

@@ -302,10 +304,7 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
302304
repr.as_mut().copy_from_slice(elem_bytes);
303305

304306
let elem = Option::<G>::from(G::from_bytes(&repr)).ok_or_else(|| {
305-
Error::from(InvalidInstance::new(format!(
306-
"Invalid group element at index {}",
307-
i
308-
)))
307+
Error::from(InvalidInstance::new(format!("Invalid group element at index {i}")))
309308
})?;
310309

311310
group_elements_ordered.push(elem);
@@ -367,11 +366,13 @@ impl<G: PrimeGroup> TryFrom<&LinearRelation<G>> for CanonicalLinearRelation<G> {
367366
let mut weighted_group_cache = HashMap::new();
368367

369368
// Process each constraint using the modular helper method
370-
for (lhs, rhs) in
371-
iter::zip(&relation.image, &relation.linear_map.linear_combinations)
372-
{
369+
for (lhs, rhs) in iter::zip(&relation.image, &relation.linear_map.linear_combinations) {
373370
// If the linear combination is trivial, check it directly and skip processing.
374-
if rhs.0.iter().all(|weighted| matches!(weighted.term.scalar, ScalarTerm::Unit)) {
371+
if rhs
372+
.0
373+
.iter()
374+
.all(|weighted| matches!(weighted.term.scalar, ScalarTerm::Unit))
375+
{
375376
let lhs_value = relation
376377
.linear_map
377378
.group_elements
@@ -383,22 +384,21 @@ impl<G: PrimeGroup> TryFrom<&LinearRelation<G>> for CanonicalLinearRelation<G> {
383384
.linear_map
384385
.group_elements
385386
.get(weighted.term.elem)
386-
.unwrap_or_else(|_| panic!("Unassigned group variable in linear combination"))
387+
.unwrap_or_else(|_| {
388+
panic!("Unassigned group variable in linear combination")
389+
})
387390
* weighted.weight
388391
});
389392
if lhs_value != rhs_value {
390-
return Err(InvalidInstance::new("Trivial linear combination does not match image"));
393+
return Err(InvalidInstance::new(
394+
"Trivial linear combination does not match image",
395+
));
391396
} else {
392397
continue; // Skip processing trivial constraints
393398
}
394399
}
395400

396-
canonical.process_constraint(
397-
lhs,
398-
rhs,
399-
relation,
400-
&mut weighted_group_cache,
401-
)?;
401+
canonical.process_constraint(lhs, rhs, relation, &mut weighted_group_cache)?;
402402
}
403403

404404
Ok(canonical)

src/linear_relation/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
//! - [`LinearMap`]: a collection of linear combinations acting on group elements.
99
//! - [`LinearRelation`]: a higher-level structure managing linear maps and their associated images.
1010
11-
use std::marker::PhantomData;
1211
use std::iter;
12+
use std::marker::PhantomData;
1313

1414
use ff::Field;
1515
use group::prime::PrimeGroup;
@@ -24,7 +24,6 @@ mod convert;
2424
/// Implementations of core ops for the linear combination types.
2525
mod ops;
2626

27-
2827
/// Implementation of canonical linear relation.
2928
mod canonical;
3029
pub use canonical::CanonicalLinearRelation;

src/schnorr_protocol.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ use rand::{CryptoRng, Rng, RngCore};
2727
#[derive(Clone, Default, Debug)]
2828
pub struct SchnorrProof<G: PrimeGroup>(pub CanonicalLinearRelation<G>);
2929

30+
struct ProverState<G: PrimeGroup>(Vec<G::Scalar>, Vec<G::Scalar>);
31+
3032
impl<G: PrimeGroup> SchnorrProof<G> {
3133
pub fn witness_length(&self) -> usize {
3234
self.0.num_scalars
@@ -58,7 +60,7 @@ impl<G: PrimeGroup> SchnorrProof<G> {
5860
&self,
5961
witness: &[G::Scalar],
6062
nonces: &[G::Scalar],
61-
) -> Result<(Vec<G>, (Vec<G::Scalar>, Vec<G::Scalar>)), Error> {
63+
) -> Result<(Vec<G>, ProverState<G>), Error> {
6264
if witness.len() != self.witness_length() {
6365
return Err(Error::InvalidInstanceWitnessPair);
6466
}
@@ -79,7 +81,7 @@ impl<G: PrimeGroup> SchnorrProof<G> {
7981
}
8082

8183
let commitment = self.evaluate(nonces)?;
82-
let prover_state = (nonces.to_vec(), witness.to_vec());
84+
let prover_state = ProverState(nonces.to_vec(), witness.to_vec());
8385
Ok((commitment, prover_state))
8486
}
8587
}
@@ -98,7 +100,7 @@ where
98100
G: PrimeGroup,
99101
{
100102
type Commitment = Vec<G>;
101-
type ProverState = (Vec<G::Scalar>, Vec<G::Scalar>);
103+
type ProverState = ProverState<G>;
102104
type Response = Vec<G::Scalar>;
103105
type Witness = Vec<G::Scalar>;
104106
type Challenge = G::Scalar;
@@ -159,7 +161,7 @@ where
159161
prover_state: Self::ProverState,
160162
challenge: &Self::Challenge,
161163
) -> Result<Self::Response, Error> {
162-
let (nonces, witness) = prover_state;
164+
let ProverState(nonces, witness) = prover_state;
163165

164166
if nonces.len() != self.witness_length() || witness.len() != self.witness_length() {
165167
return Err(Error::InvalidInstanceWitnessPair);

src/tests/spec/custom_schnorr_protocol.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@ impl<G: PrimeGroup> From<CanonicalLinearRelation<G>> for DeterministicSchnorrPro
2424
}
2525
}
2626

27-
impl<G: PrimeGroup> DeterministicSchnorrProof<G> {}
28-
2927
impl<G: SRandom + PrimeGroup> SigmaProtocol for DeterministicSchnorrProof<G> {
30-
type Commitment = Vec<G>;
31-
type ProverState = (Vec<G::Scalar>, Vec<G::Scalar>);
32-
type Response = Vec<G::Scalar>;
33-
type Witness = Vec<G::Scalar>;
34-
type Challenge = G::Scalar;
28+
type Commitment = <SchnorrProof<G> as SigmaProtocol>::Commitment;
29+
type ProverState = <SchnorrProof<G> as SigmaProtocol>::ProverState;
30+
type Response = <SchnorrProof<G> as SigmaProtocol>::Response;
31+
type Witness = <SchnorrProof<G> as SigmaProtocol>::Witness;
32+
type Challenge = <SchnorrProof<G> as SigmaProtocol>::Challenge;
3533

3634
fn prover_commit(
3735
&self,

src/tests/test_relations.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ff::{Field, PrimeField};
1+
use ff::Field;
22
use group::prime::PrimeGroup;
33
use rand::rngs::OsRng;
44
use rand::RngCore;
@@ -290,12 +290,12 @@ fn simple_subtractions<G: PrimeGroup, R: RngCore>(
290290
) -> (CanonicalLinearRelation<G>, Vec<G::Scalar>) {
291291
let x = G::Scalar::random(&mut rng);
292292
let B = G::random(&mut rng);
293-
let X = B * (x - G::Scalar::from_u128(1u128));
293+
let X = B * (x - G::Scalar::from(1));
294294

295295
let mut linear_relation = LinearRelation::<G>::new();
296296
let var_x = linear_relation.allocate_scalar();
297297
let var_B = linear_relation.allocate_element();
298-
let var_X = linear_relation.allocate_eq((var_x + (-G::Scalar::from_u128(1u128))) * var_B);
298+
let var_X = linear_relation.allocate_eq((var_x + (-G::Scalar::from(1))) * var_B);
299299
linear_relation.set_element(var_B, B);
300300
linear_relation.set_element(var_X, X);
301301

@@ -314,8 +314,7 @@ fn subtractions_with_shift<G: PrimeGroup, R: RngCore>(
314314
let mut linear_relation = LinearRelation::<G>::new();
315315
let var_x = linear_relation.allocate_scalar();
316316
let var_B = linear_relation.allocate_element();
317-
let var_X =
318-
linear_relation.allocate_eq((var_x + (-G::Scalar::from_u128(1u128))) * var_B + (-var_B));
317+
let var_X = linear_relation.allocate_eq((var_x + (-G::Scalar::from(1))) * var_B + (-var_B));
319318

320319
linear_relation.set_element(var_B, B);
321320
linear_relation.set_element(var_X, X);
@@ -370,7 +369,6 @@ fn cmz_wallet_spend_relation<G: PrimeGroup, R: RngCore>(
370369
(instance, witness)
371370
}
372371

373-
374372
fn nested_affine_relation<G: PrimeGroup, R: RngCore>(
375373
mut rng: &mut R,
376374
) -> (CanonicalLinearRelation<G>, Vec<G::Scalar>) {
@@ -395,7 +393,6 @@ fn nested_affine_relation<G: PrimeGroup, R: RngCore>(
395393
(instance, witness)
396394
}
397395

398-
399396
#[test]
400397
fn test_cmz_wallet_with_fee() {
401398
use group::Group;

0 commit comments

Comments
 (0)