Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions icechunk-python/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ pub(crate) fn datetime_repr(d: &DateTime<Utc>) -> String {
#[derive(Clone, Debug, Serialize, Deserialize)]
struct PythonCredentialsFetcher<CredType> {
pub pickled_function: Vec<u8>,
pub current: Option<CredType>,
pub initial: Option<CredType>,
}

impl<CredType> PythonCredentialsFetcher<CredType> {
fn new(pickled_function: Vec<u8>) -> Self {
PythonCredentialsFetcher { pickled_function, current: None }
PythonCredentialsFetcher { pickled_function, initial: None }
}

fn new_with_current<C>(pickled_function: Vec<u8>, current: C) -> Self
fn new_with_initial<C>(pickled_function: Vec<u8>, current: C) -> Self
where
C: Into<CredType>,
{
PythonCredentialsFetcher { pickled_function, current: Some(current.into()) }
PythonCredentialsFetcher { pickled_function, initial: Some(current.into()) }
}
}

Expand All @@ -174,18 +174,22 @@ where
#[typetag::serde]
impl S3CredentialsFetcher for PythonCredentialsFetcher<S3StaticCredentials> {
async fn get(&self) -> Result<S3StaticCredentials, String> {
Python::with_gil(|py| {
if let Some(static_creds) = self.current.as_ref() {
// avoid herding by adding some randomness
let delta = TimeDelta::seconds(rand::random_range(5..180));
let expiration = static_creds
.expires_after
.map(|exp| exp - delta)
.unwrap_or(chrono::DateTime::<Utc>::MAX_UTC);
if Utc::now() < expiration {
if let Some(static_creds) = self.initial.as_ref() {
match static_creds.expires_after {
None => {
return Ok(static_creds.clone());
}
Some(expiration)
if expiration
> Utc::now()
+ TimeDelta::seconds(rand::random_range(120..=180)) =>
{
return Ok(static_creds.clone());
}
_ => {}
}
}
Python::with_gil(|py| {
call_pickled::<PyS3StaticCredentials>(py, self.pickled_function.clone())
.map(|c| c.into())
})
Expand All @@ -197,18 +201,22 @@ impl S3CredentialsFetcher for PythonCredentialsFetcher<S3StaticCredentials> {
#[typetag::serde]
impl GcsCredentialsFetcher for PythonCredentialsFetcher<GcsBearerCredential> {
async fn get(&self) -> Result<GcsBearerCredential, String> {
Python::with_gil(|py| {
if let Some(static_creds) = self.current.as_ref() {
// avoid herding by adding some randomness
let delta = TimeDelta::seconds(rand::random_range(5..180));
let expiration = static_creds
.expires_after
.map(|exp| exp - delta)
.unwrap_or(chrono::DateTime::<Utc>::MAX_UTC);
if Utc::now() < expiration {
if let Some(static_creds) = self.initial.as_ref() {
match static_creds.expires_after {
None => {
return Ok(static_creds.clone());
}
Some(expiration)
if expiration
> Utc::now()
+ TimeDelta::seconds(rand::random_range(120..=180)) =>
{
return Ok(static_creds.clone());
}
_ => {}
}
}
Python::with_gil(|py| {
call_pickled::<PyGcsBearerCredential>(py, self.pickled_function.clone())
.map(|c| c.into())
})
Expand All @@ -233,7 +241,7 @@ impl From<PyS3Credentials> for S3Credentials {
PyS3Credentials::Static(creds) => S3Credentials::Static(creds.into()),
PyS3Credentials::Refreshable { pickled_function, current } => {
let fetcher = if let Some(current) = current {
PythonCredentialsFetcher::new_with_current(pickled_function, current)
PythonCredentialsFetcher::new_with_initial(pickled_function, current)
} else {
PythonCredentialsFetcher::new(pickled_function)
};
Expand Down Expand Up @@ -320,7 +328,7 @@ impl From<PyGcsCredentials> for GcsCredentials {
PyGcsCredentials::Static(creds) => GcsCredentials::Static(creds.into()),
PyGcsCredentials::Refreshable { pickled_function, current } => {
let fetcher = if let Some(current) = current {
PythonCredentialsFetcher::new_with_current(pickled_function, current)
PythonCredentialsFetcher::new_with_initial(pickled_function, current)
} else {
PythonCredentialsFetcher::new(pickled_function)
};
Expand Down
17 changes: 11 additions & 6 deletions icechunk/src/storage/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
};
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use chrono::{DateTime, Utc};
use chrono::{DateTime, TimeDelta, Utc};
use futures::{
StreamExt, TryStreamExt,
stream::{self, BoxStream},
Expand Down Expand Up @@ -38,7 +38,7 @@ use std::{
};
use tokio::{
io::AsyncRead,
sync::{Mutex, OnceCell},
sync::{OnceCell, RwLock},
};
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tracing::instrument;
Expand Down Expand Up @@ -1147,29 +1147,34 @@ impl ObjectStoreBackend for GcsObjectStoreBackend {

#[derive(Debug)]
pub struct GcsRefreshableCredentialProvider {
last_credential: Arc<Mutex<Option<GcsBearerCredential>>>,
last_credential: Arc<RwLock<Option<GcsBearerCredential>>>,
refresher: Arc<dyn GcsCredentialsFetcher>,
}

impl GcsRefreshableCredentialProvider {
pub fn new(refresher: Arc<dyn GcsCredentialsFetcher>) -> Self {
Self { last_credential: Arc::new(Mutex::new(None)), refresher }
Self { last_credential: Arc::new(RwLock::new(None)), refresher }
}

pub async fn get_or_update_credentials(
&self,
) -> Result<GcsBearerCredential, StorageError> {
let mut last_credential = self.last_credential.lock().await;
let last_credential = self.last_credential.read().await;

// If we have a credential and it hasn't expired, return it
if let Some(creds) = last_credential.as_ref() {
if let Some(expires_after) = creds.expires_after {
if expires_after > Utc::now() {
if expires_after
> Utc::now() + TimeDelta::seconds(rand::random_range(120..=180))
{
return Ok(creds.clone());
}
}
}

drop(last_credential);
let mut last_credential = self.last_credential.write().await;

// Otherwise, refresh the credential and cache it
let creds = self
.refresher
Expand Down
22 changes: 18 additions & 4 deletions icechunk/src/storage/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use aws_credential_types::provider::error::CredentialsError;
use aws_sdk_s3::{
Client,
config::{
Builder, ConfigBag, Intercept, ProvideCredentials, Region, RuntimeComponents,
interceptors::BeforeTransmitInterceptorContextMut,
Builder, ConfigBag, IdentityCache, Intercept, ProvideCredentials, Region,
RuntimeComponents, interceptors::BeforeTransmitInterceptorContextMut,
},
error::{BoxError, SdkError},
operation::put_object::PutObjectError,
Expand All @@ -40,7 +40,7 @@ use futures::{
};
use serde::{Deserialize, Serialize};
use tokio::{io::AsyncRead, sync::OnceCell};
use tracing::instrument;
use tracing::{error, instrument};

use super::{
CHUNK_PREFIX, CONFIG_PATH, DeleteObjectsResult, FetchConfigResult, GetRefResult,
Expand Down Expand Up @@ -171,10 +171,19 @@ pub async fn mk_client(
.with_max_backoff(core::time::Duration::from_millis(
settings.retries().max_backoff_ms() as u64,
));

let mut s3_builder = Builder::from(&aws_config.load().await)
.force_path_style(config.force_path_style)
.retry_config(retry_config);

// credentials may take a while to refresh, defaults are too strict
let id_cache = IdentityCache::lazy()
.load_timeout(core::time::Duration::from_secs(120))
.buffer_time(core::time::Duration::from_secs(120))
.build();

s3_builder = s3_builder.identity_cache(id_cache);

if !extra_read_headers.is_empty() || !extra_write_headers.is_empty() {
s3_builder = s3_builder.interceptor(ExtraHeadersInterceptor {
extra_read_headers,
Expand Down Expand Up @@ -1017,7 +1026,12 @@ impl ProvideRefreshableCredentials {
async fn provide(
&self,
) -> Result<aws_credential_types::Credentials, CredentialsError> {
let creds = self.0.get().await.map_err(CredentialsError::not_loaded)?;
let creds = self
.0
.get()
.await
.inspect_err(|err| error!(error = err, "Cannot load credentials"))
.map_err(CredentialsError::not_loaded)?;
let creds = aws_credential_types::Credentials::new(
creds.access_key_id,
creds.secret_access_key,
Expand Down
Loading