Skip to content

Commit 1417937

Browse files
committed
feat: implement Pippenger for msm.
This code is stolen from arkworks-rs/gemini.
1 parent 7c4d11b commit 1417937

File tree

2 files changed

+126
-10
lines changed

2 files changed

+126
-10
lines changed

src/composition.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ use sha3::{Digest, Sha3_256};
2424
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
2525

2626
use crate::errors::InvalidInstance;
27+
use crate::group::serialization::{deserialize_scalars, serialize_scalars};
2728
use crate::{
2829
codec::Shake128DuplexSponge,
2930
errors::Error,
3031
fiat_shamir::Nizk,
3132
linear_relation::{CanonicalLinearRelation, LinearRelation},
3233
traits::{SigmaProtocol, SigmaProtocolSimulator},
3334
};
34-
use crate::group::serialization::{deserialize_scalars, serialize_scalars};
3535

3636
/// A protocol proving knowledge of a witness for a composition of linear relations.
3737
///

src/group/msm.rs

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
1+
use ff::PrimeField;
12
use group::prime::PrimeGroup;
23

4+
/// The result of this function is only approximately `ln(a)`. This is inherited from Zexe and libsnark.
5+
#[inline]
6+
const fn ln_without_floats(a: usize) -> usize {
7+
if a == 0 {
8+
1
9+
} else {
10+
// log2(a) * ln(2), ensure minimum value of 1
11+
let result = (64 - (a - 1).leading_zeros()) as usize * 69 / 100;
12+
if result == 0 {
13+
1
14+
} else {
15+
result
16+
}
17+
}
18+
}
19+
320
/// Trait for performing Multi-Scalar Multiplication (MSM).
421
///
522
/// MSM computes the sum:
@@ -43,16 +60,115 @@ impl<G: PrimeGroup> VariableMultiScalarMul for G {
4360
/// # Panics
4461
/// Panics if `scalars.len() != bases.len()`.
4562
fn msm(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self {
46-
assert_eq!(
47-
scalars.len(),
48-
bases.len(),
49-
"scalars and bases must have the same length"
50-
);
63+
assert_eq!(scalars.len(), bases.len());
64+
65+
if scalars.is_empty() {
66+
return Self::identity();
67+
}
68+
69+
msm_internal(bases, scalars)
70+
}
71+
}
72+
73+
fn msm_internal<G: PrimeGroup>(bases: &[G], scalars: &[G::Scalar]) -> G {
74+
let c = ln_without_floats(scalars.len());
75+
let num_bits = <G::Scalar as PrimeField>::NUM_BITS as usize;
76+
// split `num_bits` into steps of `c`, but skip window 0.
77+
let windows = (0..num_bits).step_by(c);
78+
let buckets_num = 1 << c;
79+
80+
let mut window_buckets = Vec::with_capacity(windows.len());
81+
for window in windows {
82+
window_buckets.push((window, vec![G::identity(); buckets_num]));
83+
}
84+
85+
for (scalar, base) in scalars.into_iter().zip(bases) {
86+
for (w, bucket) in window_buckets.iter_mut() {
87+
let scalar_repr = scalar.to_repr();
88+
let scalar_bytes = scalar_repr.as_ref();
89+
90+
// Extract the relevant bits for this window
91+
let window_start = *w;
92+
let window_end = (window_start + c).min(scalar_bytes.len() * 8);
93+
94+
if window_start >= scalar_bytes.len() * 8 {
95+
continue; // Window is beyond the scalar size
96+
}
5197

52-
let mut acc = Self::identity();
53-
for (s, p) in scalars.iter().zip(bases.iter()) {
54-
acc += *p * s;
98+
let mut scalar_bits = 0u64;
99+
100+
// Extract bits from the byte representation
101+
for bit_idx in window_start..window_end {
102+
let byte_idx = bit_idx / 8;
103+
let bit_in_byte = bit_idx % 8;
104+
105+
if byte_idx < scalar_bytes.len() {
106+
let bit = (scalar_bytes[byte_idx] >> bit_in_byte) & 1;
107+
scalar_bits |= (bit as u64) << (bit_idx - window_start);
108+
}
109+
}
110+
111+
// If the scalar is non-zero, we update the corresponding bucket.
112+
// (Recall that `buckets` doesn't have a zero bucket.)
113+
if scalar_bits != 0 {
114+
bucket[(scalar_bits - 1) as usize].add_assign(base);
115+
}
55116
}
56-
acc
117+
}
118+
119+
let mut window_sums = window_buckets.iter().rev().map(|(_w, bucket)| {
120+
// `running_sum` = sum_{j in i..num_buckets} bucket[j],
121+
// where we iterate backward from i = num_buckets to 0.
122+
let mut bucket_sum = G::identity();
123+
let mut bucket_running_sum = G::identity();
124+
bucket.iter().rev().for_each(|b| {
125+
bucket_running_sum += b;
126+
bucket_sum += &bucket_running_sum;
127+
});
128+
bucket_sum
129+
});
130+
131+
// We're traversing windows from high to low.
132+
let first = window_sums.next().unwrap();
133+
window_sums.fold(first, |mut total, sum_i| {
134+
for _ in 0..c {
135+
total = total.double();
136+
}
137+
total + sum_i
138+
})
139+
}
140+
141+
#[cfg(test)]
142+
mod tests {
143+
use super::*;
144+
use ff::Field;
145+
use group::Group;
146+
147+
#[test]
148+
fn test_msm() {
149+
use bls12_381::{G1Projective, Scalar};
150+
use rand::thread_rng;
151+
152+
let mut rng = thread_rng();
153+
const N: usize = 1024;
154+
155+
// Generate random scalars and bases
156+
let scalars: Vec<Scalar> = (0..N).map(|_| Scalar::random(&mut rng)).collect();
157+
let bases: Vec<G1Projective> = (0..N).map(|_| G1Projective::random(&mut rng)).collect();
158+
159+
// Compute MSM using our optimized implementation
160+
let msm_result = G1Projective::msm(&scalars, &bases);
161+
162+
// Compute reference result using naive scalar multiplication and sum
163+
let naive_result = scalars
164+
.iter()
165+
.zip(bases.iter())
166+
.map(|(scalar, base)| base * scalar)
167+
.fold(G1Projective::identity(), |acc, x| acc + x);
168+
169+
assert_eq!(
170+
msm_result, naive_result,
171+
"MSM result should equal naive computation"
172+
);
57173
}
58174
}

0 commit comments

Comments
 (0)