Skip to content
87 changes: 72 additions & 15 deletions src/internet_identity/src/openid/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ const HTTP_STATUS_OK: u8 = 200;

// Fetch the certs every fifteen minutes, the responses are always
// valid for a couple of hours so that should be enough margin.
#[cfg(not(test))]
const FETCH_CERTS_INTERVAL: u64 = 60 * 15; // 15 minutes in seconds

// A JWT is only valid for a very small window, even if the JWT itself says it's valid for longer,
Expand Down Expand Up @@ -167,7 +166,7 @@ impl Provider {
let certs: Rc<RefCell<Vec<Jwk>>> = Rc::new(RefCell::new(vec![]));

#[cfg(not(test))]
schedule_fetch_certs(config.jwks_uri, Rc::clone(&certs), None);
schedule_fetch_certs(config.jwks_uri, Rc::clone(&certs), Some(0));

Provider {
client_id: config.client_id,
Expand All @@ -177,6 +176,29 @@ impl Provider {
}
}

fn compute_next_certs_fetch_delay<T, E>(
result: &Result<T, E>,
current_delay: Option<u64>,
) -> Option<u64> {
const MIN_DELAY: u64 = 60;
Comment thread
sea-snake marked this conversation as resolved.
Outdated
const MAX_DELAY: u64 = FETCH_CERTS_INTERVAL;
const MULTIPLIER: u64 = 2;

match result {
// Reset delay to None so default (`FETCH_CERTS_INTERVAL`) delay is used.
Ok(_) => None,
// Try again earlier with backoff if fetch failed, the HTTP outcall responses
// aren't the same across nodes when we fetch at the moment of key rotation.
//
// The delay should be at most `MAX_DELAY` and at minimum `MIN_DELAY`.
Err(_) => Some(
current_delay
.map_or(MIN_DELAY, |d| d * MULTIPLIER)
.clamp(MIN_DELAY, MAX_DELAY),
),
}
}

#[cfg(not(test))]
fn schedule_fetch_certs(
jwks_uri: String,
Expand All @@ -185,23 +207,21 @@ fn schedule_fetch_certs(
) {
use ic_cdk::spawn;
use ic_cdk_timers::set_timer;
use std::cmp::min;
use std::time::Duration;

set_timer(Duration::from_secs(delay.unwrap_or(0)), move || {
spawn(async move {
let new_delay = match fetch_certs(jwks_uri.clone()).await {
Ok(certs) => {
set_timer(
Duration::from_secs(delay.unwrap_or(FETCH_CERTS_INTERVAL)),
move || {
spawn(async move {
let result = fetch_certs(jwks_uri.clone()).await;
let next_delay = compute_next_certs_fetch_delay(&result, delay);
if let Ok(certs) = result {
certs_reference.replace(certs);
FETCH_CERTS_INTERVAL
}
// Try again earlier with backoff if fetch failed, the HTTP outcall responses
// aren't the same across nodes when we fetch at the moment of key rotation.
Err(_) => min(FETCH_CERTS_INTERVAL, delay.unwrap_or(60) * 2),
};
schedule_fetch_certs(jwks_uri, certs_reference, Some(new_delay));
});
});
schedule_fetch_certs(jwks_uri, certs_reference, next_delay);
});
},
);
}

#[cfg(not(test))]
Expand Down Expand Up @@ -641,3 +661,40 @@ fn should_return_error_when_name_too_long() {
))
);
}
#[test]
fn should_compute_next_certs_fetch_delay() {
const MIN_DELAY: u64 = 60;
Comment thread
sea-snake marked this conversation as resolved.
Outdated
const MAX_DELAY: u64 = FETCH_CERTS_INTERVAL;

let success: Result<(), ()> = Ok(());
let error: Result<(), ()> = Err(());

for (current_delay, expected_next_delay_on_error) in [
// Should be at least `MIN_DELAY` (1 minute)
(None, Some(MIN_DELAY)),
(Some(0), Some(MIN_DELAY)),
(Some(1), Some(MIN_DELAY)),
(Some(MIN_DELAY / 2 - 1), Some(MIN_DELAY)),
// Should be multiplied by two
(Some(MIN_DELAY / 2), Some(MIN_DELAY)),
(Some(MIN_DELAY / 2 + 1), Some(MIN_DELAY + 2)),
(Some(120), Some(240)),
(Some(120), Some(240)),
(Some(240), Some(480)),
// Should be at most `MAX_DELAY` (15 minutes)
(Some(480), Some(MAX_DELAY)),
(Some(MAX_DELAY / 2 + 1), Some(MAX_DELAY)),
(Some(MAX_DELAY * 2), Some(MAX_DELAY)),
] {
// Should return `None` on success so default (`FETCH_CERTS_INTERVAL`) delay is used.
assert_eq!(
compute_next_certs_fetch_delay(&success, current_delay),
None
);
// Should return `expected_next_delay_on_error` on error as specified above.
assert_eq!(
compute_next_certs_fetch_delay(&error, current_delay),
expected_next_delay_on_error
);
}
}
Loading