Skip to content

Commit a032410

Browse files
committed
Merge branch 'main' into morru/compressed
Signed-off-by: Michele Orrù <[email protected]>
2 parents 129ac6c + e2cc260 commit a032410

File tree

6 files changed

+144
-8
lines changed

6 files changed

+144
-8
lines changed

.github/workflows/rust.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ jobs:
2525
run: cargo build --verbose
2626
- name: Run tests
2727
run: cargo test --verbose
28+
- name: Run benchmark test
29+
# Run the msm benchmark, just to ensure it isn't broken.
30+
run: cargo bench --bench msm -- --quick
2831

2932
no-std-check:
3033
runs-on: ubuntu-latest

Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ authors = [
55
"Nugzari Uzoevi <[email protected]>",
66
"Michele Orrù <[email protected]>",
77
"Ian Goldberg <[email protected]>",
8+
"Victor Snyder-Graf <[email protected]>",
89
"Lénaïck Gouriou <[email protected]>"
910
]
1011
edition = "2021"
@@ -41,15 +42,22 @@ ahash = { version = "0.8", default-features = false }
4142

4243
[dev-dependencies]
4344
bls12_381 = "0.8.0"
45+
criterion = { version = "0.7", features = ["html_reports"] }
4446
curve25519-dalek = { version = "4", default-features = false, features = ["serde", "rand_core", "alloc", "digest", "precomputed-tables", "group"] }
4547
hex = "0.4"
4648
hex-literal = "0.4"
4749
json = "0.12.4"
50+
k256 = { version = "0.13", features = ["arithmetic"] }
4851
libtest-mimic = "0.8.1"
52+
p256 = { version = "0.13", features = ["arithmetic"] }
4953
serde = { version = "1.0.219", features = ["derive"] }
5054
serde_json = "1.0.140"
5155
sha2 = "0.10"
5256

57+
[[bench]]
58+
name = "msm"
59+
harness = false
60+
5361
[profile.dev]
5462
# Makes tests run much faster at the cost of slightly longer builds and worse debug info.
5563
opt-level = 1

benches/msm.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use std::hint::black_box;
2+
3+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
4+
use ff::Field;
5+
use group::Group;
6+
use rand::thread_rng;
7+
use sigma_proofs::VariableMultiScalarMul;
8+
9+
fn bench_msm_curve25519_dalek(c: &mut Criterion) {
10+
use curve25519_dalek::{RistrettoPoint, Scalar};
11+
12+
let mut group = c.benchmark_group("MSM curve25519-dalek RistrettoPoint");
13+
let mut rng = thread_rng();
14+
15+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
16+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
17+
let bases: Vec<RistrettoPoint> = (0..*size)
18+
.map(|_| RistrettoPoint::random(&mut rng))
19+
.collect();
20+
21+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
22+
b.iter(|| RistrettoPoint::msm(black_box(&scalars), black_box(&bases)))
23+
});
24+
}
25+
group.finish();
26+
}
27+
28+
fn bench_msm_k256(c: &mut Criterion) {
29+
use k256::{ProjectivePoint, Scalar};
30+
31+
let mut group = c.benchmark_group("MSM k256 ProjectivePoint");
32+
let mut rng = thread_rng();
33+
34+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
35+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
36+
let bases: Vec<ProjectivePoint> = (0..*size)
37+
.map(|_| ProjectivePoint::random(&mut rng))
38+
.collect();
39+
40+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
41+
b.iter(|| ProjectivePoint::msm(black_box(&scalars), black_box(&bases)))
42+
});
43+
}
44+
group.finish();
45+
}
46+
47+
fn bench_msm_p256(c: &mut Criterion) {
48+
use p256::{ProjectivePoint, Scalar};
49+
50+
let mut group = c.benchmark_group("MSM p256 ProjectivePoint");
51+
let mut rng = thread_rng();
52+
53+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
54+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
55+
let bases: Vec<ProjectivePoint> = (0..*size)
56+
.map(|_| ProjectivePoint::random(&mut rng))
57+
.collect();
58+
59+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
60+
b.iter(|| ProjectivePoint::msm(black_box(&scalars), black_box(&bases)))
61+
});
62+
}
63+
group.finish();
64+
}
65+
66+
fn bench_msm_bls12_381_g1(c: &mut Criterion) {
67+
use bls12_381::{G1Projective, Scalar};
68+
69+
let mut group = c.benchmark_group("MSM bls12_381 G1Projective");
70+
let mut rng = thread_rng();
71+
72+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
73+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
74+
let bases: Vec<G1Projective> = (0..*size).map(|_| G1Projective::random(&mut rng)).collect();
75+
76+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
77+
b.iter(|| G1Projective::msm(black_box(&scalars), black_box(&bases)))
78+
});
79+
}
80+
group.finish();
81+
}
82+
83+
fn bench_msm_bls12_381_g2(c: &mut Criterion) {
84+
use bls12_381::{G2Projective, Scalar};
85+
86+
let mut group = c.benchmark_group("MSM bls12_381 G2Projective");
87+
let mut rng = thread_rng();
88+
89+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
90+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
91+
let bases: Vec<G2Projective> = (0..*size).map(|_| G2Projective::random(&mut rng)).collect();
92+
93+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
94+
b.iter(|| G2Projective::msm(black_box(&scalars), black_box(&bases)))
95+
});
96+
}
97+
group.finish();
98+
}
99+
100+
criterion_group!(
101+
benches,
102+
bench_msm_curve25519_dalek,
103+
bench_msm_k256,
104+
bench_msm_p256,
105+
bench_msm_bls12_381_g1,
106+
bench_msm_bls12_381_g2,
107+
);
108+
criterion_main!(benches);

src/group/msm.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,29 @@ impl<G: PrimeGroup> VariableMultiScalarMul for G {
6767
/// # Panics
6868
/// Panics if `scalars.len() != bases.len()`.
6969
fn msm_unchecked(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self {
70-
msm_internal(bases, scalars)
70+
msm(bases, scalars)
71+
72+
fn msm(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self {
73+
assert_eq!(scalars.len(), bases.len());
74+
75+
// NOTE: Based on the msm benchmark in this repo, msm_pippenger provides improvements over
76+
// msm_naive past a small constant size, but is significantly slower for very small MSMs.
77+
match scalars.len() {
78+
0 => Self::identity(),
79+
1..16 => msm_naive(bases, scalars),
80+
16.. => msm_pippenger(bases, scalars),
81+
}
7182
}
7283
}
7384

74-
fn msm_internal<G: PrimeGroup>(bases: &[G], scalars: &[G::Scalar]) -> G {
85+
/// A naive MSM implementation.
86+
fn msm_naive<G: PrimeGroup>(bases: &[G], scalars: &[G::Scalar]) -> G {
87+
core::iter::zip(bases, scalars).map(|(g, x)| *g * x).sum()
88+
}
89+
90+
/// An MSM implementation that employ's Pippenger's algorithm and works for all groups that
91+
/// implement `PrimeGroup`.
92+
fn msm_pippenger<G: PrimeGroup>(bases: &[G], scalars: &[G::Scalar]) -> G {
7593
let c = ln_without_floats(scalars.len());
7694
let num_bits = <G::Scalar as PrimeField>::NUM_BITS as usize;
7795
// split `num_bits` into steps of `c`, but skip window 0.

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pub(crate) mod schnorr_protocol;
8484
pub mod tests;
8585

8686
pub use fiat_shamir::Nizk;
87+
pub use group::msm::VariableMultiScalarMul;
8788
pub use linear_relation::LinearRelation;
8889

8990
#[deprecated = "Use sigma_proofs::group::serialization instead"]

src/tests/spec/test_vectors.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
use bls12_381::G1Projective as G;
2-
use core::str;
32
use hex::FromHex;
43
use json::JsonValue;
54
use std::collections::HashMap;
6-
use std::fs;
75

86
use crate::codec::KeccakByteSchnorrCodec;
97
use crate::fiat_shamir::Nizk;
@@ -27,7 +25,7 @@ struct TestVector {
2725
#[test]
2826
fn test_spec_testvectors() {
2927
let proof_generation_rng_seed = b"proof_generation_seed";
30-
let vectors = extract_vectors_new("src/tests/spec/vectors/testSigmaProtocols.json").unwrap();
28+
let vectors = extract_vectors_new().unwrap();
3129

3230
// Define supported ciphersuites
3331
let mut supported_ciphersuites = HashMap::new();
@@ -110,11 +108,11 @@ fn test_spec_testvectors() {
110108
}
111109
}
112110

113-
fn extract_vectors_new(path: &str) -> Result<HashMap<String, TestVector>, String> {
111+
fn extract_vectors_new() -> Result<HashMap<String, TestVector>, String> {
114112
use std::collections::HashMap;
115113

116-
let content = fs::read_to_string(path).map_err(|e| format!("Unable to read JSON file: {e}"))?;
117-
let root: JsonValue = json::parse(&content).map_err(|e| format!("JSON parsing error: {e}"))?;
114+
let content = include_str!("./vectors/testSigmaProtocols.json");
115+
let root: JsonValue = json::parse(content).map_err(|e| format!("JSON parsing error: {e}"))?;
118116

119117
let mut vectors = HashMap::new();
120118

0 commit comments

Comments
 (0)