Skip to content

Commit 6524e85

Browse files
committed
feat(composition): improve error handling and optimize memory usage
- feat: implement comprehensive error handling in proof_composition functions - perf: optimize memory allocations throughout proof_composition module - refactor: remove unnecessary manual Default trait implementations - fix: resolve clippy warnings across codebase
1 parent b1abf32 commit 6524e85

File tree

4 files changed

+80
-33
lines changed

4 files changed

+80
-33
lines changed

src/fiat_shamir.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ use crate::{codec::Codec, CompactProtocol, ProofError, SigmaProtocol};
1818
use group::{Group, GroupEncoding};
1919
use rand::{CryptoRng, RngCore};
2020

21+
type Transcript<P> = (
22+
<P as SigmaProtocol>::Commitment,
23+
<P as SigmaProtocol>::Challenge,
24+
<P as SigmaProtocol>::Response,
25+
);
26+
2127
/// A Fiat-Shamir transformation of a Sigma protocol into a non-interactive proof.
2228
///
2329
/// `NISigmaProtocol` wraps an interactive Sigma protocol `P`
@@ -62,7 +68,7 @@ where
6268
&mut self,
6369
witness: &P::Witness,
6470
rng: &mut (impl RngCore + CryptoRng),
65-
) -> Result<(P::Commitment, P::Challenge, P::Response), ProofError> {
71+
) -> Result<Transcript<P>, ProofError> {
6672
let mut codec = self.hash_state.clone();
6773

6874
let (commitment, prover_state) = self.sigmap.prover_commit(witness, rng)?;

src/group_morphism.rs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub struct LinearCombination {
4242
///
4343
/// It supports dynamic allocation of scalars and elements,
4444
/// and evaluates by performing multi-scalar multiplications.
45+
#[derive(Default)]
4546
pub struct Morphism<G: Group> {
4647
pub linear_combination: Vec<LinearCombination>,
4748
pub group_elements: Vec<G>,
@@ -58,12 +59,6 @@ pub fn msm_pr<G: Group>(scalars: &[G::Scalar], bases: &[G]) -> G {
5859
acc
5960
}
6061

61-
impl<G: Group> Default for Morphism<G> {
62-
fn default() -> Self {
63-
Self::new()
64-
}
65-
}
66-
6762
impl<G: Group> Morphism<G> {
6863
/// Creates a new empty Morphism.
6964
pub fn new() -> Self {
@@ -110,6 +105,7 @@ impl<G: Group> Morphism<G> {
110105
/// Provides a higher-level API to build proof instances from sparse constraints. The equations are manipulated solely through 2 lists:
111106
/// - the index of a set of Group elements (maintained in Morphism)
112107
/// - the index of a set of scalars (provided as input for the execution)
108+
#[derive(Default)]
113109
pub struct GroupMorphismPreimage<G>
114110
where
115111
G: Group + GroupEncoding,
@@ -120,15 +116,6 @@ where
120116
pub image: Vec<PointVar>,
121117
}
122118

123-
impl<G> Default for GroupMorphismPreimage<G>
124-
where
125-
G: Group + GroupEncoding,
126-
{
127-
fn default() -> Self {
128-
Self::new()
129-
}
130-
}
131-
132119
impl<G> GroupMorphismPreimage<G>
133120
where
134121
G: Group + GroupEncoding,

src/proof_composition.rs

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
3939
witness: &Self::Witness,
4040
rng: &mut (impl rand::Rng + rand::CryptoRng),
4141
) -> Result<(Self::Commitment, Self::ProverState), ProofError> {
42+
let expected_w_len: usize = self.0.iter().map(|p| p.scalars_nb()).sum();
43+
if expected_w_len != witness.len() || self.is_empty() {
44+
return Err(ProofError::Other);
45+
}
46+
4247
let mut cursor = 0;
4348
let mut commitment = Vec::with_capacity(self.0.iter().map(|p| p.statements_nb()).sum());
4449
let mut state = Vec::with_capacity(self.len());
@@ -59,6 +64,10 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
5964
state: Self::ProverState,
6065
challenge: &Self::Challenge,
6166
) -> Result<Self::Response, ProofError> {
67+
if state.len() != self.len() {
68+
return Err(ProofError::Other);
69+
}
70+
6271
let mut response = Vec::with_capacity(self.0.iter().map(|p| p.scalars_nb()).sum());
6372
for (proto, proto_state) in self.0.iter().zip(state) {
6473
let proto_response = proto.prover_response(proto_state, challenge)?;
@@ -73,6 +82,12 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
7382
challenge: &Self::Challenge,
7483
response: &Self::Response,
7584
) -> Result<(), ProofError> {
85+
let expected_c_len: usize = self.0.iter().map(|p| p.statements_nb()).sum();
86+
let expected_r_len: usize = self.0.iter().map(|p| p.scalars_nb()).sum();
87+
if commitment.len() != expected_c_len || response.len() != expected_r_len {
88+
return Err(ProofError::Other);
89+
}
90+
7691
let mut c_cursor = 0;
7792
let mut r_cursor = 0;
7893
for proto in &self.0 {
@@ -96,6 +111,12 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
96111
challenge: &Self::Challenge,
97112
response: &Self::Response,
98113
) -> Result<Vec<u8>, ProofError> {
114+
let expected_c_len: usize = self.0.iter().map(|p| p.statements_nb()).sum();
115+
let expected_r_len: usize = self.0.iter().map(|p| p.scalars_nb()).sum();
116+
if commitment.len() != expected_c_len || response.len() != expected_r_len {
117+
return Err(ProofError::Other);
118+
}
119+
99120
let mut bytes = Vec::new();
100121
let mut c_cursor = 0;
101122
let mut r_cursor = 0;
@@ -185,12 +206,13 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
185206
rng: &mut (impl rand::Rng + rand::CryptoRng),
186207
) -> Result<(Self::Commitment, Self::ProverState), ProofError> {
187208
let real_index = witness.0;
188-
if real_index >= self.len() {
209+
let expected_w_len = self.0[real_index].scalars_nb();
210+
if real_index >= self.len() || witness.1.len() != expected_w_len {
189211
return Err(ProofError::Other);
190212
}
191213

192-
let mut fake_transcripts = Vec::new();
193-
let mut commitment = Vec::new();
214+
let mut fake_transcripts = Vec::with_capacity(self.len() - 1);
215+
let mut commitment = Vec::with_capacity(self.0.iter().map(|p| p.statements_nb()).sum());
194216
let (real_commit, real_state) = self.0[real_index].prover_commit(&witness.1, rng)?;
195217
for (i, proto) in self.0.iter().enumerate() {
196218
if i != real_index {
@@ -244,6 +266,16 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
244266
challenge: &Self::Challenge,
245267
response: &Self::Response,
246268
) -> Result<(), ProofError> {
269+
let expected_c_len: usize = self.0.iter().map(|p| p.statements_nb()).sum();
270+
let expected_ch_nb = self.len();
271+
let expected_r_len: usize = self.0.iter().map(|p| p.scalars_nb()).sum();
272+
if commitment.len() != expected_c_len
273+
|| response.0.len() != expected_ch_nb
274+
|| response.1.len() != expected_r_len
275+
{
276+
return Err(ProofError::Other);
277+
}
278+
247279
let mut expected_difference = *challenge;
248280
let mut c_cursor = 0;
249281
let mut r_cursor = 0;
@@ -270,6 +302,16 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
270302
_challenge: &Self::Challenge,
271303
response: &Self::Response,
272304
) -> Result<Vec<u8>, ProofError> {
305+
let expected_c_len: usize = self.0.iter().map(|p| p.statements_nb()).sum();
306+
let expected_ch_nb = self.len();
307+
let expected_r_len: usize = self.0.iter().map(|p| p.scalars_nb()).sum();
308+
if commitment.len() != expected_c_len
309+
|| response.0.len() != expected_ch_nb
310+
|| response.1.len() != expected_r_len
311+
{
312+
return Err(ProofError::Other);
313+
}
314+
273315
let mut bytes = Vec::new();
274316
let mut c_cursor = 0;
275317
let mut r_cursor = 0;
@@ -292,18 +334,27 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
292334
&self,
293335
data: &[u8],
294336
) -> Result<(Self::Commitment, Self::Response), ProofError> {
337+
let point_size = G::generator().to_bytes().as_ref().len();
338+
let scalar_size = <<G as Group>::Scalar as PrimeField>::Repr::default()
339+
.as_ref()
340+
.len();
341+
342+
let expected_d_len: usize = self
343+
.0
344+
.iter()
345+
.map(|p| (p.scalars_nb() + 1) * scalar_size + p.statements_nb() * point_size)
346+
.sum();
347+
if data.len() != expected_d_len {
348+
return Err(ProofError::ProofSizeMismatch);
349+
}
350+
295351
let mut cursor = 0;
296352
let mut commitment = Vec::with_capacity(self.0.iter().map(|p| p.statements_nb()).sum());
297353
let mut response = (
298354
Vec::with_capacity(self.len()),
299355
Vec::with_capacity(self.0.iter().map(|p| p.scalars_nb()).sum()),
300356
);
301357

302-
let point_size = G::generator().to_bytes().as_ref().len();
303-
let scalar_size = <<G as Group>::Scalar as PrimeField>::Repr::default()
304-
.as_ref()
305-
.len();
306-
307358
for proto in &self.0 {
308359
let c_nb = proto.statements_nb();
309360
let r_nb = proto.scalars_nb();

src/schnorr_protocol.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,9 @@ use rand::{CryptoRng, RngCore};
1717
/// A Schnorr protocol proving knowledge some discrete logarithm relation.
1818
///
1919
/// The specific proof instance is defined by a [`GroupMorphismPreimage`] over a group `G`.
20+
#[derive(Default)]
2021
pub struct SchnorrProtocol<G: Group + GroupEncoding>(GroupMorphismPreimage<G>);
2122

22-
impl<G: Group + GroupEncoding> Default for SchnorrProtocol<G> {
23-
fn default() -> Self {
24-
Self::new()
25-
}
26-
}
27-
2823
impl<G: Group + GroupEncoding> SchnorrProtocol<G> {
2924
pub fn new() -> Self {
3025
SchnorrProtocol(GroupMorphismPreimage::<G>::new())
@@ -119,13 +114,15 @@ where
119114
challenge: &Self::Challenge,
120115
response: &Self::Response,
121116
) -> Result<(), ProofError> {
122-
let lhs = self.evaluate(response);
117+
if commitment.len() != self.statements_nb() || response.len() != self.scalars_nb() {
118+
return Err(ProofError::Other);
119+
}
123120

121+
let lhs = self.evaluate(response);
124122
let mut rhs = Vec::new();
125123
for (i, g) in commitment.iter().enumerate().take(self.statements_nb()) {
126124
rhs.push(self.0.morphism.group_elements[self.0.image[i].index()] * challenge + g);
127125
}
128-
129126
match lhs == rhs {
130127
true => Ok(()),
131128
false => Err(ProofError::VerificationFailure),
@@ -139,10 +136,13 @@ where
139136
_challenge: &Self::Challenge,
140137
response: &Self::Response,
141138
) -> Result<Vec<u8>, ProofError> {
142-
let mut bytes = Vec::new();
143139
let commit_nb = self.statements_nb();
144140
let response_nb = self.scalars_nb();
141+
if commitment.len() != commit_nb || response.len() != response_nb {
142+
return Err(ProofError::Other);
143+
}
145144

145+
let mut bytes = Vec::new();
146146
// Serialize commitments
147147
for commit in commitment.iter().take(commit_nb) {
148148
bytes.extend_from_slice(&serialize_element(commit));
@@ -230,6 +230,9 @@ where
230230
) -> Result<Vec<u8>, ProofError> {
231231
let mut bytes = Vec::new();
232232
let response_nb = self.scalars_nb();
233+
if response.len() != response_nb {
234+
return Err(ProofError::Other);
235+
}
233236

234237
// Serialize challenge
235238
bytes.extend_from_slice(&serialize_scalar::<G>(challenge));

0 commit comments

Comments
 (0)