Skip to content

Commit 1d8e7bf

Browse files
daxpeddacarloskiki
authored andcommitted
Add allocation-free EdwardsPoint::compress_batch() (#832)
1 parent 2690176 commit 1d8e7bf

File tree

4 files changed

+81
-28
lines changed

4 files changed

+81
-28
lines changed

curve25519-dalek/benches/dalek_benchmarks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ mod edwards_benches {
3434
let mut rng = OsRng.unwrap_err();
3535
let points: Vec<EdwardsPoint> =
3636
(0..size).map(|_| EdwardsPoint::random(&mut rng)).collect();
37-
b.iter(|| EdwardsPoint::compress_batch(&points));
37+
b.iter(|| EdwardsPoint::compress_batch_alloc(&points));
3838
},
3939
);
4040
}

curve25519-dalek/src/edwards.rs

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ impl EdwardsPoint {
599599

600600
// Compute the denominators in a batch
601601
let mut denominators = eds.iter().map(|p| &p.Z - &p.Y).collect::<Vec<_>>();
602-
FieldElement::batch_invert(&mut denominators);
602+
FieldElement::invert_batch_alloc(&mut denominators);
603603

604604
// Now compute the Montgomery u coordinate for every point
605605
let mut ret = Vec::with_capacity(eds.len());
@@ -616,12 +616,24 @@ impl EdwardsPoint {
616616
self.to_affine().compress()
617617
}
618618

619+
/// Compress several `EdwardsPoint`s into `CompressedEdwardsY` format, using a batch inversion
620+
/// for a significant speedup.
621+
pub fn compress_batch<const N: usize>(inputs: &[EdwardsPoint; N]) -> [CompressedEdwardsY; N] {
622+
let mut zs: [_; N] = core::array::from_fn(|i| inputs[i].Z);
623+
FieldElement::invert_batch(&mut zs);
624+
625+
core::array::from_fn(|i| {
626+
let x = &inputs[i].X * &zs[i];
627+
let y = &inputs[i].Y * &zs[i];
628+
AffinePoint { x, y }.compress()
629+
})
630+
}
619631
/// Compress several `EdwardsPoint`s into `CompressedEdwardsY` format, using a batch inversion
620632
/// for a significant speedup.
621633
#[cfg(feature = "alloc")]
622-
pub fn compress_batch(inputs: &[EdwardsPoint]) -> Vec<CompressedEdwardsY> {
634+
pub fn compress_batch_alloc(inputs: &[EdwardsPoint]) -> Vec<CompressedEdwardsY> {
623635
let mut zs = inputs.iter().map(|input| input.Z).collect::<Vec<_>>();
624-
FieldElement::batch_invert(&mut zs);
636+
FieldElement::invert_batch_alloc(&mut zs);
625637

626638
inputs
627639
.iter()
@@ -2175,30 +2187,49 @@ mod test {
21752187
CompressedEdwardsY::identity()
21762188
);
21772189

2190+
assert_eq!(
2191+
EdwardsPoint::compress_batch(&[EdwardsPoint::identity()]),
2192+
[CompressedEdwardsY::identity()]
2193+
);
21782194
#[cfg(feature = "alloc")]
2179-
{
2180-
let compressed = EdwardsPoint::compress_batch(&[EdwardsPoint::identity()]);
2181-
assert_eq!(&compressed, &[CompressedEdwardsY::identity()]);
2182-
}
2195+
assert_eq!(
2196+
&EdwardsPoint::compress_batch_alloc(&[EdwardsPoint::identity()]),
2197+
&[CompressedEdwardsY::identity()]
2198+
);
21832199
}
21842200

2185-
#[cfg(all(feature = "alloc", feature = "rand_core"))]
2201+
#[cfg(feature = "rand_core")]
21862202
#[test]
21872203
fn compress_batch() {
21882204
let mut rng = rand::rng();
21892205

21902206
// TODO(tarcieri): proptests?
2191-
// Make some points deterministically then randomly
2192-
let mut points = (1u64..16)
2193-
.map(|n| constants::ED25519_BASEPOINT_POINT * Scalar::from(n))
2194-
.collect::<Vec<_>>();
2195-
points.extend(core::iter::repeat_with(|| EdwardsPoint::random(&mut rng)).take(100));
2196-
let compressed = EdwardsPoint::compress_batch(&points);
2207+
2208+
// Make some test points deterministically then randomly
2209+
const TEST_VEC_LEN: usize = 117;
2210+
let points: [EdwardsPoint; TEST_VEC_LEN] = core::array::from_fn(|i| {
2211+
if i < 17 {
2212+
// The first 17 are multiple of the basepoint
2213+
constants::ED25519_BASEPOINT_POINT * Scalar::from(i as u64)
2214+
} else {
2215+
// The rest are random
2216+
EdwardsPoint::random(&mut rng)
2217+
}
2218+
});
2219+
2220+
// Compress the points individually. This is our reference result
2221+
let expected_compressed = core::array::from_fn(|i| points[i].compress());
21972222

21982223
// Check that the batch-compressed points match the individually compressed ones
2199-
for (point, compressed) in points.iter().zip(&compressed) {
2200-
assert_eq!(&point.compress(), compressed);
2201-
}
2224+
assert_eq!(EdwardsPoint::compress_batch(&points), expected_compressed);
2225+
2226+
// Check that the batch-compressed (with alloc) points match the individually compressed
2227+
// ones
2228+
#[cfg(feature = "alloc")]
2229+
assert_eq!(
2230+
EdwardsPoint::compress_batch_alloc(&points),
2231+
expected_compressed
2232+
);
22022233
}
22032234

22042235
#[test]

curve25519-dalek/src/field.rs

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,39 @@ impl FieldElement {
209209
(t19, t3)
210210
}
211211

212+
/// Given a slice of pub(crate)lic `FieldElements`, replace each with its inverse.
213+
///
214+
/// When an input `FieldElement` is zero, its value is unchanged.
215+
pub(crate) fn invert_batch<const N: usize>(inputs: &mut [FieldElement; N]) {
216+
let mut scratch = [FieldElement::ONE; N];
217+
218+
Self::internal_invert_batch(inputs, &mut scratch);
219+
}
220+
212221
/// Given a slice of pub(crate)lic `FieldElements`, replace each with its inverse.
213222
///
214223
/// When an input `FieldElement` is zero, its value is unchanged.
215224
#[cfg(feature = "alloc")]
216-
pub(crate) fn batch_invert(inputs: &mut [FieldElement]) {
225+
pub(crate) fn invert_batch_alloc(inputs: &mut [FieldElement]) {
226+
let n = inputs.len();
227+
let mut scratch = vec![FieldElement::ONE; n];
228+
229+
Self::internal_invert_batch(inputs, &mut scratch);
230+
}
231+
232+
/// Given a slice of pub(crate)lic `FieldElements`, replace each with its inverse. `scratch` can
233+
/// contain anything, so long as its length is the same as `inputs`.
234+
///
235+
/// When an input `FieldElement` is zero, its value is unchanged.
236+
///
237+
/// # Panics
238+
/// Panics when `scratch.len() != inputs.len()`
239+
fn internal_invert_batch(inputs: &mut [FieldElement], scratch: &mut [FieldElement]) {
217240
// Montgomery’s Trick and Fast Implementation of Masked AES
218241
// Genelle, Prouff and Quisquater
219242
// Section 3.2
220243

221-
let n = inputs.len();
222-
let mut scratch = vec![FieldElement::ONE; n];
244+
debug_assert_eq!(inputs.len(), scratch.len());
223245

224246
// Keep an accumulator of all of the previous products
225247
let mut acc = FieldElement::ONE;
@@ -240,12 +262,12 @@ impl FieldElement {
240262

241263
// Pass through the vector backwards to compute the inverses
242264
// in place
243-
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.into_iter().rev()) {
265+
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.iter_mut().rev()) {
244266
let tmp = &acc * input;
245267
// input <- acc * scratch, then acc <- tmp
246268
// Again, we skip zeros in a constant-time way
247269
let nz = !input.is_zero();
248-
input.conditional_assign(&(&acc * &scratch), nz);
270+
input.conditional_assign(&(&acc * scratch), nz);
249271
acc.conditional_assign(&tmp, nz);
250272
}
251273
}
@@ -559,7 +581,7 @@ mod test {
559581

560582
#[test]
561583
#[cfg(feature = "alloc")]
562-
fn batch_invert_a_matches_nonbatched() {
584+
fn invert_batch_a_matches_nonbatched() {
563585
let a = FieldElement::from_bytes(&A_BYTES);
564586
let ap58 = FieldElement::from_bytes(&AP58_BYTES);
565587
let asq = FieldElement::from_bytes(&ASQ_BYTES);
@@ -568,7 +590,7 @@ mod test {
568590
let a2 = &a + &a;
569591
let a_list = vec![a, ap58, asq, ainv, a0, a2];
570592
let mut ainv_list = a_list.clone();
571-
FieldElement::batch_invert(&mut ainv_list[..]);
593+
FieldElement::invert_batch_alloc(&mut ainv_list[..]);
572594
for i in 0..6 {
573595
assert_eq!(a_list[i].invert(), ainv_list[i]);
574596
}
@@ -677,8 +699,8 @@ mod test {
677699

678700
#[test]
679701
#[cfg(feature = "alloc")]
680-
fn batch_invert_empty() {
681-
FieldElement::batch_invert(&mut []);
702+
fn invert_batch_empty() {
703+
FieldElement::invert_batch_alloc(&mut []);
682704
}
683705

684706
// The following two consts were generated with the following sage script:

curve25519-dalek/src/ristretto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ impl RistrettoPoint {
606606

607607
let mut invs: Vec<FieldElement> = states.iter().map(|state| state.efgh()).collect();
608608

609-
FieldElement::batch_invert(&mut invs[..]);
609+
FieldElement::invert_batch_alloc(&mut invs[..]);
610610

611611
states
612612
.iter()

0 commit comments

Comments
 (0)