diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 80ec371..431ead3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -25,6 +25,9 @@ jobs: run: cargo build --verbose - name: Run tests run: cargo test --verbose + - name: Run benchmark test + # Run the msm benchmark, just to ensure it isn't broken. + run: cargo bench --bench msm -- --quick no-std-check: runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index c8a9e6a..bdc7601 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,15 +41,22 @@ ahash = { version = "0.8", default-features = false } [dev-dependencies] bls12_381 = "0.8.0" +criterion = { version = "0.7", features = ["html_reports"] } curve25519-dalek = { version = "4", default-features = false, features = ["serde", "rand_core", "alloc", "digest", "precomputed-tables", "group"] } hex = "0.4" hex-literal = "0.4" json = "0.12.4" +k256 = { version = "0.13", features = ["arithmetic"] } libtest-mimic = "0.8.1" +p256 = { version = "0.13", features = ["arithmetic"] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" sha2 = "0.10" +[[bench]] +name = "msm" +harness = false + [profile.dev] # Makes tests run much faster at the cost of slightly longer builds and worse debug info. opt-level = 1 diff --git a/benches/msm.rs b/benches/msm.rs new file mode 100644 index 0000000..e470418 --- /dev/null +++ b/benches/msm.rs @@ -0,0 +1,108 @@ +use std::hint::black_box; + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use ff::Field; +use group::Group; +use rand::thread_rng; +use sigma_proofs::VariableMultiScalarMul; + +fn bench_msm_curve25519_dalek(c: &mut Criterion) { + use curve25519_dalek::{RistrettoPoint, Scalar}; + + let mut group = c.benchmark_group("MSM curve25519-dalek RistrettoPoint"); + let mut rng = thread_rng(); + + for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() { + let scalars: Vec = (0..*size).map(|_| Scalar::random(&mut rng)).collect(); + let bases: Vec = (0..*size) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect(); + + group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| { + b.iter(|| RistrettoPoint::msm(black_box(&scalars), black_box(&bases))) + }); + } + group.finish(); +} + +fn bench_msm_k256(c: &mut Criterion) { + use k256::{ProjectivePoint, Scalar}; + + let mut group = c.benchmark_group("MSM k256 ProjectivePoint"); + let mut rng = thread_rng(); + + for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() { + let scalars: Vec = (0..*size).map(|_| Scalar::random(&mut rng)).collect(); + let bases: Vec = (0..*size) + .map(|_| ProjectivePoint::random(&mut rng)) + .collect(); + + group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| { + b.iter(|| ProjectivePoint::msm(black_box(&scalars), black_box(&bases))) + }); + } + group.finish(); +} + +fn bench_msm_p256(c: &mut Criterion) { + use p256::{ProjectivePoint, Scalar}; + + let mut group = c.benchmark_group("MSM p256 ProjectivePoint"); + let mut rng = thread_rng(); + + for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() { + let scalars: Vec = (0..*size).map(|_| Scalar::random(&mut rng)).collect(); + let bases: Vec = (0..*size) + .map(|_| ProjectivePoint::random(&mut rng)) + .collect(); + + group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| { + b.iter(|| ProjectivePoint::msm(black_box(&scalars), black_box(&bases))) + }); + } + group.finish(); +} + +fn bench_msm_bls12_381_g1(c: &mut Criterion) { + use bls12_381::{G1Projective, Scalar}; + + let mut group = c.benchmark_group("MSM bls12_381 G1Projective"); + let mut rng = thread_rng(); + + for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() { + let scalars: Vec = (0..*size).map(|_| Scalar::random(&mut rng)).collect(); + let bases: Vec = (0..*size).map(|_| G1Projective::random(&mut rng)).collect(); + + group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| { + b.iter(|| G1Projective::msm(black_box(&scalars), black_box(&bases))) + }); + } + group.finish(); +} + +fn bench_msm_bls12_381_g2(c: &mut Criterion) { + use bls12_381::{G2Projective, Scalar}; + + let mut group = c.benchmark_group("MSM bls12_381 G2Projective"); + let mut rng = thread_rng(); + + for size in [1, 2, 4, 8, 16, 64, 256, 1024].iter() { + let scalars: Vec = (0..*size).map(|_| Scalar::random(&mut rng)).collect(); + let bases: Vec = (0..*size).map(|_| G2Projective::random(&mut rng)).collect(); + + group.bench_with_input(BenchmarkId::new("size", size), size, |b, _| { + b.iter(|| G2Projective::msm(black_box(&scalars), black_box(&bases))) + }); + } + group.finish(); +} + +criterion_group!( + benches, + bench_msm_curve25519_dalek, + bench_msm_k256, + bench_msm_p256, + bench_msm_bls12_381_g1, + bench_msm_bls12_381_g2, +); +criterion_main!(benches); diff --git a/src/group/msm.rs b/src/group/msm.rs index ea4d86c..adebc78 100644 --- a/src/group/msm.rs +++ b/src/group/msm.rs @@ -64,15 +64,24 @@ impl VariableMultiScalarMul for G { fn msm(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self { assert_eq!(scalars.len(), bases.len()); - if scalars.is_empty() { - return Self::identity(); + // NOTE: Based on the msm benchmark in this repo, msm_pippenger provides improvements over + // msm_naive past a small constant size, but is significantly slower for very small MSMs. + match scalars.len() { + 0 => Self::identity(), + 1..16 => msm_naive(bases, scalars), + 16.. => msm_pippenger(bases, scalars), } - - msm_internal(bases, scalars) } } -fn msm_internal(bases: &[G], scalars: &[G::Scalar]) -> G { +/// A naive MSM implementation. +fn msm_naive(bases: &[G], scalars: &[G::Scalar]) -> G { + core::iter::zip(bases, scalars).map(|(g, x)| *g * x).sum() +} + +/// An MSM implementation that employ's Pippenger's algorithm and works for all groups that +/// implement `PrimeGroup`. +fn msm_pippenger(bases: &[G], scalars: &[G::Scalar]) -> G { let c = ln_without_floats(scalars.len()); let num_bits = ::NUM_BITS as usize; // split `num_bits` into steps of `c`, but skip window 0. diff --git a/src/lib.rs b/src/lib.rs index 2bb96ff..993d549 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -84,6 +84,7 @@ pub(crate) mod schnorr_protocol; pub mod tests; pub use fiat_shamir::Nizk; +pub use group::msm::VariableMultiScalarMul; pub use linear_relation::LinearRelation; #[deprecated = "Use sigma_proofs::group::serialization instead"]