Skip to content

Commit 46c2817

Browse files
authored
refactor: ProtocolProverState and ProtocolWitness in OR composition (#65)
1 parent 83f20f7 commit 46c2817

File tree

11 files changed

+100
-102
lines changed

11 files changed

+100
-102
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ let or_protocol = Protocol::Or(vec![
4848
]);
4949

5050
// If we know the second option, create witness for index 1
51-
let witness = ProtocolWitness::Or(1, vec![
52-
ProtocolWitness::And(vec![
51+
let witness = ProtocolWitness::from((1,
52+
ProtocolWitness::from(vec![
5353
ProtocolWitness::Simple(vec![y]),
5454
ProtocolWitness::Simple(vec![z]),
5555
])
56-
]);
56+
));
5757
```
5858

5959
## Examples

examples/simple_composition.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ fn prove(P1: G, x2: Scalar, H: G) -> ProofResult<Vec<u8>> {
5353
let Q = H * x2;
5454

5555
let instance = create_relation(P1, P2, Q, H);
56-
let witness = ComposedWitness::Or(1, vec![ComposedWitness::Simple(vec![x2])]);
56+
let witness = ComposedWitness::from((1, ComposedWitness::Simple(vec![x2])));
5757
let nizk = Nizk::<_, Shake128DuplexSponge<G>>::new(b"or_proof_example", instance);
5858

5959
nizk.prove_batchable(&witness, &mut OsRng)

src/composition.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ pub enum ComposedProverState<G: PrimeGroup> {
8282
And(Vec<ComposedProverState<G>>),
8383
Or(
8484
usize, // real index
85-
Vec<ComposedProverState<G>>, // real ProverState
85+
Box<ComposedProverState<G>>, // real ProverState
8686
(Vec<ComposedChallenge<G>>, Vec<ComposedResponse<G>>), // simulated transcripts
8787
),
8888
}
@@ -99,7 +99,19 @@ pub enum ComposedResponse<G: PrimeGroup> {
9999
pub enum ComposedWitness<G: PrimeGroup> {
100100
Simple(<SchnorrProof<G> as SigmaProtocol>::Witness),
101101
And(Vec<ComposedWitness<G>>),
102-
Or(usize, Vec<ComposedWitness<G>>),
102+
Or(usize, Box<ComposedWitness<G>>),
103+
}
104+
105+
impl<G: PrimeGroup> From<Vec<ComposedWitness<G>>> for ComposedWitness<G> {
106+
fn from(value: Vec<ComposedWitness<G>>) -> Self {
107+
Self::And(value)
108+
}
109+
}
110+
111+
impl<G: PrimeGroup> From<(usize, ComposedWitness<G>)> for ComposedWitness<G> {
112+
fn from((i, witness): (usize, ComposedWitness<G>)) -> Self {
113+
Self::Or(i, Box::new(witness))
114+
}
103115
}
104116

105117
// Structure representing the Challenge type of Protocol as SigmaProtocol
@@ -149,7 +161,7 @@ impl<G: PrimeGroup> SigmaProtocol for ComposedRelation<G> {
149161
let mut simulated_challenges = Vec::new();
150162
let mut simulated_responses = Vec::new();
151163

152-
let (real_commitment, real_state) = ps[*w_index].prover_commit(&w[0], rng)?;
164+
let (real_commitment, real_state) = ps[*w_index].prover_commit(w, rng)?;
153165

154166
for i in (0..ps.len()).filter(|i| i != w_index) {
155167
let (commitment, challenge, response) = ps[i].simulate_transcript(rng)?;
@@ -163,7 +175,7 @@ impl<G: PrimeGroup> SigmaProtocol for ComposedRelation<G> {
163175
ComposedCommitment::Or(commitments),
164176
ComposedProverState::Or(
165177
*w_index,
166-
vec![real_state],
178+
Box::new(real_state),
167179
(simulated_challenges, simulated_responses),
168180
),
169181
))
@@ -208,8 +220,7 @@ impl<G: PrimeGroup> SigmaProtocol for ComposedRelation<G> {
208220
for ch in &simulated_challenges {
209221
real_challenge -= ch;
210222
}
211-
let real_response =
212-
ps[w_index].prover_response(real_state[0].clone(), &real_challenge)?;
223+
let real_response = ps[w_index].prover_response(*real_state, &real_challenge)?;
213224

214225
for (i, _) in ps.iter().enumerate() {
215226
if i == w_index {

src/linear_relation/convert.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,4 @@ impl<G: Group> From<Sum<Weighted<GroupVar<G>, G::Scalar>>> for Sum<Weighted<Term
153153
let sum = sum.0.into_iter().map(|x| x.into()).collect::<Vec<_>>();
154154
Self(sum)
155155
}
156-
157-
}
156+
}

src/linear_relation/mod.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ pub struct CanonicalLinearRelation<G: PrimeGroup> {
351351
pub num_scalars: usize,
352352
}
353353

354+
type WeightedCache<A, B> = HashMap<B, Vec<(A, B)>>;
355+
354356
impl<G: PrimeGroup> CanonicalLinearRelation<G> {
355357
/// Create a new empty canonical linear relation
356358
pub fn new() -> Self {
@@ -368,7 +370,7 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
368370
group_var: GroupVar<G>,
369371
weight: &G::Scalar,
370372
original_group_elements: &GroupMap<G>,
371-
weighted_group_cache: &mut HashMap<GroupVar<G>, Vec<(G::Scalar, GroupVar<G>)>>,
373+
weighted_group_cache: &mut WeightedCache<G::Scalar, GroupVar<G>>,
372374
) -> Result<GroupVar<G>, Error> {
373375
// Check if we already have this (weight, group_var) combination
374376
let entry = weighted_group_cache.entry(group_var).or_default();
@@ -397,7 +399,7 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
397399
image_var: GroupVar<G>,
398400
equation: &LinearCombination<G>,
399401
original_relation: &LinearRelation<G>,
400-
weighted_group_cache: &mut HashMap<GroupVar<G>, Vec<(G::Scalar, GroupVar<G>)>>,
402+
weighted_group_cache: &mut WeightedCache<G::Scalar, GroupVar<G>>,
401403
) -> Result<(), Error> {
402404
let mut rhs_terms = Vec::new();
403405

@@ -626,8 +628,7 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
626628

627629
let elem = Option::<G>::from(G::from_bytes(&repr)).ok_or_else(|| {
628630
Error::from(InvalidInstance::new(format!(
629-
"Invalid group element at index {}",
630-
i
631+
"Invalid group element at index {i}"
631632
)))
632633
})?;
633634

@@ -698,17 +699,13 @@ impl<G: PrimeGroup> TryFrom<&LinearRelation<G>> for CanonicalLinearRelation<G> {
698699
}
699700

700701
// If any linear combination has no witness variables, the relation is invalid
701-
if relation
702-
.linear_map
703-
.linear_combinations
704-
.iter()
705-
.any(|lc| lc.0.iter().all(|weighted| matches!(weighted.term.scalar, ScalarTerm::Unit)))
706-
{
702+
if relation.linear_map.linear_combinations.iter().any(|lc| {
703+
lc.0.iter()
704+
.all(|weighted| matches!(weighted.term.scalar, ScalarTerm::Unit))
705+
}) {
707706
return Err(Error::InvalidInstanceWitnessPair);
708707
}
709708

710-
711-
712709
let mut canonical = CanonicalLinearRelation::new();
713710
canonical.num_scalars = relation.linear_map.num_scalars;
714711

src/linear_relation/ops.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,9 @@ mod tests {
745745

746746
let diff = x - y;
747747
assert_eq!(diff.terms().len(), 2);
748-
assert_eq!(diff.terms()[0].term, y.into());
748+
assert_eq!(diff.terms()[0].term, y);
749749
assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
750-
assert_eq!(diff.terms()[1].term, x.into());
750+
assert_eq!(diff.terms()[1].term, x);
751751
assert_eq!(diff.terms()[1].weight, Scalar::ONE);
752752
}
753753

src/schnorr_protocol.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ use rand::{CryptoRng, Rng, RngCore};
2727
#[derive(Clone, Default, Debug)]
2828
pub struct SchnorrProof<G: PrimeGroup>(pub CanonicalLinearRelation<G>);
2929

30+
type CommitResult<Commit, Scalar> = (Commit, (Scalar, Scalar));
31+
3032
impl<G: PrimeGroup> SchnorrProof<G> {
3133
pub fn witness_length(&self) -> usize {
3234
self.0.num_scalars
@@ -58,7 +60,7 @@ impl<G: PrimeGroup> SchnorrProof<G> {
5860
&self,
5961
witness: &[G::Scalar],
6062
nonces: &[G::Scalar],
61-
) -> Result<(Vec<G>, (Vec<G::Scalar>, Vec<G::Scalar>)), Error> {
63+
) -> Result<CommitResult<Vec<G>, Vec<G::Scalar>>, Error> {
6264
if witness.len() != self.witness_length() {
6365
return Err(Error::InvalidInstanceWitnessPair);
6466
}

src/tests/spec/test_vectors.rs

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ fn test_spec_testvectors() {
7171
assert_eq!(
7272
parsed_instance.label(),
7373
vector.statement,
74-
"parsed statement doesn't match original for {}",
75-
test_name
74+
"parsed statement doesn't match original for {test_name}"
7675
);
7776

7877
// Create NIZK with the session_id from the test vector
@@ -90,8 +89,7 @@ fn test_spec_testvectors() {
9089
assert_eq!(
9190
computed_iv,
9291
vector.iv.as_slice(),
93-
"Computed IV doesn't match test vector IV for {}",
94-
test_name
92+
"Computed IV doesn't match test vector IV for {test_name}"
9593
);
9694

9795
// Generate proof with the proof generation RNG
@@ -101,70 +99,66 @@ fn test_spec_testvectors() {
10199
// Verify the proof matches
102100
assert_eq!(
103101
proof_bytes, vector.proof,
104-
"proof bytes for test vector {} do not match",
105-
test_name
102+
"proof bytes for test vector {test_name} do not match"
106103
);
107104

108105
// Verify the proof is valid
109106
let verified = nizk.verify_batchable(&proof_bytes).is_ok();
110107
assert!(
111108
verified,
112-
"Fiat-Shamir Schnorr proof verification failed for {}",
113-
test_name
109+
"Fiat-Shamir Schnorr proof verification failed for {test_name}"
114110
);
115111
}
116112
}
117113

118114
fn extract_vectors_new(path: &str) -> Result<HashMap<String, TestVector>, String> {
119115
use std::collections::HashMap;
120116

121-
let content =
122-
fs::read_to_string(path).map_err(|e| format!("Unable to read JSON file: {}", e))?;
123-
let root: JsonValue =
124-
json::parse(&content).map_err(|e| format!("JSON parsing error: {}", e))?;
117+
let content = fs::read_to_string(path).map_err(|e| format!("Unable to read JSON file: {e}"))?;
118+
let root: JsonValue = json::parse(&content).map_err(|e| format!("JSON parsing error: {e}"))?;
125119

126120
let mut vectors = HashMap::new();
127121

128122
for (name, obj) in root.entries() {
129123
let ciphersuite = obj["Ciphersuite"]
130124
.as_str()
131-
.ok_or_else(|| format!("Ciphersuite field not found for {}", name))?
125+
.ok_or_else(|| format!("Ciphersuite field not found for {name}"))?
132126
.to_string();
133127

134128
let session_id = Vec::from_hex(
135129
obj["SessionId"]
136130
.as_str()
137-
.ok_or_else(|| format!("SessionId field not found for {}", name))?,
131+
.ok_or_else(|| format!("SessionId field not found for {name}"))?,
138132
)
139-
.map_err(|e| format!("Invalid hex in SessionId for {}: {}", name, e))?;
133+
.map_err(|e| format!("Invalid hex in SessionId for {name}: {e}"))?;
140134

141135
let statement = Vec::from_hex(
142136
obj["Statement"]
143137
.as_str()
144-
.ok_or_else(|| format!("Statement field not found for {}", name))?,
138+
.ok_or_else(|| format!("Statement field not found for {name}"))?,
145139
)
146-
.map_err(|e| format!("Invalid hex in Statement for {}: {}", name, e))?;
140+
.map_err(|e| format!("Invalid hex in Statement for {name}: {e}"))?;
147141

148142
let witness = Vec::from_hex(
149143
obj["Witness"]
150144
.as_str()
151-
.ok_or_else(|| format!("Witness field not found for {}", name))?,
145+
.ok_or_else(|| format!("Witness field not found for {name}"))?,
152146
)
153-
.map_err(|e| format!("Invalid hex in Witness for {}: {}", name, e))?;
147+
.map_err(|e| format!("Invalid hex in Witness for {name}: {e}"))?;
154148

155149
let iv = Vec::from_hex(
156150
obj["IV"]
157151
.as_str()
158-
.ok_or_else(|| format!("IV field not found for {}", name))?,
152+
.ok_or_else(|| format!("IV field not found for {name}"))?,
159153
)
160-
.map_err(|e| format!("Invalid hex in IV for {}: {}", name, e))?;
154+
.map_err(|e| format!("Invalid hex in IV for {name}: {e}"))?;
161155

162156
let proof = Vec::from_hex(
163157
obj["Proof"]
164158
.as_str()
165-
.ok_or_else(|| format!("Proof field not found for {}", name))?,
159+
.ok_or_else(|| format!("Proof field not found for {name}"))?,
166160
)
167-
.map_err(|e| format!("Invalid hex in Proof for {}: {}", name, e))?;
161+
.map_err(|e| format!("Invalid hex in Proof for {name}: {e}"))?;
168162

169163
vectors.insert(
170164
name.to_string(),

src/tests/test_composition.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn composition_proof_correct() {
3434
ComposedRelation::Simple(SchnorrProof(relation1)),
3535
ComposedRelation::Simple(SchnorrProof(relation2)),
3636
]);
37-
let or_witness1 = ComposedWitness::Or(0, vec![ComposedWitness::Simple(witness1)]);
37+
let or_witness1 = ComposedWitness::from((0, ComposedWitness::Simple(witness1)));
3838

3939
let simple_protocol1 = ComposedRelation::Simple(SchnorrProof(relation3));
4040
let simple_witness1 = ComposedWitness::Simple(witness3);
@@ -43,14 +43,14 @@ fn composition_proof_correct() {
4343
ComposedRelation::Simple(SchnorrProof(relation4)),
4444
ComposedRelation::Simple(SchnorrProof(relation5)),
4545
]);
46-
let and_witness1 = ComposedWitness::And(vec![
46+
let and_witness1 = ComposedWitness::from(vec![
4747
ComposedWitness::Simple(witness4),
4848
ComposedWitness::Simple(witness5),
4949
]);
5050

5151
// definition of the final protocol
5252
let protocol = ComposedRelation::And(vec![or_protocol1, simple_protocol1, and_protocol1]);
53-
let witness = ComposedWitness::And(vec![or_witness1, simple_witness1, and_witness1]);
53+
let witness = ComposedWitness::from(vec![or_witness1, simple_witness1, and_witness1]);
5454

5555
let nizk = Nizk::<ComposedRelation<RistrettoPoint>, Shake128DuplexSponge<G>>::new(
5656
domain_sep, protocol,

0 commit comments

Comments
 (0)