Skip to content

Commit 9594c19

Browse files
committed
add caching for jwk.n hash
1 parent d109a5c commit 9594c19

3 files changed

Lines changed: 154 additions & 31 deletions

File tree

fastcrypto-zkp/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ harness = false
1515
[[bench]]
1616
name = "zklogin"
1717
harness = false
18+
required-features = ["test-utils"]
1819

1920
[[bench]]
2021
name = "poseidon"
@@ -64,3 +65,4 @@ proptest = "1.1.0"
6465

6566
[features]
6667
e2e = []
68+
test-utils = []

fastcrypto-zkp/benches/zklogin.rs

Lines changed: 117 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ mod zklogin_benches {
88

99
use ark_std::rand::rngs::StdRng;
1010
use ark_std::rand::SeedableRng;
11-
use criterion::Criterion;
11+
use criterion::{BatchSize, Criterion};
1212
use fastcrypto::ed25519::Ed25519KeyPair;
1313
use fastcrypto::error::FastCryptoError;
1414
use fastcrypto::rsa::{Base64UrlUnpadded, Encoding};
1515
use fastcrypto::traits::KeyPair;
1616
use fastcrypto_zkp::bn254::utils::gen_address_seed;
17+
use fastcrypto_zkp::bn254::zk_login::clear_cache_for_testing;
1718
use fastcrypto_zkp::bn254::zk_login::ZkLoginInputs;
1819
use fastcrypto_zkp::bn254::zk_login::JWK;
1920
use fastcrypto_zkp::bn254::zk_login::{JwkId, OIDCProvider};
@@ -70,11 +71,12 @@ mod zklogin_benches {
7071
b.iter(|| input_clone.get_proof().as_arkworks().unwrap())
7172
});
7273

73-
// Benchmark the `calculate_all_inputs_hash` function called by `verify_zk_login`.
74+
// Benchmark `calculate_all_inputs_hash` with a WARM modulus-hash cache (all
75+
// iterations hit).
7476
let eph_pubkey_clone = eph_pubkey.clone();
7577
let input_clone = input.clone();
7678
let modulus_clone = modulus.clone();
77-
c.bench_function("verify_zk_login/calculate_all_inputs_hash", move |b| {
79+
c.bench_function("verify_zk_login/calculate_all_inputs_hash/warm", move |b| {
7880
b.iter(|| {
7981
input_clone
8082
.calculate_all_inputs_hash(
@@ -86,6 +88,28 @@ mod zklogin_benches {
8688
.unwrap()
8789
});
8890
});
91+
92+
// Benchmark `calculate_all_inputs_hash` with a COLD cache (cleared before
93+
// each iteration, so each timed call recomputes the modulus hash).
94+
let eph_pubkey_clone = eph_pubkey.clone();
95+
let input_clone = input.clone();
96+
let modulus_clone = modulus.clone();
97+
c.bench_function("verify_zk_login/calculate_all_inputs_hash/cold", move |b| {
98+
b.iter_batched(
99+
clear_cache_for_testing,
100+
|_| {
101+
input_clone
102+
.calculate_all_inputs_hash(
103+
&eph_pubkey_clone,
104+
&modulus_clone,
105+
max_epoch,
106+
&CIRCUIT_CONFIG_V1,
107+
)
108+
.unwrap()
109+
},
110+
BatchSize::PerIteration,
111+
)
112+
});
89113
let input_hashes = input
90114
.calculate_all_inputs_hash(&eph_pubkey, &modulus, max_epoch, &CIRCUIT_CONFIG_V1)
91115
.unwrap();
@@ -106,18 +130,38 @@ mod zklogin_benches {
106130
},
107131
);
108132

109-
// Benchmark the entire `verify_zk_login` function.
110-
c.bench_function("verify_zk_login", move |b| {
133+
// Benchmark the entire `verify_zk_login` function (warm: modulus hash cache hit).
134+
let input_warm = input.clone();
135+
let eph_warm = eph_pubkey.clone();
136+
let map_warm = map.clone();
137+
c.bench_function("verify_zk_login/warm", move |b| {
111138
b.iter(|| {
112139
fastcrypto_zkp::bn254::zk_login_api::verify_zk_login(
113-
&input,
140+
&input_warm,
114141
max_epoch,
115-
&eph_pubkey,
116-
&map,
142+
&eph_warm,
143+
&map_warm,
117144
&ZkLoginEnv::Test,
118145
)
119146
})
120147
});
148+
149+
// Benchmark `verify_zk_login` on a cold cache (first call after a JWK refresh).
150+
c.bench_function("verify_zk_login/cold", move |b| {
151+
b.iter_batched(
152+
clear_cache_for_testing,
153+
|_| {
154+
fastcrypto_zkp::bn254::zk_login_api::verify_zk_login(
155+
&input,
156+
max_epoch,
157+
&eph_pubkey,
158+
&map,
159+
&ZkLoginEnv::Test,
160+
)
161+
},
162+
BatchSize::PerIteration,
163+
)
164+
});
121165
}
122166

123167
/// Benchmark V2 proof verification for 8192-bit RSA keys
@@ -163,22 +207,49 @@ mod zklogin_benches {
163207
b.iter(|| input_clone.get_proof().as_arkworks().unwrap())
164208
});
165209

166-
// Benchmark the `calculate_all_inputs_hash` function called by `verify_zk_login`.
210+
// Benchmark `calculate_all_inputs_hash` with a WARM modulus-hash cache.
167211
let eph_pubkey_clone = eph_pubkey.clone();
168212
let input_clone = input.clone();
169213
let modulus_clone = modulus.clone();
170-
c.bench_function("verify_zk_login_v2/calculate_all_inputs_hash", move |b| {
171-
b.iter(|| {
172-
input_clone
173-
.calculate_all_inputs_hash(
174-
&eph_pubkey_clone,
175-
&modulus_clone,
176-
max_epoch,
177-
&CIRCUIT_CONFIG_V2,
178-
)
179-
.unwrap()
180-
});
181-
});
214+
c.bench_function(
215+
"verify_zk_login_v2/calculate_all_inputs_hash/warm",
216+
move |b| {
217+
b.iter(|| {
218+
input_clone
219+
.calculate_all_inputs_hash(
220+
&eph_pubkey_clone,
221+
&modulus_clone,
222+
max_epoch,
223+
&CIRCUIT_CONFIG_V2,
224+
)
225+
.unwrap()
226+
});
227+
},
228+
);
229+
230+
// Benchmark `calculate_all_inputs_hash` with a COLD cache (cleared each iteration).
231+
let eph_pubkey_clone = eph_pubkey.clone();
232+
let input_clone = input.clone();
233+
let modulus_clone = modulus.clone();
234+
c.bench_function(
235+
"verify_zk_login_v2/calculate_all_inputs_hash/cold",
236+
move |b| {
237+
b.iter_batched(
238+
clear_cache_for_testing,
239+
|_| {
240+
input_clone
241+
.calculate_all_inputs_hash(
242+
&eph_pubkey_clone,
243+
&modulus_clone,
244+
max_epoch,
245+
&CIRCUIT_CONFIG_V2,
246+
)
247+
.unwrap()
248+
},
249+
BatchSize::PerIteration,
250+
)
251+
},
252+
);
182253
let input_hashes = input
183254
.calculate_all_inputs_hash(&eph_pubkey, &modulus, max_epoch, &CIRCUIT_CONFIG_V2)
184255
.unwrap();
@@ -199,18 +270,38 @@ mod zklogin_benches {
199270
},
200271
);
201272

202-
// Benchmark the entire `verify_zk_login` function.
203-
c.bench_function("verify_zk_login_v2", move |b| {
273+
// Benchmark the entire `verify_zk_login` function (warm: modulus hash cache hit).
274+
let input_warm = input.clone();
275+
let eph_warm = eph_pubkey.clone();
276+
let map_warm = map.clone();
277+
c.bench_function("verify_zk_login_v2/warm", move |b| {
204278
b.iter(|| {
205279
fastcrypto_zkp::bn254::zk_login_api::verify_zk_login(
206-
&input,
280+
&input_warm,
207281
max_epoch,
208-
&eph_pubkey,
209-
&map,
282+
&eph_warm,
283+
&map_warm,
210284
&ZkLoginEnv::Test,
211285
)
212286
})
213287
});
288+
289+
// Benchmark `verify_zk_login` on a cold cache (first call after a JWK refresh).
290+
c.bench_function("verify_zk_login_v2/cold", move |b| {
291+
b.iter_batched(
292+
clear_cache_for_testing,
293+
|_| {
294+
fastcrypto_zkp::bn254::zk_login_api::verify_zk_login(
295+
&input,
296+
max_epoch,
297+
&eph_pubkey,
298+
&map,
299+
&ZkLoginEnv::Test,
300+
)
301+
},
302+
BatchSize::PerIteration,
303+
)
304+
});
214305
}
215306

216307
criterion_group! {

fastcrypto-zkp/src/bn254/zk_login.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,47 @@ pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
2121
use fastcrypto::error::FastCryptoError;
2222
use itertools::Itertools;
2323
use num_bigint::BigUint;
24+
use once_cell::sync::Lazy;
2425
use regex::Regex;
2526
use schemars::JsonSchema;
2627
use serde::{Deserialize, Serialize};
2728
use std::cmp::Ordering::{Equal, Greater, Less};
29+
use std::collections::HashMap;
2830
use std::error::Error;
2931
use std::fmt::Display;
3032
use std::str::FromStr;
33+
use std::sync::RwLock;
34+
35+
/// Key for the modulus hash cache: (modulus bytes, max_rsa_bits).
36+
type ModulusHashKey = (Vec<u8>, u16);
37+
38+
/// JWKs rotate occasionally, so caching by (modulus bytes, max_rsa_bits) avoids recomputing
39+
/// bit-packing + poseidon hash on every verification.
40+
static MODULUS_HASH_CACHE: Lazy<RwLock<HashMap<ModulusHashKey, Bn254Fr>>> =
41+
Lazy::new(|| RwLock::new(HashMap::new()));
42+
43+
fn cached_modulus_hash(modulus: &[u8], max_rsa_bits: u16) -> Result<Bn254Fr, FastCryptoError> {
44+
if let Some(f) = MODULUS_HASH_CACHE
45+
.read()
46+
.ok()
47+
.and_then(|m| m.get(&(modulus.to_vec(), max_rsa_bits)).copied())
48+
{
49+
return Ok(f);
50+
}
51+
let f = hash_to_field(&[BigUint::from_bytes_be(modulus)], max_rsa_bits, PACK_WIDTH)?;
52+
if let Ok(mut m) = MODULUS_HASH_CACHE.write() {
53+
m.insert((modulus.to_vec(), max_rsa_bits), f);
54+
}
55+
Ok(f)
56+
}
57+
58+
/// Clear the modulus hash cache for testing only benchmark.
59+
#[cfg(any(test, feature = "test-utils"))]
60+
pub fn clear_cache_for_testing() {
61+
if let Ok(mut m) = MODULUS_HASH_CACHE.write() {
62+
m.clear();
63+
}
64+
}
3165

3266
#[cfg(test)]
3367
#[path = "unit_tests/zk_login_tests.rs"]
@@ -589,11 +623,7 @@ impl ZkLoginInputs {
589623
let iss_base64_f =
590624
hash_ascii_str_to_field(&self.iss_base64_details.value, config.max_iss_len_b64)?;
591625
let header_f = hash_ascii_str_to_field(&self.header_base64, config.max_header_len_b64)?;
592-
let modulus_f = hash_to_field(
593-
&[BigUint::from_bytes_be(modulus)],
594-
config.max_rsa_bits,
595-
PACK_WIDTH,
596-
)?;
626+
let modulus_f = cached_modulus_hash(modulus, config.max_rsa_bits)?;
597627
poseidon_zk_login(&[
598628
first,
599629
second,

0 commit comments

Comments
 (0)