Skip to content

Commit 735f6aa

Browse files
committed
1 parent 57be20e commit 735f6aa

File tree

3 files changed

+289
-0
lines changed

3 files changed

+289
-0
lines changed

ff/src/fields/models/fp/inverse.rs

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
//! Implementation based on:
2+
//! - https://github.com/privacy-scaling-explorations/halo2curves/blob/3bfa6562f0ddcbac941091ba3c7c9b6c322efac1/src/ff_ext/inverse.rs
3+
//! - https://github.com/RustCrypto/crypto-bigint/blob/682f17a979c3a1886fde5426b26400dfc3f4775b/src/modular/safegcd.rs
4+
//!
5+
//! It's more than twice faster than the default (algebra) implementation:
6+
//! Calling `Fp::invert` 1_000_000 times takes 2.05 seconds, instead of 5.23 seconds
7+
8+
use core::cmp::PartialEq;
9+
10+
/// Integer using 62 bits limbs
11+
#[derive(Clone)]
12+
pub struct Integer(pub [u64; NLIMBS]);
13+
14+
const NLIMBS: usize = 6;
15+
const NBITS_PER_LIMB: usize = 62;
16+
17+
impl Integer {
18+
pub const MASK: u64 = u64::MAX >> (64 - NBITS_PER_LIMB);
19+
pub const MINUS_ONE: Self = Self([Self::MASK; NLIMBS]);
20+
pub const ZERO: Self = Self([0; NLIMBS]);
21+
pub const ONE: Self = {
22+
let mut data = [0; NLIMBS];
23+
data[0] = 1;
24+
Self(data)
25+
};
26+
27+
const fn shr(&self) -> Self {
28+
let mut data = [0; NLIMBS];
29+
if self.is_neg() {
30+
data[NLIMBS - 1] = Self::MASK;
31+
}
32+
let mut i = 0;
33+
while i < NLIMBS - 1 {
34+
data[i] = self.0[i + 1];
35+
i += 1;
36+
}
37+
Self(data)
38+
}
39+
40+
const fn lowest(&self) -> u64 {
41+
self.0[0]
42+
}
43+
44+
const fn is_neg(&self) -> bool {
45+
self.0[NLIMBS - 1] > (Self::MASK >> 1)
46+
}
47+
48+
const fn add(&self, other: &Self) -> Self {
49+
let (mut data, mut carry) = ([0; NLIMBS], 0);
50+
let mut i = 0;
51+
while i < NLIMBS {
52+
let sum = self.0[i] + other.0[i] + carry;
53+
data[i] = sum & Integer::MASK;
54+
carry = sum >> NBITS_PER_LIMB;
55+
i += 1;
56+
}
57+
Self(data)
58+
}
59+
60+
const fn mul(&self, other: i64) -> Self {
61+
let mut data = [0; NLIMBS];
62+
let (other, mut carry, mask) = if other < 0 {
63+
(-other, -other as u64, Integer::MASK)
64+
} else {
65+
(other, 0, 0)
66+
};
67+
let mut i = 0;
68+
while i < NLIMBS {
69+
let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
70+
data[i] = sum as u64 & Integer::MASK;
71+
carry = (sum >> NBITS_PER_LIMB) as u64;
72+
i += 1;
73+
}
74+
Self(data)
75+
}
76+
77+
const fn neg(self) -> Self {
78+
let (mut data, mut carry) = ([0; NLIMBS], 1);
79+
let mut i = 0;
80+
while i < NLIMBS {
81+
let sum = (self.0[i] ^ Integer::MASK) + carry;
82+
data[i] = sum & Integer::MASK;
83+
carry = sum >> NBITS_PER_LIMB;
84+
i += 1;
85+
}
86+
Self(data)
87+
}
88+
}
89+
90+
impl PartialEq for Integer {
91+
fn eq(&self, other: &Self) -> bool {
92+
let mut is_eq = true;
93+
let mut i = 0;
94+
while i < NLIMBS {
95+
is_eq &= self.0[i] == other.0[i];
96+
i += 1;
97+
}
98+
is_eq
99+
}
100+
}
101+
102+
/// Bernstein-Yang inverter
103+
pub struct BYInverter {
104+
pub modulus: Integer,
105+
/// Adjusting parameter
106+
pub adjuster: Integer,
107+
/// Multiplicative inverse of the modulus modulo 2^62
108+
pub inverse: i64,
109+
}
110+
111+
/// Type of the Bernstein-Yang transition matrix multiplied by 2^62
112+
type Matrix = [[i64; 2]; 2];
113+
114+
impl BYInverter {
115+
const fn jump(f: &Integer, g: &Integer, mut delta: i64) -> (i64, Matrix) {
116+
const fn min(a: i64, b: i64) -> i64 {
117+
if a > b {
118+
b
119+
} else {
120+
a
121+
}
122+
}
123+
let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128);
124+
let mut t: Matrix = [[1, 0], [0, 1]];
125+
loop {
126+
let zeros = min(steps, g.trailing_zeros() as i64);
127+
(steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
128+
t[0] = [t[0][0] << zeros, t[0][1] << zeros];
129+
if steps == 0 {
130+
break;
131+
}
132+
if delta > 0 {
133+
(delta, f, g) = (-delta, g as i64, -f as i128);
134+
(t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]);
135+
}
136+
let mask = (1 << min(min(steps, 1 - delta), 5)) - 1;
137+
let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;
138+
t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
139+
g += w as i128 * f as i128;
140+
}
141+
(delta, t)
142+
}
143+
144+
const fn fg(f: Integer, g: Integer, t: Matrix) -> (Integer, Integer) {
145+
(
146+
f.mul(t[0][0]).add(&g.mul(t[0][1])).shr(),
147+
f.mul(t[1][0]).add(&g.mul(t[1][1])).shr(),
148+
)
149+
}
150+
151+
const fn de(&self, d: Integer, e: Integer, t: Matrix) -> (Integer, Integer) {
152+
let mask = Integer::MASK as i64;
153+
let mut md = t[0][0] * d.is_neg() as i64 + t[0][1] * e.is_neg() as i64;
154+
let mut me = t[1][0] * d.is_neg() as i64 + t[1][1] * e.is_neg() as i64;
155+
let cd = t[0][0]
156+
.wrapping_mul(d.lowest() as i64)
157+
.wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64))
158+
& mask;
159+
let ce = t[1][0]
160+
.wrapping_mul(d.lowest() as i64)
161+
.wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64))
162+
& mask;
163+
md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
164+
me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;
165+
let cd = d
166+
.mul(t[0][0])
167+
.add(&e.mul(t[0][1]))
168+
.add(&self.modulus.mul(md));
169+
let ce = d
170+
.mul(t[1][0])
171+
.add(&e.mul(t[1][1]))
172+
.add(&self.modulus.mul(me));
173+
(cd.shr(), ce.shr())
174+
}
175+
176+
const fn norm(&self, mut value: Integer, negate: bool) -> Integer {
177+
if value.is_neg() {
178+
value = value.add(&self.modulus);
179+
}
180+
if negate {
181+
value = value.neg();
182+
}
183+
if value.is_neg() {
184+
value = value.add(&self.modulus);
185+
}
186+
value
187+
}
188+
189+
const fn convert<const I: usize, const O: usize, const S: usize>(input: &[u64]) -> [u64; S] {
190+
const fn min(a: usize, b: usize) -> usize {
191+
if a > b {
192+
b
193+
} else {
194+
a
195+
}
196+
}
197+
let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0);
198+
while bits < total {
199+
let (i, o) = (bits % I, bits % O);
200+
output[bits / O] |= (input[bits / I] >> i) << o;
201+
bits += min(I - i, O - o);
202+
}
203+
let mask = u64::MAX >> (64 - O);
204+
let mut filled = total / O + if total % O > 0 { 1 } else { 0 };
205+
while filled > 0 {
206+
filled -= 1;
207+
output[filled] &= mask;
208+
}
209+
output
210+
}
211+
212+
pub fn invert<const N: usize>(&self, value: &[u64]) -> Option<[u64; N]> {
213+
assert_eq!(N, 4); // Other sizes not supported
214+
215+
let (mut d, mut e) = (Integer::ZERO, self.adjuster.clone());
216+
let mut g = Integer(Self::convert::<64, 62, NLIMBS>(value));
217+
let (mut delta, mut f) = (1, self.modulus.clone());
218+
let mut matrix;
219+
while g != Integer::ZERO {
220+
(delta, matrix) = Self::jump(&f, &g, delta);
221+
(f, g) = Self::fg(f, g, matrix);
222+
(d, e) = self.de(d, e, matrix);
223+
}
224+
let antiunit = f == Integer::MINUS_ONE;
225+
if (f != Integer::ONE) && !antiunit {
226+
return None;
227+
}
228+
Some(Self::convert::<62, 64, N>(&self.norm(d, antiunit).0))
229+
}
230+
}

ff/src/fields/models/fp/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use ark_std::{
1414
One, Zero,
1515
};
1616

17+
mod inverse;
1718
#[macro_use]
1819
mod montgomery_backend;
1920
pub use montgomery_backend::*;

ff/src/fields/models/fp/montgomery_backend.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,64 @@ pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
292292
}
293293
}
294294

295+
#[cfg(not(target_family = "wasm"))] // impl not optimized for wasm, because of some 128 bits products
296+
#[inline]
297+
fn inverse(this: &Fp<MontBackend<Self, N>, N>) -> Option<Fp<MontBackend<Self, N>, N>> {
298+
if this.is_zero() {
299+
None
300+
} else {
301+
// Check `Self::INV` to know which of `Fp` or `Fq` we are
302+
let inverter = if Self::INV == 11037532056220336127 {
303+
// Fp
304+
super::inverse::BYInverter {
305+
modulus: super::inverse::Integer([
306+
1814160019365560321,
307+
655946578803287150,
308+
2,
309+
0,
310+
64,
311+
0,
312+
]),
313+
adjuster: super::inverse::Integer([
314+
898728379203715087,
315+
2255237944337466270,
316+
4141791841069945229,
317+
1968191642266354973,
318+
9,
319+
0,
320+
]),
321+
inverse: 2797525999061827585,
322+
}
323+
} else if Self::INV == 10108024940646105087 {
324+
// Fq
325+
super::inverse::BYInverter {
326+
modulus: super::inverse::Integer([
327+
884652903791329281,
328+
655946578822079350,
329+
2,
330+
0,
331+
64,
332+
0,
333+
]),
334+
adjuster: super::inverse::Integer([
335+
4365809925394268175,
336+
2228451641930570639,
337+
4243007676294846726,
338+
1968191643554589279,
339+
9,
340+
0,
341+
]),
342+
inverse: 3727033114636058625,
343+
}
344+
} else {
345+
unimplemented!();
346+
};
347+
let inverted = inverter.invert(&this.0 .0)?;
348+
Some(Fp::new_unchecked(BigInt(inverted)))
349+
}
350+
}
351+
352+
#[cfg(target_family = "wasm")]
295353
fn inverse(a: &Fp<MontBackend<Self, N>, N>) -> Option<Fp<MontBackend<Self, N>, N>> {
296354
if a.is_zero() {
297355
None

0 commit comments

Comments
 (0)