Skip to content

Commit 5f31575

Browse files
committed
feat(error-handling): improve group deserialization and refactor SchnorrProtocol
- feat: add error handling to group deserialization functions - refactor: simplify internal expressions in SchnorrProtocol's SigmaProtocol trait implementation - fix: apply various minor code improvements and corrections
1 parent 1a109e9 commit 5f31575

File tree

5 files changed

+51
-49
lines changed

5 files changed

+51
-49
lines changed

src/errors.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ pub enum ProofError {
1717
VerificationFailure,
1818
/// Occurs during batch verification if the batch parameters do not have the right size.
1919
#[error("Mismatched parameter sizes for batch verification.")]
20-
BatchSizeMismatch,
20+
ProofSizeMismatch,
2121
/// Occurs when a feature is not implemented yet.
2222
#[error("The method is not yet implemented for this struct")]
2323
NotImplemented(&'static str),
2424
/// Serialization of a group element/scalar failed
2525
#[error("Serialization of a group element/scalar failed")]
26-
/// Other error
2726
GroupSerializationFailure,
27+
/// Other error
2828
#[error("Other")]
2929
Other,
3030
}

src/group_serialization.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
use ff::PrimeField;
22
use group::{Group, GroupEncoding};
33

4+
use crate::ProofError;
5+
46
pub fn serialize_element<G: Group + GroupEncoding>(element: &G) -> Vec<u8> {
57
element.to_bytes().as_ref().to_vec()
68
}
79

8-
pub fn deserialize_element<G: Group + GroupEncoding>(data: &[u8]) -> Option<G> {
10+
pub fn deserialize_element<G: Group + GroupEncoding>(data: &[u8]) -> Result<G, ProofError> {
911
let element_len = G::Repr::default().as_ref().len();
1012
if data.len() != element_len {
11-
return None;
13+
return Err(ProofError::GroupSerializationFailure);
1214
}
1315

1416
let mut repr = G::Repr::default();
1517
repr.as_mut().copy_from_slice(data);
1618
let ct_point = G::from_bytes(&repr);
17-
1819
if ct_point.is_some().into() {
1920
let point = ct_point.unwrap();
20-
Some(point)
21+
Ok(point)
2122
} else {
22-
None
23+
Err(ProofError::GroupSerializationFailure)
2324
}
2425
}
2526

@@ -29,12 +30,12 @@ pub fn serialize_scalar<G: Group>(scalar: &G::Scalar) -> Vec<u8> {
2930
scalar_bytes
3031
}
3132

32-
pub fn deserialize_scalar<G: Group>(data: &[u8]) -> Option<G::Scalar> {
33+
pub fn deserialize_scalar<G: Group>(data: &[u8]) -> Result<G::Scalar, ProofError> {
3334
let scalar_len = <<G as Group>::Scalar as PrimeField>::Repr::default()
3435
.as_ref()
3536
.len();
3637
if data.len() != scalar_len {
37-
return None;
38+
return Err(ProofError::GroupSerializationFailure);
3839
}
3940

4041
let mut repr = <<G as Group>::Scalar as PrimeField>::Repr::default();
@@ -43,6 +44,11 @@ pub fn deserialize_scalar<G: Group>(data: &[u8]) -> Option<G::Scalar> {
4344
tmp.reverse();
4445
tmp
4546
});
46-
47-
G::Scalar::from_repr(repr).into()
47+
let ct_scalar = G::Scalar::from_repr(repr);
48+
if ct_scalar.is_some().into() {
49+
let scalar = ct_scalar.unwrap();
50+
Ok(scalar)
51+
} else {
52+
Err(ProofError::GroupSerializationFailure)
53+
}
4854
}

src/proof_composition.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
8383
let mut resp_cursor = 0;
8484

8585
for protocol in &self.0 {
86-
let commit_len = protocol.points_nb();
86+
let commit_len = protocol.statements_nb();
8787
let resp_len = protocol.scalars_nb();
8888

8989
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
@@ -108,7 +108,7 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
108108
let mut resp_cursor = 0;
109109

110110
for protocol in &self.0 {
111-
let commit_len = protocol.points_nb();
111+
let commit_len = protocol.statements_nb();
112112
let resp_len = protocol.scalars_nb();
113113

114114
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
@@ -138,7 +138,7 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
138138
.len();
139139

140140
for protocol in &self.0 {
141-
let commit_nb = protocol.points_nb();
141+
let commit_nb = protocol.statements_nb();
142142
let response_nb = protocol.scalars_nb();
143143
let proof_len = response_nb * scalar_size + commit_nb * point_size;
144144
let (commit, resp) =
@@ -261,7 +261,7 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
261261
let mut commit_cursor = 0;
262262
let mut resp_cursor = 0;
263263
for (i, protocol) in self.0.iter().enumerate() {
264-
let commit_len = protocol.points_nb();
264+
let commit_len = protocol.statements_nb();
265265
let resp_len = protocol.scalars_nb();
266266
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
267267
let resp = &response.1[resp_cursor..(resp_cursor + resp_len)];
@@ -285,7 +285,7 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
285285
let mut resp_cursor = 0;
286286

287287
for (i, protocol) in self.0.iter().enumerate() {
288-
let commit_len = protocol.points_nb();
288+
let commit_len = protocol.statements_nb();
289289
let resp_len = protocol.scalars_nb();
290290

291291
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
@@ -316,7 +316,7 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
316316
.len();
317317

318318
for protocol in &self.0 {
319-
let commit_nb = protocol.points_nb();
319+
let commit_nb = protocol.statements_nb();
320320
let response_nb = protocol.scalars_nb();
321321
let proof_len = response_nb * scalar_size + commit_nb * point_size;
322322
let (commit, resp) =

src/schnorr_protocol.rs

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ impl<G: Group + GroupEncoding> SchnorrProtocol<G> {
3838
self.0.morphism.num_scalars
3939
}
4040

41-
pub fn points_nb(&self) -> usize {
41+
pub fn statements_nb(&self) -> usize {
4242
self.0.morphism.num_statements()
4343
}
4444

@@ -58,6 +58,10 @@ impl<G: Group + GroupEncoding> SchnorrProtocol<G> {
5858
self.0.set_elements(elements);
5959
}
6060

61+
pub fn evaluate(&self, scalars: &[<G as Group>::Scalar]) -> Vec<G> {
62+
self.0.morphism.evaluate(scalars)
63+
}
64+
6165
pub fn image(&self) -> Vec<G> {
6266
self.0.image()
6367
}
@@ -83,11 +87,11 @@ where
8387
return Err(ProofError::Other);
8488
}
8589

86-
let nonces: Vec<G::Scalar> = (0..self.0.morphism.num_scalars)
90+
let nonces: Vec<G::Scalar> = (0..self.scalars_nb())
8791
.map(|_| G::Scalar::random(&mut rng))
8892
.collect();
8993
let prover_state = (nonces.clone(), witness.clone());
90-
let commitment = self.0.morphism.evaluate(&nonces);
94+
let commitment = self.evaluate(&nonces);
9195
Ok((commitment, prover_state))
9296
}
9397

@@ -102,7 +106,7 @@ where
102106
}
103107

104108
let mut responses = Vec::new();
105-
for i in 0..self.0.morphism.num_scalars {
109+
for i in 0..self.scalars_nb() {
106110
responses.push(state.0[i] + state.1[i] * challenge);
107111
}
108112
Ok(responses)
@@ -115,14 +119,10 @@ where
115119
challenge: &Self::Challenge,
116120
response: &Self::Response,
117121
) -> Result<(), ProofError> {
118-
let lhs = self.0.morphism.evaluate(response);
122+
let lhs = self.evaluate(response);
119123

120124
let mut rhs = Vec::new();
121-
for (i, g) in commitment
122-
.iter()
123-
.enumerate()
124-
.take(self.0.morphism.num_statements())
125-
{
125+
for (i, g) in commitment.iter().enumerate().take(self.statements_nb()) {
126126
rhs.push(self.0.morphism.group_elements[self.0.image[i].index()] * challenge + g);
127127
}
128128

@@ -140,8 +140,8 @@ where
140140
response: &Self::Response,
141141
) -> Result<Vec<u8>, ProofError> {
142142
let mut bytes = Vec::new();
143-
let commit_nb = self.0.morphism.num_statements();
144-
let response_nb = self.0.morphism.num_scalars;
143+
let commit_nb = self.statements_nb();
144+
let response_nb = self.scalars_nb();
145145

146146
// Serialize commitments
147147
for commit in commitment.iter().take(commit_nb) {
@@ -160,8 +160,8 @@ where
160160
&self,
161161
data: &[u8],
162162
) -> Result<(Self::Commitment, Self::Response), ProofError> {
163-
let commit_nb = self.0.morphism.num_statements();
164-
let response_nb = self.0.morphism.num_scalars;
163+
let commit_nb = self.statements_nb();
164+
let response_nb = self.scalars_nb();
165165

166166
let commit_size = G::generator().to_bytes().as_ref().len();
167167
let response_size = <<G as Group>::Scalar as PrimeField>::Repr::default()
@@ -170,7 +170,7 @@ where
170170

171171
let expected_len = response_nb * response_size + commit_nb * commit_size;
172172
if data.len() != expected_len {
173-
return Err(ProofError::BatchSizeMismatch);
173+
return Err(ProofError::ProofSizeMismatch);
174174
}
175175

176176
let mut commitments: Self::Commitment = Vec::new();
@@ -181,7 +181,7 @@ where
181181
let end = start + commit_size;
182182

183183
let slice = &data[start..end];
184-
let elem = deserialize_element(slice).ok_or(ProofError::GroupSerializationFailure)?;
184+
let elem = deserialize_element(slice)?;
185185
commitments.push(elem);
186186
}
187187

@@ -190,8 +190,7 @@ where
190190
let end = start + response_size;
191191

192192
let slice = &data[start..end];
193-
let scalar =
194-
deserialize_scalar::<G>(slice).ok_or(ProofError::GroupSerializationFailure)?;
193+
let scalar = deserialize_scalar::<G>(slice)?;
195194
responses.push(scalar);
196195
}
197196

@@ -212,8 +211,8 @@ where
212211
return Err(ProofError::Other);
213212
}
214213

215-
let response_image = self.0.morphism.evaluate(response);
216-
let image = self.0.image();
214+
let response_image = self.evaluate(response);
215+
let image = self.image();
217216

218217
let mut commitment = Vec::new();
219218
for i in 0..image.len() {
@@ -230,7 +229,7 @@ where
230229
response: &Self::Response,
231230
) -> Result<Vec<u8>, ProofError> {
232231
let mut bytes = Vec::new();
233-
let response_nb = self.0.morphism.num_scalars;
232+
let response_nb = self.scalars_nb();
234233

235234
// Serialize challenge
236235
bytes.extend_from_slice(&serialize_scalar::<G>(challenge));
@@ -247,30 +246,28 @@ where
247246
&self,
248247
data: &[u8],
249248
) -> Result<(Self::Challenge, Self::Response), ProofError> {
250-
let response_nb = self.0.morphism.num_scalars;
249+
let response_nb = self.scalars_nb();
251250
let response_size = <<G as Group>::Scalar as PrimeField>::Repr::default()
252251
.as_ref()
253252
.len();
254253

255254
let expected_len = (response_nb + 1) * response_size;
256255

257256
if data.len() != expected_len {
258-
return Err(ProofError::BatchSizeMismatch);
257+
return Err(ProofError::ProofSizeMismatch);
259258
}
260259

261260
let mut responses: Self::Response = Vec::new();
262261

263262
let slice = &data[0..response_size];
264-
let challenge =
265-
deserialize_scalar::<G>(slice).ok_or(ProofError::GroupSerializationFailure)?;
263+
let challenge = deserialize_scalar::<G>(slice)?;
266264

267265
for i in 0..response_nb {
268266
let start = (i + 1) * response_size;
269267
let end = start + response_size;
270268

271269
let slice = &data[start..end];
272-
let scalar =
273-
deserialize_scalar::<G>(slice).ok_or(ProofError::GroupSerializationFailure)?;
270+
let scalar = deserialize_scalar::<G>(slice)?;
274271
responses.push(scalar);
275272
}
276273

@@ -288,7 +285,7 @@ where
288285
rng: &mut (impl RngCore + CryptoRng),
289286
) -> (Self::Commitment, Self::Response) {
290287
let mut response = Vec::new();
291-
response.extend(iter::repeat(G::Scalar::random(rng)).take(self.0.morphism.num_scalars));
288+
response.extend(iter::repeat(G::Scalar::random(rng)).take(self.scalars_nb()));
292289
let commitment = self.get_commitment(challenge, &response).unwrap();
293290
(commitment, response)
294291
}

tests/spec/custom_schnorr_protocol.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ where
116116

117117
let expected_len = scalar_nb * scalar_size + point_nb * point_size;
118118
if data.len() != expected_len {
119-
return Err(ProofError::BatchSizeMismatch);
119+
return Err(ProofError::ProofSizeMismatch);
120120
}
121121

122122
let mut commitments: Self::Commitment = Vec::new();
@@ -127,7 +127,7 @@ where
127127
let end = start + point_size;
128128

129129
let slice = &data[start..end];
130-
let elem = deserialize_element(slice).ok_or(ProofError::GroupSerializationFailure)?;
130+
let elem = deserialize_element(slice)?;
131131
commitments.push(elem);
132132
}
133133

@@ -136,8 +136,7 @@ where
136136
let end = start + scalar_size;
137137

138138
let slice = data[start..end].to_vec();
139-
let scalar =
140-
deserialize_scalar::<G>(&slice).ok_or(ProofError::GroupSerializationFailure)?;
139+
let scalar = deserialize_scalar::<G>(&slice)?;
141140
responses.push(scalar);
142141
}
143142

0 commit comments

Comments
 (0)