Skip to content

Commit c2fca84

Browse files
authored
ZKVM-1383: fix: inverse in reciprocal (#8)
* Revert "don't accelerate modinv (#7)" This reverts commit 8922ef3. * set ret to zero * compute inv directly * fix sizeof
1 parent 8922ef3 commit c2fca84

1 file changed

Lines changed: 55 additions & 0 deletions

File tree

src/recip.c

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,39 @@
44
* SPDX-License-Identifier: Apache-2.0
55
*/
66

7+
#ifdef __ZKVM__
8+
#include "risczero_utils.h"
9+
10+
#ifdef __RISC0_UNCHECKED__
11+
extern void risc0_bigint_modmul_256_unchecked(const limb_t*, const limb_t*,
12+
const limb_t*, limb_t*);
13+
extern void risc0_bigint_modmul_384_unchecked(const limb_t*, const limb_t*,
14+
const limb_t*, limb_t*);
15+
extern void risc0_bigint_modinv_256_unchecked(const limb_t*, const limb_t*,
16+
limb_t*);
17+
extern void risc0_bigint_modinv_384_unchecked(const limb_t*, const limb_t*,
18+
limb_t*);
19+
#define r0bigint_modmul_256 risc0_bigint_modmul_256_unchecked
20+
#define r0bigint_modmul_384 risc0_bigint_modmul_384_unchecked
21+
#define r0bigint_modinv_256 risc0_bigint_modinv_256_unchecked
22+
#define r0bigint_modinv_384 risc0_bigint_modinv_384_unchecked
23+
#else // __RISC0_UNCHECKED__
24+
extern void risc0_bigint_modmul_256(const limb_t*, const limb_t*,
25+
const limb_t*, limb_t*);
26+
extern void risc0_bigint_modmul_384(const limb_t*, const limb_t*,
27+
const limb_t*, limb_t*);
28+
extern void risc0_bigint_modinv_256(const limb_t*, const limb_t*,
29+
limb_t*);
30+
extern void risc0_bigint_modinv_384(const limb_t*, const limb_t*,
31+
limb_t*);
32+
#define r0bigint_modmul_256 risc0_bigint_modmul_256
33+
#define r0bigint_modmul_384 risc0_bigint_modmul_384
34+
#define r0bigint_modinv_256 risc0_bigint_modinv_256
35+
#define r0bigint_modinv_384 risc0_bigint_modinv_384
36+
#endif // __RISC0_UNCHECKED__
37+
38+
#endif //__ZKVM__
39+
740
#include "fields.h"
841

942
#ifdef __OPTIMIZE_SIZE__
@@ -57,6 +90,16 @@ static void flt_reciprocal_fp2(vec384x out, const vec384x inp)
5790

5891
static void reciprocal_fp(vec384 out, const vec384 inp)
5992
{
93+
#ifdef __ZKVM__
94+
if (vec_is_zero(inp, sizeof(vec384))) {
95+
vec_zero(out, sizeof(vec384));
96+
} else {
97+
// compute imp^-1 = (v * R)^-1
98+
r0bigint_modinv_384(inp, BLS12_381_P, out);
99+
// compute (v * R)^-1 * R^2 = v^-1 * R
100+
r0bigint_modmul_384(out, BLS12_381_RR, BLS12_381_P, out);
101+
}
102+
#else //__ZKVM__
60103
static const vec384 Px8 = { /* left-aligned value of the modulus */
61104
TO_LIMB_T(0xcff7fffffffd5558), TO_LIMB_T(0xf55ffff58a9ffffd),
62105
TO_LIMB_T(0x39869507b587b120), TO_LIMB_T(0x23ba5c279c2895fb),
@@ -89,6 +132,7 @@ static void reciprocal_fp(vec384 out, const vec384 inp)
89132
vec_copy(out, temp.r[0], sizeof(vec384));
90133
#endif
91134
#undef RRx4
135+
#endif //__ZKVM__
92136
}
93137

94138
void blst_fp_inverse(vec384 out, const vec384 inp)
@@ -121,6 +165,16 @@ void blst_fp2_eucl_inverse(vec384x out, const vec384x inp)
121165

122166
static void reciprocal_fr(vec256 out, const vec256 inp)
123167
{
168+
#ifdef __ZKVM__
169+
if (vec_is_zero(inp, sizeof(vec256))) {
170+
vec_zero(out, sizeof(vec256));
171+
} else {
172+
// compute imp^-1 = (v * R)^-1
173+
r0bigint_modinv_256(inp, BLS12_381_r, out);
174+
// compute (v * R)^-1 * R^2 = v^-1 * R
175+
r0bigint_modmul_256(out, BLS12_381_rRR, BLS12_381_r, out);
176+
}
177+
#else //__ZKVM__
124178
static const vec256 rx2 = { /* left-aligned value of the modulus */
125179
TO_LIMB_T(0xfffffffe00000002), TO_LIMB_T(0xa77b4805fffcb7fd),
126180
TO_LIMB_T(0x6673b0101343b00a), TO_LIMB_T(0xe7db4ea6533afa90),
@@ -130,6 +184,7 @@ static void reciprocal_fr(vec256 out, const vec256 inp)
130184
ct_inverse_mod_256(temp, inp, BLS12_381_r, rx2);
131185
redc_mont_256(out, temp, BLS12_381_r, r0);
132186
mul_mont_sparse_256(out, out, BLS12_381_rRR, BLS12_381_r, r0);
187+
#endif //__ZKVM__
133188
}
134189

135190
void blst_fr_inverse(vec256 out, const vec256 inp)

0 commit comments

Comments
 (0)