Skip to content
96 changes: 76 additions & 20 deletions src/internet_identity/src/openid/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use super::{
};
use crate::openid::OpenIdCredential;
use crate::openid::OpenIdProvider;
use crate::openid::MINUTE_NS;
use crate::secs_to_nanos;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use base64::Engine;
Expand Down Expand Up @@ -38,14 +37,13 @@ const HTTP_STATUS_OK: u8 = 200;

// Fetch the certs every fifteen minutes, the responses are always
// valid for a couple of hours so that should be enough margin.
#[cfg(not(test))]
const FETCH_CERTS_INTERVAL: u64 = 60 * 15; // 15 minutes in seconds
const FETCH_CERTS_INTERVAL_SECONDS: u64 = 60 * 15; // 15 minutes in seconds

// A JWT is only valid for a very small window, even if the JWT itself says it's valid for longer,
// we only need it right after it's being issued to create a JWT delegation with its own expiry.
// As the JWT is also used for registration, which may include longer user interaction,
// we are using 10 minutes to account for potential clock offsets as well as users.
const MAX_VALIDITY_WINDOW: u64 = 10 * MINUTE_NS; // Same as ingress expiry
const MAX_VALIDITY_WINDOW_SECONDS: u64 = 10 * 60; // Same as ingress expiry

// Maximum length of the email claim in the JWT, in practice we expect the identity provider to
// already validate it on their end for a sane maximum length. This is an additional sanity check.
Expand Down Expand Up @@ -167,7 +165,7 @@ impl Provider {
let certs: Rc<RefCell<Vec<Jwk>>> = Rc::new(RefCell::new(vec![]));

#[cfg(not(test))]
schedule_fetch_certs(config.jwks_uri, Rc::clone(&certs), None);
schedule_fetch_certs(config.jwks_uri, Rc::clone(&certs), Some(0));

Provider {
client_id: config.client_id,
Expand All @@ -177,6 +175,29 @@ impl Provider {
}
}

fn compute_next_certs_fetch_delay<T, E>(
result: &Result<T, E>,
current_delay: Option<u64>,
) -> Option<u64> {
const MIN_DELAY_SECONDS: u64 = 60;
const MAX_DELAY_SECONDS: u64 = FETCH_CERTS_INTERVAL_SECONDS;
const BACKOFF_MULTIPLIER: u64 = 2;

match result {
// Reset delay to None so default (`FETCH_CERTS_INTERVAL`) delay is used.
Ok(_) => None,
// Try again earlier with backoff if fetch failed, the HTTP outcall responses
// aren't the same across nodes when we fetch at the moment of key rotation.
//
// The delay should be at most `MAX_DELAY` and at minimum `MIN_DELAY`.
Err(_) => Some(
current_delay
.map_or(MIN_DELAY_SECONDS, |d| d * BACKOFF_MULTIPLIER)
.clamp(MIN_DELAY_SECONDS, MAX_DELAY_SECONDS),
),
}
}

#[cfg(not(test))]
fn schedule_fetch_certs(
jwks_uri: String,
Expand All @@ -185,23 +206,21 @@ fn schedule_fetch_certs(
) {
use ic_cdk::spawn;
use ic_cdk_timers::set_timer;
use std::cmp::min;
use std::time::Duration;

set_timer(Duration::from_secs(delay.unwrap_or(0)), move || {
spawn(async move {
let new_delay = match fetch_certs(jwks_uri.clone()).await {
Ok(certs) => {
set_timer(
Duration::from_secs(delay.unwrap_or(FETCH_CERTS_INTERVAL_SECONDS)),
move || {
spawn(async move {
let result = fetch_certs(jwks_uri.clone()).await;
let next_delay = compute_next_certs_fetch_delay(&result, delay);
if let Ok(certs) = result {
certs_reference.replace(certs);
FETCH_CERTS_INTERVAL
}
// Try again earlier with backoff if fetch failed, the HTTP outcall responses
// aren't the same across nodes when we fetch at the moment of key rotation.
Err(_) => min(FETCH_CERTS_INTERVAL, delay.unwrap_or(60) * 2),
};
schedule_fetch_certs(jwks_uri, certs_reference, Some(new_delay));
});
});
schedule_fetch_certs(jwks_uri, certs_reference, next_delay);
});
},
);
}

#[cfg(not(test))]
Expand Down Expand Up @@ -358,7 +377,7 @@ fn verify_claims(
if now > secs_to_nanos(claims.exp) {
return Err(OpenIDJWTVerificationError::JWTExpired);
}
if now > secs_to_nanos(claims.iat) + MAX_VALIDITY_WINDOW {
if now > secs_to_nanos(claims.iat + MAX_VALIDITY_WINDOW_SECONDS) {
return Err(OpenIDJWTVerificationError::JWTExpired);
}
if now < secs_to_nanos(claims.iat) {
Expand Down Expand Up @@ -592,7 +611,7 @@ fn should_return_error_when_invalid_caller() {

#[test]
fn should_return_error_when_no_longer_valid() {
TEST_TIME.replace(time() + MAX_VALIDITY_WINDOW + 1);
TEST_TIME.replace(time() + secs_to_nanos(MAX_VALIDITY_WINDOW_SECONDS) + 1);
let (_, salt, config, claims) = test_data();

assert_eq!(
Expand Down Expand Up @@ -641,3 +660,40 @@ fn should_return_error_when_name_too_long() {
))
);
}
#[test]
fn should_compute_next_certs_fetch_delay() {
const MIN_DELAY_SECONDS: u64 = 60;
const MAX_DELAY_SECONDS: u64 = FETCH_CERTS_INTERVAL_SECONDS;

let success: Result<(), ()> = Ok(());
let error: Result<(), ()> = Err(());

for (current_delay, expected_next_delay_on_error) in [
// Should be at least `MIN_DELAY` (1 minute)
(None, Some(MIN_DELAY_SECONDS)),
(Some(0), Some(MIN_DELAY_SECONDS)),
(Some(1), Some(MIN_DELAY_SECONDS)),
(Some(MIN_DELAY_SECONDS / 2 - 1), Some(MIN_DELAY_SECONDS)),
// Should be multiplied by two
(Some(MIN_DELAY_SECONDS / 2), Some(MIN_DELAY_SECONDS)),
(Some(MIN_DELAY_SECONDS / 2 + 1), Some(MIN_DELAY_SECONDS + 2)),
(Some(120), Some(240)),
(Some(120), Some(240)),
(Some(240), Some(480)),
// Should be at most `MAX_DELAY` (15 minutes)
(Some(480), Some(MAX_DELAY_SECONDS)),
(Some(MAX_DELAY_SECONDS / 2 + 1), Some(MAX_DELAY_SECONDS)),
(Some(MAX_DELAY_SECONDS * 2), Some(MAX_DELAY_SECONDS)),
] {
// Should return `None` on success so default (`FETCH_CERTS_INTERVAL`) delay is used.
assert_eq!(
compute_next_certs_fetch_delay(&success, current_delay),
None
);
// Should return `expected_next_delay_on_error` on error as specified above.
assert_eq!(
compute_next_certs_fetch_delay(&error, current_delay),
expected_next_delay_on_error
);
}
}
Loading