Skip to content

Commit b1abf32

Browse files
committed
fix(security): add critical challenge verification to OrProtocol
- fix: implement essential security condition in OrProtocol verifier to ensure sum of prover challenges equals verifier challenge - refactor: improve code quality and structure in proof_composition module
1 parent 5f31575 commit b1abf32

File tree

1 file changed

+100
-111
lines changed

1 file changed

+100
-111
lines changed

src/proof_composition.rs

Lines changed: 100 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1-
use ff::PrimeField;
1+
use ff::{Field, PrimeField};
22
use group::{Group, GroupEncoding};
33

44
use crate::{
55
deserialize_scalar, serialize_scalar, ProofError, SchnorrProtocol, SigmaProtocol,
66
SigmaProtocolSimulator,
77
};
88

9+
#[derive(Default)]
910
pub struct AndProtocol<G: Group + GroupEncoding>(pub Vec<SchnorrProtocol<G>>);
1011

11-
impl<G: Group + GroupEncoding> Default for AndProtocol<G> {
12-
fn default() -> Self {
13-
Self::new()
14-
}
15-
}
16-
1712
impl<G: Group + GroupEncoding> AndProtocol<G> {
1813
pub fn new() -> Self {
1914
AndProtocol(Vec::<SchnorrProtocol<G>>::new())
@@ -44,18 +39,17 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
4439
witness: &Self::Witness,
4540
rng: &mut (impl rand::Rng + rand::CryptoRng),
4641
) -> Result<(Self::Commitment, Self::ProverState), ProofError> {
47-
let mut commitment = Vec::new();
48-
let mut state = Vec::new();
4942
let mut cursor = 0;
50-
51-
for protocol in &self.0 {
52-
let witness_len = protocol.scalars_nb();
53-
let p_witness = &witness[cursor..(cursor + witness_len)];
54-
let (commit, pr_state) = protocol.prover_commit(&p_witness.to_vec(), rng)?;
55-
commitment.extend(commit);
56-
state.push(pr_state);
57-
58-
cursor += witness_len;
43+
let mut commitment = Vec::with_capacity(self.0.iter().map(|p| p.statements_nb()).sum());
44+
let mut state = Vec::with_capacity(self.len());
45+
46+
for proto in &self.0 {
47+
let n = proto.scalars_nb();
48+
let proto_witness = witness[cursor..(cursor + n)].to_vec();
49+
let (proto_commit, proto_state) = proto.prover_commit(&proto_witness, rng)?;
50+
commitment.extend(proto_commit);
51+
state.push(proto_state);
52+
cursor += n;
5953
}
6054
Ok((commitment, state))
6155
}
@@ -65,10 +59,10 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
6559
state: Self::ProverState,
6660
challenge: &Self::Challenge,
6761
) -> Result<Self::Response, ProofError> {
68-
let mut response = Vec::new();
69-
for (i, protocol) in self.0.iter().enumerate() {
70-
let resp = protocol.prover_response(state[i].clone(), challenge)?;
71-
response.extend(resp);
62+
let mut response = Vec::with_capacity(self.0.iter().map(|p| p.scalars_nb()).sum());
63+
for (proto, proto_state) in self.0.iter().zip(state) {
64+
let proto_response = proto.prover_response(proto_state, challenge)?;
65+
response.extend(proto_response);
7266
}
7367
Ok(response)
7468
}
@@ -79,20 +73,19 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
7973
challenge: &Self::Challenge,
8074
response: &Self::Response,
8175
) -> Result<(), ProofError> {
82-
let mut commit_cursor = 0;
83-
let mut resp_cursor = 0;
84-
85-
for protocol in &self.0 {
86-
let commit_len = protocol.statements_nb();
87-
let resp_len = protocol.scalars_nb();
76+
let mut c_cursor = 0;
77+
let mut r_cursor = 0;
78+
for proto in &self.0 {
79+
let c_len = proto.statements_nb();
80+
let r_len = proto.scalars_nb();
8881

89-
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
90-
let resp = &response[resp_cursor..(resp_cursor + resp_len)];
82+
let proto_commit = commitment[c_cursor..(c_cursor + c_len)].to_vec();
83+
let proto_resp = response[r_cursor..(r_cursor + r_len)].to_vec();
9184

92-
protocol.verifier(&commit.to_vec(), challenge, &resp.to_vec())?;
85+
proto.verifier(&proto_commit, challenge, &proto_resp)?;
9386

94-
commit_cursor += commit_len;
95-
resp_cursor += resp_len
87+
c_cursor += c_len;
88+
r_cursor += r_len
9689
}
9790
Ok(())
9891
}
@@ -104,22 +97,19 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
10497
response: &Self::Response,
10598
) -> Result<Vec<u8>, ProofError> {
10699
let mut bytes = Vec::new();
107-
let mut commit_cursor = 0;
108-
let mut resp_cursor = 0;
109-
110-
for protocol in &self.0 {
111-
let commit_len = protocol.statements_nb();
112-
let resp_len = protocol.scalars_nb();
113-
114-
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
115-
let resp = &response[resp_cursor..(resp_cursor + resp_len)];
116-
bytes.extend_from_slice(&protocol.serialize_batchable(
117-
&commit.to_vec(),
118-
challenge,
119-
&resp.to_vec(),
120-
)?);
121-
commit_cursor += commit_len;
122-
resp_cursor += resp_len;
100+
let mut c_cursor = 0;
101+
let mut r_cursor = 0;
102+
for proto in &self.0 {
103+
let c_len = proto.statements_nb();
104+
let r_len = proto.scalars_nb();
105+
106+
let proto_commit = commitment[c_cursor..(c_cursor + c_len)].to_vec();
107+
let proto_resp = response[r_cursor..(r_cursor + r_len)].to_vec();
108+
109+
bytes.extend(proto.serialize_batchable(&proto_commit, challenge, &proto_resp)?);
110+
111+
c_cursor += c_len;
112+
r_cursor += r_len;
123113
}
124114
Ok(bytes)
125115
}
@@ -128,42 +118,37 @@ impl<G: Group + GroupEncoding> SigmaProtocol for AndProtocol<G> {
128118
&self,
129119
data: &[u8],
130120
) -> Result<(Self::Commitment, Self::Response), ProofError> {
131-
let mut commitment = Vec::new();
132-
let mut response = Vec::new();
133121
let mut cursor = 0;
122+
let mut commitment = Vec::with_capacity(self.0.iter().map(|p| p.statements_nb()).sum());
123+
let mut response = Vec::with_capacity(self.0.iter().map(|p| p.scalars_nb()).sum());
134124

135125
let point_size = G::generator().to_bytes().as_ref().len();
136126
let scalar_size = <<G as Group>::Scalar as PrimeField>::Repr::default()
137127
.as_ref()
138128
.len();
139129

140-
for protocol in &self.0 {
141-
let commit_nb = protocol.statements_nb();
142-
let response_nb = protocol.scalars_nb();
143-
let proof_len = response_nb * scalar_size + commit_nb * point_size;
144-
let (commit, resp) =
145-
protocol.deserialize_batchable(&data[cursor..(cursor + proof_len)])?;
146-
commitment.extend(commit);
147-
response.extend(resp);
130+
for proto in &self.0 {
131+
let c_nb = proto.statements_nb();
132+
let r_nb = proto.scalars_nb();
133+
let proof_len = r_nb * scalar_size + c_nb * point_size;
134+
let (proto_commit, proto_resp) =
135+
proto.deserialize_batchable(&data[cursor..(cursor + proof_len)])?;
136+
commitment.extend(proto_commit);
137+
response.extend(proto_resp);
148138
cursor += proof_len;
149139
}
150140
Ok((commitment, response))
151141
}
152142
}
153143

144+
#[derive(Default)]
154145
pub struct OrProtocol<G: Group + GroupEncoding>(pub Vec<SchnorrProtocol<G>>);
155146

156147
pub struct Transcript<G: Group> {
157148
challenge: <G as Group>::Scalar,
158149
response: Vec<<G as Group>::Scalar>,
159150
}
160151

161-
impl<G: Group + GroupEncoding> Default for OrProtocol<G> {
162-
fn default() -> Self {
163-
Self::new()
164-
}
165-
}
166-
167152
impl<G: Group + GroupEncoding> OrProtocol<G> {
168153
pub fn new() -> Self {
169154
OrProtocol(Vec::<SchnorrProtocol<G>>::new())
@@ -207,9 +192,9 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
207192
let mut fake_transcripts = Vec::new();
208193
let mut commitment = Vec::new();
209194
let (real_commit, real_state) = self.0[real_index].prover_commit(&witness.1, rng)?;
210-
for (i, protocol) in self.0.iter().enumerate() {
195+
for (i, proto) in self.0.iter().enumerate() {
211196
if i != real_index {
212-
let (commit, challenge, resp) = protocol.simulate_transcript(rng);
197+
let (commit, challenge, resp) = proto.simulate_transcript(rng);
213198
fake_transcripts.push(Transcript {
214199
challenge,
215200
response: resp,
@@ -228,7 +213,10 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
228213
challenge: &Self::Challenge,
229214
) -> Result<Self::Response, ProofError> {
230215
let (real_index, real_state, fake_transcripts) = state;
231-
let mut response = (Vec::new(), Vec::new());
216+
let mut response = (
217+
Vec::with_capacity(self.len()),
218+
Vec::with_capacity(self.0.iter().map(|p| p.scalars_nb()).sum()),
219+
);
232220

233221
let mut real_challenge = *challenge;
234222
for transcript in &fake_transcripts {
@@ -257,21 +245,23 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
257245
response: &Self::Response,
258246
) -> Result<(), ProofError> {
259247
let mut expected_difference = *challenge;
260-
261-
let mut commit_cursor = 0;
262-
let mut resp_cursor = 0;
263-
for (i, protocol) in self.0.iter().enumerate() {
264-
let commit_len = protocol.statements_nb();
265-
let resp_len = protocol.scalars_nb();
266-
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
267-
let resp = &response.1[resp_cursor..(resp_cursor + resp_len)];
268-
protocol.verifier(&commit.to_vec(), &response.0[i], &resp.to_vec())?;
269-
commit_cursor += commit_len;
270-
resp_cursor += resp_len;
271-
272-
expected_difference += response.0[i];
248+
let mut c_cursor = 0;
249+
let mut r_cursor = 0;
250+
for (i, proto) in self.0.iter().enumerate() {
251+
let c_len = proto.statements_nb();
252+
let r_len = proto.scalars_nb();
253+
let proto_commit = commitment[c_cursor..(c_cursor + c_len)].to_vec();
254+
let proto_resp = response.1[r_cursor..(r_cursor + r_len)].to_vec();
255+
proto.verifier(&proto_commit, &response.0[i], &proto_resp)?;
256+
c_cursor += c_len;
257+
r_cursor += r_len;
258+
259+
expected_difference -= response.0[i];
260+
}
261+
match expected_difference.is_zero_vartime() {
262+
true => Ok(()),
263+
false => Err(ProofError::VerificationFailure),
273264
}
274-
Ok(())
275265
}
276266

277267
fn serialize_batchable(
@@ -281,23 +271,19 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
281271
response: &Self::Response,
282272
) -> Result<Vec<u8>, ProofError> {
283273
let mut bytes = Vec::new();
284-
let mut commit_cursor = 0;
285-
let mut resp_cursor = 0;
286-
287-
for (i, protocol) in self.0.iter().enumerate() {
288-
let commit_len = protocol.statements_nb();
289-
let resp_len = protocol.scalars_nb();
290-
291-
let commit = &commitment[commit_cursor..(commit_cursor + commit_len)];
292-
let resp = &response.1[resp_cursor..(resp_cursor + resp_len)];
293-
bytes.extend_from_slice(&protocol.serialize_batchable(
294-
&commit.to_vec(),
295-
&response.0[i],
296-
&resp.to_vec(),
297-
)?);
298-
bytes.extend_from_slice(&serialize_scalar::<G>(&response.0[i]));
299-
commit_cursor += commit_len;
300-
resp_cursor += resp_len;
274+
let mut c_cursor = 0;
275+
let mut r_cursor = 0;
276+
277+
for (i, proto) in self.0.iter().enumerate() {
278+
let c_len = proto.statements_nb();
279+
let r_len = proto.scalars_nb();
280+
281+
let proto_commit = commitment[c_cursor..(c_cursor + c_len)].to_vec();
282+
let proto_resp = response.1[r_cursor..(r_cursor + r_len)].to_vec();
283+
bytes.extend(proto.serialize_batchable(&proto_commit, &response.0[i], &proto_resp)?);
284+
bytes.extend(&serialize_scalar::<G>(&response.0[i]));
285+
c_cursor += c_len;
286+
r_cursor += r_len;
301287
}
302288
Ok(bytes)
303289
}
@@ -306,28 +292,31 @@ impl<G: Group + GroupEncoding> SigmaProtocol for OrProtocol<G> {
306292
&self,
307293
data: &[u8],
308294
) -> Result<(Self::Commitment, Self::Response), ProofError> {
309-
let mut commitment = Vec::new();
310-
let mut response = (Vec::new(), Vec::new());
311295
let mut cursor = 0;
296+
let mut commitment = Vec::with_capacity(self.0.iter().map(|p| p.statements_nb()).sum());
297+
let mut response = (
298+
Vec::with_capacity(self.len()),
299+
Vec::with_capacity(self.0.iter().map(|p| p.scalars_nb()).sum()),
300+
);
312301

313302
let point_size = G::generator().to_bytes().as_ref().len();
314303
let scalar_size = <<G as Group>::Scalar as PrimeField>::Repr::default()
315304
.as_ref()
316305
.len();
317306

318-
for protocol in &self.0 {
319-
let commit_nb = protocol.statements_nb();
320-
let response_nb = protocol.scalars_nb();
321-
let proof_len = response_nb * scalar_size + commit_nb * point_size;
322-
let (commit, resp) =
323-
protocol.deserialize_batchable(&data[cursor..(cursor + proof_len)])?;
324-
let challenge = deserialize_scalar::<G>(
307+
for proto in &self.0 {
308+
let c_nb = proto.statements_nb();
309+
let r_nb = proto.scalars_nb();
310+
let proof_len = r_nb * scalar_size + c_nb * point_size;
311+
let (proto_commit, proto_resp) =
312+
proto.deserialize_batchable(&data[cursor..(cursor + proof_len)])?;
313+
let proto_challenge = deserialize_scalar::<G>(
325314
&data[(cursor + proof_len)..(cursor + proof_len + scalar_size)],
326315
)
327316
.unwrap();
328-
commitment.extend(commit);
329-
response.1.extend(resp);
330-
response.0.push(challenge);
317+
commitment.extend(proto_commit);
318+
response.1.extend(proto_resp);
319+
response.0.push(proto_challenge);
331320

332321
cursor += proof_len + scalar_size;
333322
}

0 commit comments

Comments
 (0)