Skip to content

Commit 526dda8

Browse files
committed
wip on using the allocator in CanonicalLinearRelation
1 parent 0d93078 commit 526dda8

File tree

3 files changed

+152
-125
lines changed

3 files changed

+152
-125
lines changed

src/linear_relation/canonical.rs

Lines changed: 136 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use alloc::vec::Vec;
55
use core::iter;
66
use core::marker::PhantomData;
77
#[cfg(not(feature = "std"))]
8-
use hashbrown::HashMap;
8+
use hashbrown::{HashMap, HashSet};
99
#[cfg(feature = "std")]
10-
use std::collections::HashMap;
10+
use std::collections::{HashMap, HashSet};
1111

1212
use ff::Field;
1313
use group::prime::PrimeGroup;
@@ -29,6 +29,7 @@ use crate::serialization::serialize_elements;
2929
/// constraint is of the form: image_i = Σ (scalar_j * group_element_k)
3030
/// without weights or extra scalars.
3131
#[derive(Clone, Debug, Default)]
32+
#[non_exhaustive]
3233
pub struct CanonicalLinearRelation<G: PrimeGroup> {
3334
/// The image group elements (left-hand side of equations)
3435
pub image: Vec<GroupVar<G>>,
@@ -37,21 +38,10 @@ pub struct CanonicalLinearRelation<G: PrimeGroup> {
3738
pub linear_combinations: Vec<Vec<(ScalarVar<G>, GroupVar<G>)>>,
3839
/// The group elements map
3940
pub group_elements: GroupMap<G>,
40-
/// Number of scalar variables
41-
pub num_scalars: usize,
41+
/// Set of scalar variables used in this relation.
42+
pub scalar_vars: HashSet<ScalarVar<G>>,
4243
}
4344

44-
/// Private type alias used to simplify function signatures below.
45-
///
46-
/// The cache is essentially a mapping (GroupVar, Scalar) => GroupVar, which maps the original
47-
/// weighted group vars to a new assignment, such that if a pair appears more than once, it will
48-
/// map to the same group variable in the canonical linear relation.
49-
#[cfg(feature = "std")]
50-
type WeightedGroupCache<G> = HashMap<GroupVar<G>, Vec<(<G as group::Group>::Scalar, GroupVar<G>)>>;
51-
#[cfg(not(feature = "std"))]
52-
type WeightedGroupCache<G> =
53-
HashMap<GroupVar<G>, Vec<(<G as group::Group>::Scalar, GroupVar<G>)>, RandomState>;
54-
5545
impl<G: PrimeGroup> CanonicalLinearRelation<G> {
5646
/// Create a new empty canonical linear relation.
5747
///
@@ -62,7 +52,7 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
6252
image: Vec::new(),
6353
linear_combinations: Vec::new(),
6454
group_elements: GroupMap::default(),
65-
num_scalars: 0,
55+
scalar_vars: HashSet::default(),
6656
}
6757
}
6858

@@ -95,103 +85,6 @@ impl<G: PrimeGroup> CanonicalLinearRelation<G> {
9585
.collect()
9686
}
9787

98-
pub fn scalar_vars(&self) -> impl Iterator<Item = ScalarVar<G>> {
99-
(0..self.num_scalars).map(|i| ScalarVar(i, PhantomData))
100-
}
101-
102-
/// Get or create a GroupVar for a weighted group element, with deduplication
103-
fn get_or_create_weighted_group_var(
104-
&mut self,
105-
group_var: GroupVar<G>,
106-
weight: &G::Scalar,
107-
original_group_elements: &GroupMap<G>,
108-
weighted_group_cache: &mut WeightedGroupCache<G>,
109-
) -> Result<GroupVar<G>, InvalidInstance> {
110-
// Check if we already have this (weight, group_var) combination
111-
let entry = weighted_group_cache.entry(group_var).or_default();
112-
113-
// Find if we already have this weight for this group_var
114-
if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) {
115-
return Ok(*existing_var);
116-
}
117-
118-
// Create new weighted group element
119-
// Use a special case for one, as this is the most common weight.
120-
let original_group_val = original_group_elements.get(group_var)?;
121-
let weighted_group = match *weight == G::Scalar::ONE {
122-
true => original_group_val,
123-
false => original_group_val * weight,
124-
};
125-
126-
// Add to our group elements with new index (length)
127-
let new_var = self.group_elements.allocate_element_with(weighted_group);
128-
129-
// Cache the mapping for this group_var and weight
130-
entry.push((*weight, new_var));
131-
132-
Ok(new_var)
133-
}
134-
135-
/// Process a single constraint equation and add it to the canonical relation.
136-
fn process_constraint<A: Allocator<G = G>>(
137-
&mut self,
138-
&image_var: &GroupVar<G>,
139-
equation: &LinearCombination<G>,
140-
original_relation: &LinearRelation<G, A>,
141-
weighted_group_cache: &mut WeightedGroupCache<G>,
142-
) -> Result<(), InvalidInstance> {
143-
let mut rhs_terms = Vec::new();
144-
145-
// Collect RHS terms that have scalar variables and apply weights
146-
for weighted_term in equation.terms() {
147-
if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar {
148-
let group_var = weighted_term.term.elem;
149-
let weight = &weighted_term.weight;
150-
151-
if weight.is_zero_vartime() {
152-
continue; // Skip zero weights
153-
}
154-
155-
let canonical_group_var = self.get_or_create_weighted_group_var(
156-
group_var,
157-
weight,
158-
&original_relation.heap.elements,
159-
weighted_group_cache,
160-
)?;
161-
162-
rhs_terms.push((scalar_var, canonical_group_var));
163-
}
164-
}
165-
166-
// Compute the canonical image by subtracting constant terms from the original image
167-
let mut canonical_image = original_relation.heap.elements.get(image_var)?;
168-
for weighted_term in equation.terms() {
169-
if let ScalarTerm::Unit = weighted_term.term.scalar {
170-
let group_val = original_relation
171-
.heap
172-
.elements
173-
.get(weighted_term.term.elem)?;
174-
canonical_image -= group_val * weighted_term.weight;
175-
}
176-
}
177-
178-
// Only include constraints that are non-trivial (not zero constraints).
179-
if rhs_terms.is_empty() {
180-
if canonical_image.is_identity().into() {
181-
return Ok(());
182-
}
183-
return Err(InvalidInstance::new(
184-
"trivially false constraint: constraint has empty right-hand side and non-identity left-hand side",
185-
));
186-
}
187-
188-
let canonical_image_group_var = self.group_elements.allocate_element_with(canonical_image);
189-
self.image.push(canonical_image_group_var);
190-
self.linear_combinations.push(rhs_terms);
191-
192-
Ok(())
193-
}
194-
19588
/// Serialize the linear relation to bytes.
19689
///
19790
/// The output format is:
@@ -441,9 +334,12 @@ impl<G: PrimeGroup, A: Allocator<G = G>> TryFrom<&LinearRelation<G, A>>
441334
));
442335
}
443336

444-
let mut canonical = CanonicalLinearRelation::new();
445-
canonical.num_scalars = relation.heap.num_scalars;
337+
let mut builder = CanonicalLinearRelationBuilder::default();
446338

339+
#[cfg(feature = "std")]
340+
let mut scalar_vars = HashSet::<ScalarVar<G>>::new();
341+
#[cfg(not(feature = "std"))]
342+
let mut scalar_vars = HashSet::<ScalarVar<G>>::with_hasher(RandomState::new());
447343
// Cache for deduplicating weighted group elements
448344
#[cfg(feature = "std")]
449345
let mut weighted_group_cache = HashMap::new();
@@ -490,10 +386,10 @@ impl<G: PrimeGroup, A: Allocator<G = G>> TryFrom<&LinearRelation<G, A>>
490386
return Err(InvalidInstance::new("Trivial kernel in this relation"));
491387
}
492388

493-
canonical.process_constraint(lhs, rhs, relation, &mut weighted_group_cache)?;
389+
builder.process_constraint(lhs, rhs, relation)?;
494390
}
495391

496-
Ok(canonical)
392+
Ok(builder.build())
497393
}
498394
}
499395

@@ -513,3 +409,126 @@ impl<G: PrimeGroup + ConstantTimeEq> CanonicalLinearRelation<G> {
513409
.fold(Choice::from(1), |acc, (lhs, rhs)| acc & lhs.ct_eq(&rhs))
514410
}
515411
}
412+
413+
/// Private type alias used to simplify function signatures below.
414+
///
415+
/// The cache is essentially a mapping (GroupVar, Scalar) => GroupVar, which maps the original
416+
/// weighted group vars to a new assignment, such that if a pair appears more than once, it will
417+
/// map to the same group variable in the canonical linear relation.
418+
#[cfg(feature = "std")]
419+
type WeightedGroupCache<G> = HashMap<GroupVar<G>, Vec<(<G as group::Group>::Scalar, GroupVar<G>)>>;
420+
#[cfg(not(feature = "std"))]
421+
type WeightedGroupCache<G> =
422+
HashMap<GroupVar<G>, Vec<(<G as group::Group>::Scalar, GroupVar<G>)>, RandomState>;
423+
424+
#[derive(Debug)]
425+
struct CanonicalLinearRelationBuilder<G: PrimeGroup> {
426+
relation: CanonicalLinearRelation<G>,
427+
weighted_group_cache: WeightedGroupCache<G>,
428+
}
429+
430+
impl<G: PrimeGroup> CanonicalLinearRelationBuilder<G> {
431+
/// Get or create a GroupVar for a weighted group element, with deduplication
432+
fn get_or_create_weighted_group_var<A: Allocator<G = G>>(
433+
&mut self,
434+
group_var: GroupVar<G>,
435+
weight: &G::Scalar,
436+
original_alloc: &A,
437+
) -> Result<GroupVar<G>, InvalidInstance> {
438+
// Check if we already have this (weight, group_var) combination
439+
let entry = self.weighted_group_cache.entry(group_var).or_default();
440+
441+
// Find if we already have this weight for this group_var
442+
if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) {
443+
return Ok(*existing_var);
444+
}
445+
446+
// Create new weighted group element
447+
// Use a special case for one, as this is the most common weight.
448+
let original_group_val = original_alloc.get_element(group_var)?;
449+
let weighted_group = match *weight == G::Scalar::ONE {
450+
true => original_group_val,
451+
false => original_group_val * weight,
452+
};
453+
454+
// Add to our group elements with new index (length)
455+
let new_var = self
456+
.relation
457+
.group_elements
458+
.allocate_element_with(weighted_group);
459+
460+
// Cache the mapping for this group_var and weight
461+
entry.push((*weight, new_var));
462+
463+
Ok(new_var)
464+
}
465+
466+
/// Process a single constraint equation and add it to the canonical relation.
467+
fn process_constraint<A: Allocator<G = G>>(
468+
&mut self,
469+
&image_var: &GroupVar<G>,
470+
equation: &LinearCombination<G>,
471+
allocator: &A,
472+
) -> Result<(), InvalidInstance> {
473+
let mut rhs_terms = Vec::new();
474+
475+
// Collect RHS terms that have scalar variables and apply weights
476+
for weighted_term in equation.terms() {
477+
if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar {
478+
let group_var = weighted_term.term.elem;
479+
let weight = &weighted_term.weight;
480+
481+
if weight.is_zero_vartime() {
482+
continue; // Skip zero weights
483+
}
484+
485+
let canonical_group_var =
486+
self.get_or_create_weighted_group_var(group_var, weight, allocator)?;
487+
488+
rhs_terms.push((scalar_var, canonical_group_var));
489+
self.relation.scalar_vars.insert(scalar_var);
490+
}
491+
}
492+
493+
// Compute the canonical image by subtracting constant terms from the original image
494+
let mut canonical_image = allocator.get_element(image_var)?;
495+
for weighted_term in equation.terms() {
496+
if let ScalarTerm::Unit = weighted_term.term.scalar {
497+
let group_val = allocator.get_element(weighted_term.term.elem)?;
498+
canonical_image -= group_val * weighted_term.weight;
499+
}
500+
}
501+
502+
// Only include constraints that are non-trivial (not zero constraints).
503+
if rhs_terms.is_empty() {
504+
if canonical_image.is_identity().into() {
505+
return Ok(());
506+
}
507+
return Err(InvalidInstance::new(
508+
"trivially false constraint: constraint has empty right-hand side and non-identity left-hand side",
509+
));
510+
}
511+
512+
let canonical_image_group_var = self
513+
.relation
514+
.group_elements
515+
.allocate_element_with(canonical_image);
516+
self.relation.image.push(canonical_image_group_var);
517+
self.relation.linear_combinations.push(rhs_terms);
518+
519+
Ok(())
520+
}
521+
522+
fn build(self) -> CanonicalLinearRelation<G> {
523+
self.relation
524+
}
525+
}
526+
527+
impl<G: PrimeGroup> Default for CanonicalLinearRelationBuilder<G> {
528+
fn default() -> Self {
529+
Self {
530+
relation: CanonicalLinearRelation::new(),
531+
weighted_group_cache: Default::default(),
532+
}
533+
}
534+
}

src/schnorr_protocol.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ impl<G: PrimeGroup> SigmaProtocol for CanonicalLinearRelation<G> {
5050
witness: Self::Witness,
5151
rng: &mut (impl RngCore + CryptoRng),
5252
) -> Result<(Self::Commitment, Self::ProverState), Error> {
53-
if witness.len() < self.num_scalars {
53+
if witness.len() < self.scalar_vars.len() {
5454
return Err(Error::InvalidInstanceWitnessPair);
5555
}
5656

@@ -124,12 +124,14 @@ impl<G: PrimeGroup> SigmaProtocol for CanonicalLinearRelation<G> {
124124
challenge: &Self::Challenge,
125125
response: &Self::Response,
126126
) -> Result<(), Error> {
127-
if commitment.len() != self.image.len() || response.len() != self.num_scalars {
127+
if commitment.len() != self.image.len() || response.len() != self.scalar_vars.len() {
128128
return Err(Error::InvalidInstanceWitnessPair);
129129
}
130130

131131
let response_map = self
132-
.scalar_vars()
132+
.scalar_vars
133+
.iter()
134+
.copied()
133135
.zip(response.iter().copied())
134136
.collect::<ScalarMap<G>>();
135137

@@ -263,7 +265,9 @@ where
263265
/// # Returns
264266
/// - A commitment and response forming a valid proof for the given challenge.
265267
fn simulate_response<R: Rng + CryptoRng>(&self, rng: &mut R) -> Self::Response {
266-
let response: Vec<G::Scalar> = (0..self.num_scalars)
268+
let response: Vec<G::Scalar> = self
269+
.scalar_vars
270+
.iter()
267271
.map(|_| G::Scalar::random(&mut *rng))
268272
.collect();
269273
response
@@ -302,12 +306,14 @@ where
302306
challenge: &Self::Challenge,
303307
response: &Self::Response,
304308
) -> Result<Self::Commitment, Error> {
305-
if response.len() != self.num_scalars {
309+
if response.len() != self.scalar_vars.len() {
306310
return Err(Error::InvalidInstanceWitnessPair);
307311
}
308312

309313
let response_map = self
310-
.scalar_vars()
314+
.scalar_vars
315+
.iter()
316+
.copied()
311317
.zip(response.iter().copied())
312318
.collect::<ScalarMap<G>>();
313319

src/tests/spec/test_vectors.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,13 @@ fn test_spec_testvectors() {
6060
// Decode the witness from the test vector
6161
let witness_vec = crate::group::serialization::deserialize_scalars::<G>(
6262
&vector.witness,
63-
parsed_instance.num_scalars,
63+
parsed_instance.scalar_vars.len(),
6464
)
6565
.expect("Failed to deserialize witness");
6666
let witness = parsed_instance
67-
.scalar_vars()
67+
.scalar_vars
68+
.iter()
69+
.copied()
6870
.zip(witness_vec)
6971
.collect::<ScalarMap<G>>();
7072

0 commit comments

Comments
 (0)