Skip to content

Commit f182f1e

Browse files
authored
chore: pptimize small MSMs by using a naive implementation under size 16 (#93)
1 parent 5cfb48e commit f182f1e

File tree

5 files changed

+133
-5
lines changed

5 files changed

+133
-5
lines changed

.github/workflows/rust.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ jobs:
2525
run: cargo build --verbose
2626
- name: Run tests
2727
run: cargo test --verbose
28+
- name: Run benchmark test
29+
# Run the msm benchmark, just to ensure it isn't broken.
30+
run: cargo bench --bench msm -- --quick
2831

2932
no-std-check:
3033
runs-on: ubuntu-latest

Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,22 @@ ahash = { version = "0.8", default-features = false }
4141

4242
[dev-dependencies]
4343
bls12_381 = "0.8.0"
44+
criterion = { version = "0.7", features = ["html_reports"] }
4445
curve25519-dalek = { version = "4", default-features = false, features = ["serde", "rand_core", "alloc", "digest", "precomputed-tables", "group"] }
4546
hex = "0.4"
4647
hex-literal = "0.4"
4748
json = "0.12.4"
49+
k256 = { version = "0.13", features = ["arithmetic"] }
4850
libtest-mimic = "0.8.1"
51+
p256 = { version = "0.13", features = ["arithmetic"] }
4952
serde = { version = "1.0.219", features = ["derive"] }
5053
serde_json = "1.0.140"
5154
sha2 = "0.10"
5255

56+
[[bench]]
57+
name = "msm"
58+
harness = false
59+
5360
[profile.dev]
5461
# Makes tests run much faster at the cost of slightly longer builds and worse debug info.
5562
opt-level = 1

benches/msm.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use std::hint::black_box;
2+
3+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
4+
use ff::Field;
5+
use group::Group;
6+
use rand::thread_rng;
7+
use sigma_proofs::VariableMultiScalarMul;
8+
9+
fn bench_msm_curve25519_dalek(c: &mut Criterion) {
10+
use curve25519_dalek::{RistrettoPoint, Scalar};
11+
12+
let mut group = c.benchmark_group("MSM curve25519-dalek RistrettoPoint");
13+
let mut rng = thread_rng();
14+
15+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
16+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
17+
let bases: Vec<RistrettoPoint> = (0..*size)
18+
.map(|_| RistrettoPoint::random(&mut rng))
19+
.collect();
20+
21+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
22+
b.iter(|| RistrettoPoint::msm(black_box(&scalars), black_box(&bases)))
23+
});
24+
}
25+
group.finish();
26+
}
27+
28+
fn bench_msm_k256(c: &mut Criterion) {
29+
use k256::{ProjectivePoint, Scalar};
30+
31+
let mut group = c.benchmark_group("MSM k256 ProjectivePoint");
32+
let mut rng = thread_rng();
33+
34+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
35+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
36+
let bases: Vec<ProjectivePoint> = (0..*size)
37+
.map(|_| ProjectivePoint::random(&mut rng))
38+
.collect();
39+
40+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
41+
b.iter(|| ProjectivePoint::msm(black_box(&scalars), black_box(&bases)))
42+
});
43+
}
44+
group.finish();
45+
}
46+
47+
fn bench_msm_p256(c: &mut Criterion) {
48+
use p256::{ProjectivePoint, Scalar};
49+
50+
let mut group = c.benchmark_group("MSM p256 ProjectivePoint");
51+
let mut rng = thread_rng();
52+
53+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
54+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
55+
let bases: Vec<ProjectivePoint> = (0..*size)
56+
.map(|_| ProjectivePoint::random(&mut rng))
57+
.collect();
58+
59+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
60+
b.iter(|| ProjectivePoint::msm(black_box(&scalars), black_box(&bases)))
61+
});
62+
}
63+
group.finish();
64+
}
65+
66+
fn bench_msm_bls12_381_g1(c: &mut Criterion) {
67+
use bls12_381::{G1Projective, Scalar};
68+
69+
let mut group = c.benchmark_group("MSM bls12_381 G1Projective");
70+
let mut rng = thread_rng();
71+
72+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
73+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
74+
let bases: Vec<G1Projective> = (0..*size).map(|_| G1Projective::random(&mut rng)).collect();
75+
76+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
77+
b.iter(|| G1Projective::msm(black_box(&scalars), black_box(&bases)))
78+
});
79+
}
80+
group.finish();
81+
}
82+
83+
fn bench_msm_bls12_381_g2(c: &mut Criterion) {
84+
use bls12_381::{G2Projective, Scalar};
85+
86+
let mut group = c.benchmark_group("MSM bls12_381 G2Projective");
87+
let mut rng = thread_rng();
88+
89+
for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() {
90+
let scalars: Vec<Scalar> = (0..*size).map(|_| Scalar::random(&mut rng)).collect();
91+
let bases: Vec<G2Projective> = (0..*size).map(|_| G2Projective::random(&mut rng)).collect();
92+
93+
group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| {
94+
b.iter(|| G2Projective::msm(black_box(&scalars), black_box(&bases)))
95+
});
96+
}
97+
group.finish();
98+
}
99+
100+
criterion_group!(
101+
benches,
102+
bench_msm_curve25519_dalek,
103+
bench_msm_k256,
104+
bench_msm_p256,
105+
bench_msm_bls12_381_g1,
106+
bench_msm_bls12_381_g2,
107+
);
108+
criterion_main!(benches);

src/group/msm.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,24 @@ impl<G: PrimeGroup> VariableMultiScalarMul for G {
6464
fn msm(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self {
6565
assert_eq!(scalars.len(), bases.len());
6666

67-
if scalars.is_empty() {
68-
return Self::identity();
67+
// NOTE: Based on the msm benchmark in this repo, msm_pippenger provides improvements over
68+
// msm_naive past a small constant size, but is significantly slower for very small MSMs.
69+
match scalars.len() {
70+
0 => Self::identity(),
71+
1..16 => msm_naive(bases, scalars),
72+
16.. => msm_pippenger(bases, scalars),
6973
}
70-
71-
msm_internal(bases, scalars)
7274
}
7375
}
7476

75-
fn msm_internal<G: PrimeGroup>(bases: &[G], scalars: &[G::Scalar]) -> G {
77+
/// A naive MSM implementation.
78+
fn msm_naive<G: PrimeGroup>(bases: &[G], scalars: &[G::Scalar]) -> G {
79+
core::iter::zip(bases, scalars).map(|(g, x)| *g * x).sum()
80+
}
81+
82+
/// An MSM implementation that employ's Pippenger's algorithm and works for all groups that
83+
/// implement `PrimeGroup`.
84+
fn msm_pippenger<G: PrimeGroup>(bases: &[G], scalars: &[G::Scalar]) -> G {
7685
let c = ln_without_floats(scalars.len());
7786
let num_bits = <G::Scalar as PrimeField>::NUM_BITS as usize;
7887
// split `num_bits` into steps of `c`, but skip window 0.

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pub(crate) mod schnorr_protocol;
8484
pub mod tests;
8585

8686
pub use fiat_shamir::Nizk;
87+
pub use group::msm::VariableMultiScalarMul;
8788
pub use linear_relation::LinearRelation;
8889

8990
#[deprecated = "Use sigma_proofs::group::serialization instead"]

0 commit comments

Comments
 (0)