diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 4b3128b2d..2d3368ee9 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -140,19 +140,19 @@ pub(crate) fn datetime_repr(d: &DateTime) -> String { #[derive(Clone, Debug, Serialize, Deserialize)] struct PythonCredentialsFetcher { pub pickled_function: Vec, - pub current: Option, + pub initial: Option, } impl PythonCredentialsFetcher { fn new(pickled_function: Vec) -> Self { - PythonCredentialsFetcher { pickled_function, current: None } + PythonCredentialsFetcher { pickled_function, initial: None } } - fn new_with_current(pickled_function: Vec, current: C) -> Self + fn new_with_initial(pickled_function: Vec, current: C) -> Self where C: Into, { - PythonCredentialsFetcher { pickled_function, current: Some(current.into()) } + PythonCredentialsFetcher { pickled_function, initial: Some(current.into()) } } } @@ -174,18 +174,22 @@ where #[typetag::serde] impl S3CredentialsFetcher for PythonCredentialsFetcher { async fn get(&self) -> Result { - Python::with_gil(|py| { - if let Some(static_creds) = self.current.as_ref() { - // avoid herding by adding some randomness - let delta = TimeDelta::seconds(rand::random_range(5..180)); - let expiration = static_creds - .expires_after - .map(|exp| exp - delta) - .unwrap_or(chrono::DateTime::::MAX_UTC); - if Utc::now() < expiration { + if let Some(static_creds) = self.initial.as_ref() { + match static_creds.expires_after { + None => { + return Ok(static_creds.clone()); + } + Some(expiration) + if expiration + > Utc::now() + + TimeDelta::seconds(rand::random_range(120..=180)) => + { return Ok(static_creds.clone()); } + _ => {} } + } + Python::with_gil(|py| { call_pickled::(py, self.pickled_function.clone()) .map(|c| c.into()) }) @@ -197,18 +201,22 @@ impl S3CredentialsFetcher for PythonCredentialsFetcher { #[typetag::serde] impl GcsCredentialsFetcher for PythonCredentialsFetcher { async fn get(&self) -> Result { - Python::with_gil(|py| { - if let Some(static_creds) = self.current.as_ref() { - // avoid herding by adding some randomness - let delta = TimeDelta::seconds(rand::random_range(5..180)); - let expiration = static_creds - .expires_after - .map(|exp| exp - delta) - .unwrap_or(chrono::DateTime::::MAX_UTC); - if Utc::now() < expiration { + if let Some(static_creds) = self.initial.as_ref() { + match static_creds.expires_after { + None => { + return Ok(static_creds.clone()); + } + Some(expiration) + if expiration + > Utc::now() + + TimeDelta::seconds(rand::random_range(120..=180)) => + { return Ok(static_creds.clone()); } + _ => {} } + } + Python::with_gil(|py| { call_pickled::(py, self.pickled_function.clone()) .map(|c| c.into()) }) @@ -233,7 +241,7 @@ impl From for S3Credentials { PyS3Credentials::Static(creds) => S3Credentials::Static(creds.into()), PyS3Credentials::Refreshable { pickled_function, current } => { let fetcher = if let Some(current) = current { - PythonCredentialsFetcher::new_with_current(pickled_function, current) + PythonCredentialsFetcher::new_with_initial(pickled_function, current) } else { PythonCredentialsFetcher::new(pickled_function) }; @@ -320,7 +328,7 @@ impl From for GcsCredentials { PyGcsCredentials::Static(creds) => GcsCredentials::Static(creds.into()), PyGcsCredentials::Refreshable { pickled_function, current } => { let fetcher = if let Some(current) = current { - PythonCredentialsFetcher::new_with_current(pickled_function, current) + PythonCredentialsFetcher::new_with_initial(pickled_function, current) } else { PythonCredentialsFetcher::new(pickled_function) }; diff --git a/icechunk/src/storage/object_store.rs b/icechunk/src/storage/object_store.rs index 3af865195..7a5033979 100644 --- a/icechunk/src/storage/object_store.rs +++ b/icechunk/src/storage/object_store.rs @@ -8,7 +8,7 @@ use crate::{ }; use async_trait::async_trait; use bytes::{Buf, Bytes}; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, TimeDelta, Utc}; use futures::{ StreamExt, TryStreamExt, stream::{self, BoxStream}, @@ -38,7 +38,7 @@ use std::{ }; use tokio::{ io::AsyncRead, - sync::{Mutex, OnceCell}, + sync::{OnceCell, RwLock}, }; use tokio_util::compat::FuturesAsyncReadCompatExt; use tracing::instrument; @@ -1147,29 +1147,34 @@ impl ObjectStoreBackend for GcsObjectStoreBackend { #[derive(Debug)] pub struct GcsRefreshableCredentialProvider { - last_credential: Arc>>, + last_credential: Arc>>, refresher: Arc, } impl GcsRefreshableCredentialProvider { pub fn new(refresher: Arc) -> Self { - Self { last_credential: Arc::new(Mutex::new(None)), refresher } + Self { last_credential: Arc::new(RwLock::new(None)), refresher } } pub async fn get_or_update_credentials( &self, ) -> Result { - let mut last_credential = self.last_credential.lock().await; + let last_credential = self.last_credential.read().await; // If we have a credential and it hasn't expired, return it if let Some(creds) = last_credential.as_ref() { if let Some(expires_after) = creds.expires_after { - if expires_after > Utc::now() { + if expires_after + > Utc::now() + TimeDelta::seconds(rand::random_range(120..=180)) + { return Ok(creds.clone()); } } } + drop(last_credential); + let mut last_credential = self.last_credential.write().await; + // Otherwise, refresh the credential and cache it let creds = self .refresher diff --git a/icechunk/src/storage/s3.rs b/icechunk/src/storage/s3.rs index ae6e1db84..86fe0c20d 100644 --- a/icechunk/src/storage/s3.rs +++ b/icechunk/src/storage/s3.rs @@ -23,8 +23,8 @@ use aws_credential_types::provider::error::CredentialsError; use aws_sdk_s3::{ Client, config::{ - Builder, ConfigBag, Intercept, ProvideCredentials, Region, RuntimeComponents, - interceptors::BeforeTransmitInterceptorContextMut, + Builder, ConfigBag, IdentityCache, Intercept, ProvideCredentials, Region, + RuntimeComponents, interceptors::BeforeTransmitInterceptorContextMut, }, error::{BoxError, SdkError}, operation::put_object::PutObjectError, @@ -40,7 +40,7 @@ use futures::{ }; use serde::{Deserialize, Serialize}; use tokio::{io::AsyncRead, sync::OnceCell}; -use tracing::instrument; +use tracing::{error, instrument}; use super::{ CHUNK_PREFIX, CONFIG_PATH, DeleteObjectsResult, FetchConfigResult, GetRefResult, @@ -171,10 +171,19 @@ pub async fn mk_client( .with_max_backoff(core::time::Duration::from_millis( settings.retries().max_backoff_ms() as u64, )); + let mut s3_builder = Builder::from(&aws_config.load().await) .force_path_style(config.force_path_style) .retry_config(retry_config); + // credentials may take a while to refresh, defaults are too strict + let id_cache = IdentityCache::lazy() + .load_timeout(core::time::Duration::from_secs(120)) + .buffer_time(core::time::Duration::from_secs(120)) + .build(); + + s3_builder = s3_builder.identity_cache(id_cache); + if !extra_read_headers.is_empty() || !extra_write_headers.is_empty() { s3_builder = s3_builder.interceptor(ExtraHeadersInterceptor { extra_read_headers, @@ -1017,7 +1026,12 @@ impl ProvideRefreshableCredentials { async fn provide( &self, ) -> Result { - let creds = self.0.get().await.map_err(CredentialsError::not_loaded)?; + let creds = self + .0 + .get() + .await + .inspect_err(|err| error!(error = err, "Cannot load credentials")) + .map_err(CredentialsError::not_loaded)?; let creds = aws_credential_types::Credentials::new( creds.access_key_id, creds.secret_access_key,