|
| 1 | +//! Implementation based on: |
| 2 | +//! - https://github.com/privacy-scaling-explorations/halo2curves/blob/3bfa6562f0ddcbac941091ba3c7c9b6c322efac1/src/ff_ext/inverse.rs |
| 3 | +//! - https://github.com/RustCrypto/crypto-bigint/blob/682f17a979c3a1886fde5426b26400dfc3f4775b/src/modular/safegcd.rs |
| 4 | +//! |
| 5 | +//! It's more than twice faster than the default (algebra) implementation: |
| 6 | +//! Calling `Fp::invert` 1_000_000 times takes 2.05 seconds, instead of 5.23 seconds |
| 7 | +
|
| 8 | +use core::cmp::PartialEq; |
| 9 | + |
| 10 | +/// Integer using 62 bits limbs |
| 11 | +#[derive(Clone)] |
| 12 | +pub struct Integer(pub [u64; NLIMBS]); |
| 13 | + |
| 14 | +const NLIMBS: usize = 6; |
| 15 | +const NBITS_PER_LIMB: usize = 62; |
| 16 | + |
| 17 | +impl Integer { |
| 18 | + pub const MASK: u64 = u64::MAX >> (64 - NBITS_PER_LIMB); |
| 19 | + pub const MINUS_ONE: Self = Self([Self::MASK; NLIMBS]); |
| 20 | + pub const ZERO: Self = Self([0; NLIMBS]); |
| 21 | + pub const ONE: Self = { |
| 22 | + let mut data = [0; NLIMBS]; |
| 23 | + data[0] = 1; |
| 24 | + Self(data) |
| 25 | + }; |
| 26 | + |
| 27 | + const fn shr(&self) -> Self { |
| 28 | + let mut data = [0; NLIMBS]; |
| 29 | + if self.is_neg() { |
| 30 | + data[NLIMBS - 1] = Self::MASK; |
| 31 | + } |
| 32 | + let mut i = 0; |
| 33 | + while i < NLIMBS - 1 { |
| 34 | + data[i] = self.0[i + 1]; |
| 35 | + i += 1; |
| 36 | + } |
| 37 | + Self(data) |
| 38 | + } |
| 39 | + |
| 40 | + const fn lowest(&self) -> u64 { |
| 41 | + self.0[0] |
| 42 | + } |
| 43 | + |
| 44 | + const fn is_neg(&self) -> bool { |
| 45 | + self.0[NLIMBS - 1] > (Self::MASK >> 1) |
| 46 | + } |
| 47 | + |
| 48 | + const fn add(&self, other: &Self) -> Self { |
| 49 | + let (mut data, mut carry) = ([0; NLIMBS], 0); |
| 50 | + let mut i = 0; |
| 51 | + while i < NLIMBS { |
| 52 | + let sum = self.0[i] + other.0[i] + carry; |
| 53 | + data[i] = sum & Integer::MASK; |
| 54 | + carry = sum >> NBITS_PER_LIMB; |
| 55 | + i += 1; |
| 56 | + } |
| 57 | + Self(data) |
| 58 | + } |
| 59 | + |
| 60 | + const fn mul(&self, other: i64) -> Self { |
| 61 | + let mut data = [0; NLIMBS]; |
| 62 | + let (other, mut carry, mask) = if other < 0 { |
| 63 | + (-other, -other as u64, Integer::MASK) |
| 64 | + } else { |
| 65 | + (other, 0, 0) |
| 66 | + }; |
| 67 | + let mut i = 0; |
| 68 | + while i < NLIMBS { |
| 69 | + let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128); |
| 70 | + data[i] = sum as u64 & Integer::MASK; |
| 71 | + carry = (sum >> NBITS_PER_LIMB) as u64; |
| 72 | + i += 1; |
| 73 | + } |
| 74 | + Self(data) |
| 75 | + } |
| 76 | + |
| 77 | + const fn neg(self) -> Self { |
| 78 | + let (mut data, mut carry) = ([0; NLIMBS], 1); |
| 79 | + let mut i = 0; |
| 80 | + while i < NLIMBS { |
| 81 | + let sum = (self.0[i] ^ Integer::MASK) + carry; |
| 82 | + data[i] = sum & Integer::MASK; |
| 83 | + carry = sum >> NBITS_PER_LIMB; |
| 84 | + i += 1; |
| 85 | + } |
| 86 | + Self(data) |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +impl PartialEq for Integer { |
| 91 | + fn eq(&self, other: &Self) -> bool { |
| 92 | + let mut is_eq = true; |
| 93 | + let mut i = 0; |
| 94 | + while i < NLIMBS { |
| 95 | + is_eq &= self.0[i] == other.0[i]; |
| 96 | + i += 1; |
| 97 | + } |
| 98 | + is_eq |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +/// Bernstein-Yang inverter |
| 103 | +pub struct BYInverter { |
| 104 | + pub modulus: Integer, |
| 105 | + /// Adjusting parameter |
| 106 | + pub adjuster: Integer, |
| 107 | + /// Multiplicative inverse of the modulus modulo 2^62 |
| 108 | + pub inverse: i64, |
| 109 | +} |
| 110 | + |
| 111 | +/// Type of the Bernstein-Yang transition matrix multiplied by 2^62 |
| 112 | +type Matrix = [[i64; 2]; 2]; |
| 113 | + |
| 114 | +impl BYInverter { |
| 115 | + const fn jump(f: &Integer, g: &Integer, mut delta: i64) -> (i64, Matrix) { |
| 116 | + const fn min(a: i64, b: i64) -> i64 { |
| 117 | + if a > b { |
| 118 | + b |
| 119 | + } else { |
| 120 | + a |
| 121 | + } |
| 122 | + } |
| 123 | + let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128); |
| 124 | + let mut t: Matrix = [[1, 0], [0, 1]]; |
| 125 | + loop { |
| 126 | + let zeros = min(steps, g.trailing_zeros() as i64); |
| 127 | + (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros); |
| 128 | + t[0] = [t[0][0] << zeros, t[0][1] << zeros]; |
| 129 | + if steps == 0 { |
| 130 | + break; |
| 131 | + } |
| 132 | + if delta > 0 { |
| 133 | + (delta, f, g) = (-delta, g as i64, -f as i128); |
| 134 | + (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]); |
| 135 | + } |
| 136 | + let mask = (1 << min(min(steps, 1 - delta), 5)) - 1; |
| 137 | + let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask; |
| 138 | + t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]]; |
| 139 | + g += w as i128 * f as i128; |
| 140 | + } |
| 141 | + (delta, t) |
| 142 | + } |
| 143 | + |
| 144 | + const fn fg(f: Integer, g: Integer, t: Matrix) -> (Integer, Integer) { |
| 145 | + ( |
| 146 | + f.mul(t[0][0]).add(&g.mul(t[0][1])).shr(), |
| 147 | + f.mul(t[1][0]).add(&g.mul(t[1][1])).shr(), |
| 148 | + ) |
| 149 | + } |
| 150 | + |
| 151 | + const fn de(&self, d: Integer, e: Integer, t: Matrix) -> (Integer, Integer) { |
| 152 | + let mask = Integer::MASK as i64; |
| 153 | + let mut md = t[0][0] * d.is_neg() as i64 + t[0][1] * e.is_neg() as i64; |
| 154 | + let mut me = t[1][0] * d.is_neg() as i64 + t[1][1] * e.is_neg() as i64; |
| 155 | + let cd = t[0][0] |
| 156 | + .wrapping_mul(d.lowest() as i64) |
| 157 | + .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64)) |
| 158 | + & mask; |
| 159 | + let ce = t[1][0] |
| 160 | + .wrapping_mul(d.lowest() as i64) |
| 161 | + .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64)) |
| 162 | + & mask; |
| 163 | + md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask; |
| 164 | + me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask; |
| 165 | + let cd = d |
| 166 | + .mul(t[0][0]) |
| 167 | + .add(&e.mul(t[0][1])) |
| 168 | + .add(&self.modulus.mul(md)); |
| 169 | + let ce = d |
| 170 | + .mul(t[1][0]) |
| 171 | + .add(&e.mul(t[1][1])) |
| 172 | + .add(&self.modulus.mul(me)); |
| 173 | + (cd.shr(), ce.shr()) |
| 174 | + } |
| 175 | + |
| 176 | + const fn norm(&self, mut value: Integer, negate: bool) -> Integer { |
| 177 | + if value.is_neg() { |
| 178 | + value = value.add(&self.modulus); |
| 179 | + } |
| 180 | + if negate { |
| 181 | + value = value.neg(); |
| 182 | + } |
| 183 | + if value.is_neg() { |
| 184 | + value = value.add(&self.modulus); |
| 185 | + } |
| 186 | + value |
| 187 | + } |
| 188 | + |
| 189 | + const fn convert<const I: usize, const O: usize, const S: usize>(input: &[u64]) -> [u64; S] { |
| 190 | + const fn min(a: usize, b: usize) -> usize { |
| 191 | + if a > b { |
| 192 | + b |
| 193 | + } else { |
| 194 | + a |
| 195 | + } |
| 196 | + } |
| 197 | + let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0); |
| 198 | + while bits < total { |
| 199 | + let (i, o) = (bits % I, bits % O); |
| 200 | + output[bits / O] |= (input[bits / I] >> i) << o; |
| 201 | + bits += min(I - i, O - o); |
| 202 | + } |
| 203 | + let mask = u64::MAX >> (64 - O); |
| 204 | + let mut filled = total / O + if total % O > 0 { 1 } else { 0 }; |
| 205 | + while filled > 0 { |
| 206 | + filled -= 1; |
| 207 | + output[filled] &= mask; |
| 208 | + } |
| 209 | + output |
| 210 | + } |
| 211 | + |
| 212 | + pub fn invert<const N: usize>(&self, value: &[u64]) -> Option<[u64; N]> { |
| 213 | + assert_eq!(N, 4); // Other sizes not supported |
| 214 | + |
| 215 | + let (mut d, mut e) = (Integer::ZERO, self.adjuster.clone()); |
| 216 | + let mut g = Integer(Self::convert::<64, 62, NLIMBS>(value)); |
| 217 | + let (mut delta, mut f) = (1, self.modulus.clone()); |
| 218 | + let mut matrix; |
| 219 | + while g != Integer::ZERO { |
| 220 | + (delta, matrix) = Self::jump(&f, &g, delta); |
| 221 | + (f, g) = Self::fg(f, g, matrix); |
| 222 | + (d, e) = self.de(d, e, matrix); |
| 223 | + } |
| 224 | + let antiunit = f == Integer::MINUS_ONE; |
| 225 | + if (f != Integer::ONE) && !antiunit { |
| 226 | + return None; |
| 227 | + } |
| 228 | + Some(Self::convert::<62, 64, N>(&self.norm(d, antiunit).0)) |
| 229 | + } |
| 230 | +} |
0 commit comments