Skip to content

Commit e4ed5fc

Browse files
authored
fix: support OR composition when all supplied branches are true (#77)
1 parent 2e2b3aa commit e4ed5fc

File tree

5 files changed

+128
-55
lines changed

5 files changed

+128
-55
lines changed

examples/simple_composition.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ fn create_relation(P1: G, P2: G, Q: G, H: G) -> ComposedRelation<G> {
4040
rel2.set_element(Q_var, Q);
4141

4242
// Compose into OR protocol
43-
let proto1 = ComposedRelation::from(rel1);
44-
let proto2 = ComposedRelation::from(rel2);
45-
ComposedRelation::Or(vec![proto1, proto2])
43+
ComposedRelation::or([rel1.canonical().unwrap(), rel2.canonical().unwrap()])
4644
}
4745

4846
/// Prove knowledge of one of the witnesses (we know x2 for the DLEQ)

src/composition.rs

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@
2020
2121
use ff::{Field, PrimeField};
2222
use group::prime::PrimeGroup;
23-
use sha3::Digest;
24-
use sha3::Sha3_256;
25-
use subtle::Choice;
26-
use subtle::ConditionallySelectable;
27-
use subtle::ConstantTimeEq;
23+
use sha3::{Digest, Sha3_256};
24+
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
2825

26+
use crate::errors::InvalidInstance;
2927
use crate::{
3028
codec::Shake128DuplexSponge,
3129
errors::Error,
@@ -48,18 +46,29 @@ pub enum ComposedRelation<G: PrimeGroup> {
4846
Or(Vec<ComposedRelation<G>>),
4947
}
5048

49+
impl<G: PrimeGroup> ComposedRelation<G> {
50+
/// Create a [ComposedRelation] for an AND relation from the given list of relations.
51+
pub fn and<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
52+
Self::And(witness.into_iter().map(|x| x.into()).collect())
53+
}
54+
55+
/// Create a [ComposedRelation] for an OR relation from the given list of relations.
56+
pub fn or<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
57+
Self::Or(witness.into_iter().map(|x| x.into()).collect())
58+
}
59+
}
60+
5161
impl<G: PrimeGroup> From<CanonicalLinearRelation<G>> for ComposedRelation<G> {
5262
fn from(value: CanonicalLinearRelation<G>) -> Self {
5363
ComposedRelation::Simple(value)
5464
}
5565
}
5666

57-
impl<G: PrimeGroup> From<LinearRelation<G>> for ComposedRelation<G> {
58-
fn from(value: LinearRelation<G>) -> Self {
59-
Self::Simple(
60-
CanonicalLinearRelation::try_from(value)
61-
.expect("Failed to convert LinearRelation to CanonicalLinearRelation"),
62-
)
67+
impl<G: PrimeGroup> TryFrom<LinearRelation<G>> for ComposedRelation<G> {
68+
type Error = InvalidInstance;
69+
70+
fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
71+
Ok(Self::Simple(CanonicalLinearRelation::try_from(value)?))
6372
}
6473
}
6574

@@ -102,6 +111,26 @@ pub enum ComposedWitness<G: PrimeGroup> {
102111
Or(Vec<ComposedWitness<G>>),
103112
}
104113

114+
impl<G: PrimeGroup> ComposedWitness<G> {
115+
/// Create a [ComposedWitness] for an AND relation from the given list of witnesses.
116+
pub fn and<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
117+
Self::And(witness.into_iter().map(|x| x.into()).collect())
118+
}
119+
120+
/// Create a [ComposedWitness] for an OR relation from the given list of witnesses.
121+
pub fn or<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
122+
Self::Or(witness.into_iter().map(|x| x.into()).collect())
123+
}
124+
}
125+
126+
impl<G: PrimeGroup> From<<CanonicalLinearRelation<G> as SigmaProtocol>::Witness>
127+
for ComposedWitness<G>
128+
{
129+
fn from(value: <CanonicalLinearRelation<G> as SigmaProtocol>::Witness) -> Self {
130+
Self::Simple(value)
131+
}
132+
}
133+
105134
type ComposedChallenge<G> = <CanonicalLinearRelation<G> as SigmaProtocol>::Challenge;
106135

107136
const fn composed_challenge_size<G: PrimeGroup>() -> usize {
@@ -207,37 +236,37 @@ impl<G: PrimeGroup + ConstantTimeEq> ComposedRelation<G> {
207236
let mut commitments = Vec::new();
208237
let mut prover_states = Vec::new();
209238

239+
// Selector value set when the first valid witness is found.
240+
let mut valid_witness_found = Choice::from(0);
210241
for (i, w) in witnesses.iter().enumerate() {
211242
let (commitment, prover_state) = instances[i].prover_commit(w, rng)?;
212243
let (simulated_commitment, simulated_challenge, simulated_response) =
213244
instances[i].simulate_transcript(rng)?;
214245

246+
// TODO: Implement and use ConditionallySelectable here
215247
let valid_witness = instances[i].is_witness_valid(w);
216-
commitments.push(if valid_witness.unwrap_u8() == 1 {
248+
let select_witness = valid_witness & !valid_witness_found;
249+
commitments.push(if select_witness.unwrap_u8() == 1 {
217250
commitment
218251
} else {
219252
simulated_commitment.clone()
220253
});
221254
prover_states.push(ComposedOrProverStateEntry(
222-
valid_witness,
255+
select_witness,
223256
prover_state,
224257
simulated_challenge,
225258
simulated_response,
226259
));
260+
261+
valid_witness_found |= valid_witness;
227262
}
228-
// check that we have only one witness set
229-
let witnesses_found = prover_states
230-
.iter()
231-
.map(|x| x.0.unwrap_u8() as usize)
232-
.sum::<usize>();
233-
let prover_state = prover_states;
234263

235-
if witnesses_found != 1 {
264+
if valid_witness_found.unwrap_u8() == 0 {
236265
Err(Error::InvalidInstanceWitnessPair)
237266
} else {
238267
Ok((
239268
ComposedCommitment::Or(commitments),
240-
ComposedProverState::Or(prover_state),
269+
ComposedProverState::Or(prover_states),
241270
))
242271
}
243272
}
@@ -250,15 +279,13 @@ impl<G: PrimeGroup + ConstantTimeEq> ComposedRelation<G> {
250279
let mut result_challenges = Vec::with_capacity(instances.len());
251280
let mut result_responses = Vec::with_capacity(instances.len());
252281

253-
let prover_states = prover_state;
254-
255282
let mut witness_challenge = *challenge;
256283
for ComposedOrProverStateEntry(
257284
valid_witness,
258285
_prover_state,
259286
simulated_challenge,
260287
_simulated_response,
261-
) in &prover_states
288+
) in &prover_state
262289
{
263290
let c = G::Scalar::conditional_select(
264291
simulated_challenge,
@@ -275,7 +302,7 @@ impl<G: PrimeGroup + ConstantTimeEq> ComposedRelation<G> {
275302
simulated_challenge,
276303
simulated_response,
277304
),
278-
) in instances.iter().zip(prover_states)
305+
) in instances.iter().zip(prover_state)
279306
{
280307
let challenge_i = G::Scalar::conditional_select(
281308
&simulated_challenge,

src/linear_relation/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,4 +541,11 @@ impl<G: PrimeGroup> LinearRelation<G> {
541541
) -> Result<Nizk<CanonicalLinearRelation<G>, Shake128DuplexSponge<G>>, InvalidInstance> {
542542
Ok(Nizk::new(session_identifier, self.try_into()?))
543543
}
544+
545+
/// Construct a [CanonicalLinearRelation] from this generalized linear relation.
546+
///
547+
/// The construction may fail if the linear relation is malformed, unsatisfiable, or trivial.
548+
pub fn canonical(&self) -> Result<CanonicalLinearRelation<G>, InvalidInstance> {
549+
self.try_into()
550+
}
544551
}

src/tests/test_composition.rs

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ type G = RistrettoPoint;
99

1010
#[allow(non_snake_case)]
1111
#[test]
12-
fn test_composition_correctness() {
12+
fn test_composition_example() {
1313
// Composition and verification of proof for the following protocol :
1414
//
1515
// And(
@@ -31,30 +31,15 @@ fn test_composition_correctness() {
3131
.map(|_| <G as Group>::Scalar::random(&mut rng))
3232
.collect::<Vec<_>>();
3333
// second layer protocol definitions
34-
let or_protocol1 = ComposedRelation::<G>::Or(vec![
35-
ComposedRelation::Simple(relation1),
36-
ComposedRelation::Simple(relation2),
37-
]);
38-
let or_witness1 = ComposedWitness::Or(vec![
39-
ComposedWitness::Simple(witness1),
40-
ComposedWitness::Simple(wrong_witness2),
41-
]);
42-
43-
let simple_protocol1 = ComposedRelation::Simple(relation3);
44-
let simple_witness1 = ComposedWitness::Simple(witness3);
45-
46-
let and_protocol1 = ComposedRelation::And(vec![
47-
ComposedRelation::Simple(relation4),
48-
ComposedRelation::Simple(relation5),
49-
]);
50-
let and_witness1 = ComposedWitness::And(vec![
51-
ComposedWitness::Simple(witness4),
52-
ComposedWitness::Simple(witness5),
53-
]);
34+
let or_protocol1 = ComposedRelation::<G>::or([relation1, relation2]);
35+
let or_witness1 = ComposedWitness::or([witness1, wrong_witness2]);
36+
37+
let and_protocol1 = ComposedRelation::and([relation4, relation5]);
38+
let and_witness1 = ComposedWitness::and([witness4, witness5]);
5439

5540
// definition of the final protocol
56-
let instance = ComposedRelation::And(vec![or_protocol1, simple_protocol1, and_protocol1]);
57-
let witness = ComposedWitness::And(vec![or_witness1, simple_witness1, and_witness1]);
41+
let instance = ComposedRelation::and([or_protocol1, relation3.into(), and_protocol1]);
42+
let witness = ComposedWitness::and([or_witness1, witness3.into(), and_witness1]);
5843

5944
let nizk = instance.into_nizk(domain_sep);
6045

@@ -65,3 +50,61 @@ fn test_composition_correctness() {
6550
assert!(nizk.verify_batchable(&proof_batchable_bytes).is_ok());
6651
assert!(nizk.verify_compact(&proof_compact_bytes).is_ok());
6752
}
53+
54+
#[allow(non_snake_case)]
55+
#[test]
56+
fn test_or_one_true() {
57+
// Test composition of a basic OR protocol, with one of the two witnesses being valid.
58+
59+
// definitions of the underlying protocols
60+
let mut rng = OsRng;
61+
let (relation1, witness1) = dleq::<G, _>(&mut rng);
62+
let (relation2, witness2) = dleq::<G, _>(&mut rng);
63+
64+
let wrong_witness1 = (0..witness1.len())
65+
.map(|_| <G as Group>::Scalar::random(&mut rng))
66+
.collect::<Vec<_>>();
67+
let wrong_witness2 = (0..witness2.len())
68+
.map(|_| <G as Group>::Scalar::random(&mut rng))
69+
.collect::<Vec<_>>();
70+
71+
let or_protocol = ComposedRelation::or([relation1, relation2]);
72+
73+
// Construct two witnesses to the protocol, the first and then the second as the true branch.
74+
let witness_or_1 = ComposedWitness::or([witness1, wrong_witness2]);
75+
let witness_or_2 = ComposedWitness::or([wrong_witness1, witness2]);
76+
77+
let nizk = or_protocol.into_nizk(b"test_or_one_true");
78+
79+
for witness in [witness_or_1, witness_or_2] {
80+
// Batchable and compact proofs
81+
let proof_batchable_bytes = nizk.prove_batchable(&witness, &mut OsRng).unwrap();
82+
let proof_compact_bytes = nizk.prove_compact(&witness, &mut OsRng).unwrap();
83+
// Verify proofs
84+
assert!(nizk.verify_batchable(&proof_batchable_bytes).is_ok());
85+
assert!(nizk.verify_compact(&proof_compact_bytes).is_ok());
86+
}
87+
}
88+
89+
#[allow(non_snake_case)]
90+
#[test]
91+
fn test_or_both_true() {
92+
// Test composition of a basic OR protocol, with both of the two witnesses being valid.
93+
94+
// definitions of the underlying protocols
95+
let mut rng = OsRng;
96+
let (relation1, witness1) = dleq::<G, _>(&mut rng);
97+
let (relation2, witness2) = dleq::<G, _>(&mut rng);
98+
99+
let or_protocol = ComposedRelation::or([relation1, relation2]);
100+
101+
let witness = ComposedWitness::or([witness1, witness2]);
102+
let nizk = or_protocol.into_nizk(b"test_or_one_true");
103+
104+
// Batchable and compact proofs
105+
let proof_batchable_bytes = nizk.prove_batchable(&witness, &mut OsRng).unwrap();
106+
let proof_compact_bytes = nizk.prove_compact(&witness, &mut OsRng).unwrap();
107+
// Verify proofs
108+
assert!(nizk.verify_batchable(&proof_batchable_bytes).is_ok());
109+
assert!(nizk.verify_compact(&proof_compact_bytes).is_ok());
110+
}

src/tests/test_validation_criteria.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,8 @@ mod proof_validation {
406406
lr2.set_element(eq2, C);
407407

408408
// Create OR composition
409-
let or_relation = ComposedRelation::Or(vec![
410-
ComposedRelation::from(lr1),
411-
ComposedRelation::from(lr2),
412-
]);
409+
let or_relation =
410+
ComposedRelation::or([lr1.canonical().unwrap(), lr2.canonical().unwrap()]);
413411
let nizk = or_relation.into_nizk(b"test_or_relation");
414412

415413
// Create a correct witness for branch 1 (C = y*B)

0 commit comments

Comments
 (0)