Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions icechunk-python/python/icechunk/credentials.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import inspect
import pickle
from collections.abc import Callable, Mapping
from collections.abc import Awaitable, Callable, Coroutine, Mapping
from datetime import datetime
from typing import cast
from typing import Any, TypeVar, cast

from icechunk._icechunk_python import (
AzureCredentials,
Expand Down Expand Up @@ -46,16 +48,50 @@

AnyCredential = Credentials.S3 | Credentials.Gcs | Credentials.Azure

S3CredentialProvider = Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]]
GcsCredentialProvider = Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]]
CredentialType = TypeVar("CredentialType")


def _resolve_callback_once(
get_credentials: Callable[[], CredentialType | Awaitable[CredentialType]],
) -> CredentialType:
value = get_credentials()
if inspect.isawaitable(value):
if inspect.iscoroutine(value):
coroutine = cast(Coroutine[Any, Any, CredentialType], value)
else:
awaitable = cast(Awaitable[CredentialType], value)

async def _as_coroutine() -> CredentialType:
return await awaitable

coroutine = _as_coroutine()

try:
return asyncio.run(coroutine)
except RuntimeError as err:
coroutine.close()
if "asyncio.run() cannot be called from a running event loop" in str(err):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting.... this is a bit of a breaking change, because everybody (including Earthmover) uses True for scatter

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll think on this..

raise ValueError(
"scatter_initial_credentials=True cannot eagerly evaluate async "
"credential callbacks while an event loop is running. "
"Set scatter_initial_credentials=False in async contexts."
) from err
raise

return value


def s3_refreshable_credentials(
get_credentials: Callable[[], S3StaticCredentials],
get_credentials: S3CredentialProvider,
scatter_initial_credentials: bool = False,
) -> S3Credentials.Refreshable:
"""Create refreshable credentials for S3 and S3 compatible object stores.

Parameters
----------
get_credentials: Callable[[], S3StaticCredentials]
get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]]
Use this function to get and refresh the credentials. The function must be pickable.
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand All @@ -64,7 +100,9 @@ def s3_refreshable_credentials(
set of credentials has expired, the cached value is no longer used. Notice that credentials
obtained are stored, and they can be sent over the network if you pickle the session/repo.
"""
current = get_credentials() if scatter_initial_credentials else None
current = (
_resolve_callback_once(get_credentials) if scatter_initial_credentials else None
)
return S3Credentials.Refreshable(pickle.dumps(get_credentials), current)


Expand Down Expand Up @@ -116,7 +154,7 @@ def s3_credentials(
expires_after: datetime | None = None,
anonymous: bool | None = None,
from_env: bool | None = None,
get_credentials: Callable[[], S3StaticCredentials] | None = None,
get_credentials: S3CredentialProvider | None = None,
scatter_initial_credentials: bool = False,
) -> AnyS3Credential:
"""Create credentials for S3 and S3 compatible object stores.
Expand All @@ -137,7 +175,7 @@ def s3_credentials(
If set to True requests to the object store will not be signed
from_env: bool | None
Fetch credentials from the operative system environment
get_credentials: Callable[[], S3StaticCredentials] | None
get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
Use this function to get and refresh object store credentials
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand Down Expand Up @@ -218,14 +256,14 @@ def gcs_static_credentials(


def gcs_refreshable_credentials(
get_credentials: Callable[[], GcsBearerCredential],
get_credentials: GcsCredentialProvider,
scatter_initial_credentials: bool = False,
) -> GcsCredentials.Refreshable:
"""Create refreshable credentials for Google Cloud Storage object store.

Parameters
----------
get_credentials: Callable[[], S3StaticCredentials]
get_credentials: Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]]
Use this function to get and refresh the credentials. The function must be pickable.
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand All @@ -235,7 +273,9 @@ def gcs_refreshable_credentials(
obtained are stored, and they can be sent over the network if you pickle the session/repo.
"""

current = get_credentials() if scatter_initial_credentials else None
current = (
_resolve_callback_once(get_credentials) if scatter_initial_credentials else None
)
return GcsCredentials.Refreshable(pickle.dumps(get_credentials), current)


Expand All @@ -257,7 +297,7 @@ def gcs_credentials(
bearer_token: str | None = None,
from_env: bool | None = None,
anonymous: bool | None = None,
get_credentials: Callable[[], GcsBearerCredential] | None = None,
get_credentials: GcsCredentialProvider | None = None,
scatter_initial_credentials: bool = False,
) -> AnyGcsCredential:
"""Create credentials Google Cloud Storage object store.
Expand Down
26 changes: 17 additions & 9 deletions icechunk-python/python/icechunk/storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from datetime import datetime

from icechunk._icechunk_python import (
Expand Down Expand Up @@ -125,7 +125,9 @@ def s3_storage(
expires_after: datetime | None = None,
anonymous: bool | None = None,
from_env: bool | None = None,
get_credentials: Callable[[], S3StaticCredentials] | None = None,
get_credentials: (
Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
) = None,
scatter_initial_credentials: bool = False,
force_path_style: bool = False,
network_stream_timeout_seconds: int = 60,
Expand Down Expand Up @@ -157,7 +159,7 @@ def s3_storage(
If set to True requests to the object store will not be signed
from_env: bool | None
Fetch credentials from the operative system environment
get_credentials: Callable[[], S3StaticCredentials] | None
get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
Use this function to get and refresh object store credentials
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand Down Expand Up @@ -254,7 +256,9 @@ def tigris_storage(
expires_after: datetime | None = None,
anonymous: bool | None = None,
from_env: bool | None = None,
get_credentials: Callable[[], S3StaticCredentials] | None = None,
get_credentials: (
Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
) = None,
scatter_initial_credentials: bool = False,
network_stream_timeout_seconds: int = 60,
) -> Storage:
Expand Down Expand Up @@ -288,7 +292,7 @@ def tigris_storage(
If set to True requests to the object store will not be signed
from_env: bool | None
Fetch credentials from the operative system environment
get_credentials: Callable[[], S3StaticCredentials] | None
get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
Use this function to get and refresh object store credentials
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand Down Expand Up @@ -340,7 +344,9 @@ def r2_storage(
expires_after: datetime | None = None,
anonymous: bool | None = None,
from_env: bool | None = None,
get_credentials: Callable[[], S3StaticCredentials] | None = None,
get_credentials: (
Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
) = None,
scatter_initial_credentials: bool = False,
network_stream_timeout_seconds: int = 60,
) -> Storage:
Expand Down Expand Up @@ -374,7 +380,7 @@ def r2_storage(
If set to True requests to the object store will not be signed
from_env: bool | None
Fetch credentials from the operative system environment
get_credentials: Callable[[], S3StaticCredentials] | None
get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
Use this function to get and refresh object store credentials
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand Down Expand Up @@ -436,7 +442,9 @@ def gcs_storage(
anonymous: bool | None = None,
from_env: bool | None = None,
config: dict[str, str] | None = None,
get_credentials: Callable[[], GcsBearerCredential] | None = None,
get_credentials: (
Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] | None
) = None,
scatter_initial_credentials: bool = False,
) -> Storage:
"""Create a Storage instance that saves data in Google Cloud Storage object store.
Expand All @@ -461,7 +469,7 @@ def gcs_storage(
Fetch credentials from the operative system environment
config: dict[str, str] | None
A dictionary of options for the Google Cloud Storage object store. See https://docs.rs/object_store/latest/object_store/gcp/enum.GoogleConfigKey.html#variants for a list of possible configuration keys.
get_credentials: Callable[[], GcsBearerCredential] | None
get_credentials: Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] | None
Use this function to get and refresh object store credentials
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand Down
64 changes: 64 additions & 0 deletions icechunk-python/src/asyncio_bridge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use pyo3::{
PyErr, PyResult, Python,
exceptions::PyRuntimeError,
types::{PyAnyMethods, PyModule},
};
use std::sync::OnceLock;

pub(crate) fn current_task_locals() -> Option<pyo3_async_runtimes::TaskLocals> {
Python::attach(|py| pyo3_async_runtimes::tokio::get_current_locals(py).ok())
}

static FALLBACK_TASK_LOCALS: OnceLock<Result<pyo3_async_runtimes::TaskLocals, String>> =
OnceLock::new();

fn create_fallback_task_locals() -> Result<pyo3_async_runtimes::TaskLocals, String> {
let (tx, rx) = std::sync::mpsc::sync_channel::<
Result<pyo3_async_runtimes::TaskLocals, String>,
>(1);

std::thread::Builder::new()
.name("icechunk-python-asyncio-fallback".to_string())
.spawn(move || {
let mut tx = Some(tx);
let run_result = Python::attach(|py| -> PyResult<()> {
let asyncio = PyModule::import(py, "asyncio")?;
let event_loop = asyncio.getattr("new_event_loop")?.call0()?;
asyncio.getattr("set_event_loop")?.call1((event_loop.clone(),))?;

let locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone());
if let Some(sender) = tx.take() {
let _ = sender.send(Ok(locals));
} else {
return Err(PyRuntimeError::new_err(
"fallback asyncio loop sender was unexpectedly missing",
));
}

event_loop.call_method0("run_forever")?;
Ok(())
});

if let Err(err) = run_result
&& let Some(sender) = tx
{
let _ = sender.send(Err(err.to_string()));
}
})
.map_err(|err| format!("failed to start fallback asyncio loop thread: {err}"))?;

match rx.recv() {
Ok(Ok(locals)) => Ok(locals),
Ok(Err(err)) => Err(format!("failed to initialize fallback asyncio loop: {err}")),
Err(err) => Err(format!(
"fallback asyncio loop thread terminated before initialization: {err}"
)),
}
}

pub(crate) fn fallback_task_locals() -> Result<pyo3_async_runtimes::TaskLocals, PyErr> {
match FALLBACK_TASK_LOCALS.get_or_init(create_fallback_task_locals) {
Ok(locals) => Ok(locals.clone()),
Err(err) => Err(PyRuntimeError::new_err(err.clone())),
}
}
Loading
Loading