Skip to content

Commit 78c6d97

Browse files
committed
wip: some fixes for OR relation, still one test failing.
1 parent 47b7e6c commit 78c6d97

File tree

2 files changed

+66
-31
lines changed

2 files changed

+66
-31
lines changed

src/composition.rs

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@ pub enum ComposedProverState<G: PrimeGroup + ConstantTimeEq> {
7979
Or(ComposedOrProverState<G>),
8080
}
8181

82-
struct ComposedOrProverState<G: PrimeGroup + ConstantTimeEq> {
83-
prover_states: Vec<(Choice, ComposedProverState<G>, ComposedChallenge<G>, ComposedResponse<G>)>,
84-
}
82+
pub type ComposedOrProverState<G> = Vec<ComposedOrProverStateEntry<G>>;
83+
pub struct ComposedOrProverStateEntry<G: PrimeGroup + ConstantTimeEq>(
84+
Choice,
85+
ComposedProverState<G>,
86+
ComposedChallenge<G>,
87+
ComposedResponse<G>,
88+
);
8589

8690
// Structure representing the Response type of Protocol as SigmaProtocol
8791
#[derive(Clone)]
@@ -105,18 +109,16 @@ const fn composed_challenge_size<G: PrimeGroup>() -> usize {
105109
(G::Scalar::NUM_BITS as usize + 7) / 8
106110
}
107111

108-
109112
impl<G: PrimeGroup + ConstantTimeEq> ComposedRelation<G> {
110113
fn is_witness_valid(&self, witness: &ComposedWitness<G>) -> Choice {
111-
let validity_bit = Choice::from(0);
112114
match (self, witness) {
113115
(ComposedRelation::Simple(instance), ComposedWitness::Simple(witness)) => {
114116
instance.0.is_witness_valid(witness)
115117
}
116118
(ComposedRelation::And(instances), ComposedWitness::And(witnesses)) => instances
117119
.iter()
118120
.zip(witnesses)
119-
.fold(Choice::from(0), |bit, (instance, witness)| {
121+
.fold(Choice::from(1), |bit, (instance, witness)| {
120122
bit & instance.is_witness_valid(witness)
121123
}),
122124
(ComposedRelation::Or(instances), ComposedWitness::Or(witnesses)) => instances
@@ -125,10 +127,10 @@ impl<G: PrimeGroup + ConstantTimeEq> ComposedRelation<G> {
125127
.fold(Choice::from(0), |bit, (instance, witness)| {
126128
bit | instance.is_witness_valid(witness)
127129
}),
128-
_ => unreachable!(),
129-
};
130-
validity_bit
130+
_ => Choice::from(0),
131+
}
131132
}
133+
132134
fn prover_commit_simple(
133135
protocol: &SchnorrProof<G>,
134136
witness: &<SchnorrProof<G> as SigmaProtocol>::Witness,
@@ -212,20 +214,26 @@ impl<G: PrimeGroup + ConstantTimeEq> ComposedRelation<G> {
212214
instances[i].simulate_transcript(rng)?;
213215

214216
let valid_witness = instances[i].is_witness_valid(&w);
215-
commitments.push(if valid_witness.unwrap_u8() == 1 {commitment } else { simulated_commitment.clone()} );
216-
prover_states.push((valid_witness, prover_state, simulated_challenge, simulated_response));
217+
commitments.push(if valid_witness.unwrap_u8() == 1 {
218+
commitment
219+
} else {
220+
simulated_commitment.clone()
221+
});
222+
prover_states.push(ComposedOrProverStateEntry(
223+
valid_witness,
224+
prover_state,
225+
simulated_challenge,
226+
simulated_response,
227+
));
217228
}
218229
// check that we have only one witness set
219230
let witnesses_found = prover_states
220231
.iter()
221232
.map(|x| x.0.unwrap_u8() as usize)
222233
.sum::<usize>();
223-
let prover_state =
224-
ComposedOrProverState {
225-
prover_states,
226-
};
234+
let prover_state = prover_states;
227235

228-
if witnesses_found > 1 {
236+
if witnesses_found != 1 {
229237
return Err(Error::InvalidInstanceWitnessPair);
230238
} else {
231239
Ok((
@@ -238,25 +246,52 @@ impl<G: PrimeGroup + ConstantTimeEq> ComposedRelation<G> {
238246
fn prover_response_or(
239247
instances: &[ComposedRelation<G>],
240248
prover_state: ComposedOrProverState<G>,
241-
&challenge: &ComposedChallenge<G>,
249+
challenge: &ComposedChallenge<G>,
242250
) -> Result<ComposedResponse<G>, Error> {
243251
let mut result_challenges = Vec::with_capacity(instances.len());
244252
let mut result_responses = Vec::with_capacity(instances.len());
245253

246-
let ComposedOrProverState { prover_states } = prover_state;
247-
248-
let mut witness_challenge = challenge;
249-
for (valid_witness, _prover_state, simulated_challenge, _simulated_response) in &prover_states {
250-
let c = G::Scalar::conditional_select(&G::Scalar::ZERO, &simulated_challenge, *valid_witness);
251-
witness_challenge -= c;
254+
let prover_states = prover_state;
255+
256+
let mut witness_challenge = *challenge;
257+
for ComposedOrProverStateEntry(
258+
valid_witness,
259+
_prover_state,
260+
simulated_challenge,
261+
_simulated_response,
262+
) in &prover_states
263+
{
264+
let c = G::Scalar::conditional_select(
265+
&simulated_challenge,
266+
&G::Scalar::ZERO,
267+
*valid_witness,
268+
);
269+
witness_challenge = witness_challenge - c;
252270
}
253-
for (instance, (valid_witness, prover_state, simulated_challenge, simulated_response)) in instances.iter().zip(prover_states) {
254-
let challenge_i = G::Scalar::conditional_select(&witness_challenge, &simulated_challenge, valid_witness);
271+
for (
272+
instance,
273+
ComposedOrProverStateEntry(
274+
valid_witness,
275+
prover_state,
276+
simulated_challenge,
277+
simulated_response,
278+
),
279+
) in instances.iter().zip(prover_states)
280+
{
281+
let challenge_i = G::Scalar::conditional_select(
282+
&simulated_challenge,
283+
&witness_challenge,
284+
valid_witness,
285+
);
255286

256287
let real_response = instance.prover_response(prover_state, &challenge_i)?;
257288

258289
// let response_i = ComposedResponse::conditional_select(&real_response, &simulated_response, *witness_location);
259-
let response_i = if valid_witness.unwrap_u8() == 1 { real_response } else { simulated_response };
290+
let response_i = if valid_witness.unwrap_u8() == 1 {
291+
real_response
292+
} else {
293+
simulated_response
294+
};
260295
result_challenges.push(challenge_i);
261296
result_responses.push(response_i);
262297
}

src/tests/test_validation_criteria.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,15 +393,15 @@ mod proof_validation {
393393
let C = B * y;
394394

395395
// Create the first branch: C = x*A
396-
let mut lr1 = LinearRelation::<G>::new();
396+
let mut lr1 = LinearRelation::new();
397397
let x_var = lr1.allocate_scalar();
398398
let A_var = lr1.allocate_element();
399399
let eq1 = lr1.allocate_eq(x_var * A_var);
400400
lr1.set_element(A_var, A);
401401
lr1.set_element(eq1, C);
402402

403403
// Create the second branch: C = y*B
404-
let mut lr2 = LinearRelation::<G>::new();
404+
let mut lr2 = LinearRelation::new();
405405
let y_var = lr2.allocate_scalar();
406406
let B_var = lr2.allocate_element();
407407
let eq2 = lr2.allocate_eq(y_var * B_var);
@@ -414,7 +414,7 @@ mod proof_validation {
414414
ComposedRelation::from(lr2),
415415
]);
416416

417-
let nizk = Nizk::<_, KeccakByteSchnorrCodec<G>>::new(b"test_or_bug", or_relation);
417+
let nizk =or_relation.into_nizk(b"test_or_bug");
418418

419419
// Create a correct witness for branch 1 (C = y*B)
420420
let witness_correct = ComposedWitness::Or(vec![
@@ -441,12 +441,12 @@ mod proof_validation {
441441
Ok(proof) => {
442442
let verify_result = nizk.verify_batchable(&proof);
443443
println!(
444-
"Bug reproduced: Proof with wrong branch verified: {:?}",
444+
"Proof with wrong branch verified: {:?}",
445445
verify_result.is_ok()
446446
);
447447
assert!(
448448
verify_result.is_err(),
449-
"BUG: Proof should fail when using wrong branch in OR relation, but it passed!"
449+
"Proof should fail when using wrong branch in OR relation, but it passed!"
450450
);
451451
}
452452
Err(e) => {

0 commit comments

Comments
 (0)