Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions icechunk-python/python/icechunk/credentials.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import inspect
import pickle
from collections.abc import Callable, Mapping
from collections.abc import Awaitable, Callable, Coroutine, Mapping
from datetime import datetime
from typing import cast
from typing import Any, TypeVar, cast

from icechunk._icechunk_python import (
AzureCredentials,
Expand Down Expand Up @@ -46,16 +48,50 @@

AnyCredential = Credentials.S3 | Credentials.Gcs | Credentials.Azure

S3CredentialProvider = Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]]
GcsCredentialProvider = Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]]
CredentialType = TypeVar("CredentialType")


def _resolve_callback_once(
get_credentials: Callable[[], CredentialType | Awaitable[CredentialType]],
) -> CredentialType:
value = get_credentials()
if inspect.isawaitable(value):
if inspect.iscoroutine(value):
coroutine = cast(Coroutine[Any, Any, CredentialType], value)
else:
awaitable = cast(Awaitable[CredentialType], value)

async def _as_coroutine() -> CredentialType:
return await awaitable

coroutine = _as_coroutine()

try:
return asyncio.run(coroutine)
except RuntimeError as err:
if "asyncio.run() cannot be called from a running event loop" in str(err):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting.... this is a bit of a breaking change, because everybody (including Earthmover) uses True for scatter

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll think on this..

coroutine.close()
raise ValueError(
"scatter_initial_credentials=True cannot eagerly evaluate async "
"credential callbacks while an event loop is running. "
"Set scatter_initial_credentials=False in async contexts."
) from err
raise

return value


def s3_refreshable_credentials(
get_credentials: Callable[[], S3StaticCredentials],
get_credentials: S3CredentialProvider,
scatter_initial_credentials: bool = False,
) -> S3Credentials.Refreshable:
"""Create refreshable credentials for S3 and S3 compatible object stores.

Parameters
----------
get_credentials: Callable[[], S3StaticCredentials]
get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]]
Use this function to get and refresh the credentials. The function must be pickable.
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand All @@ -64,7 +100,9 @@ def s3_refreshable_credentials(
set of credentials has expired, the cached value is no longer used. Notice that credentials
obtained are stored, and they can be sent over the network if you pickle the session/repo.
"""
current = get_credentials() if scatter_initial_credentials else None
current = (
_resolve_callback_once(get_credentials) if scatter_initial_credentials else None
)
return S3Credentials.Refreshable(pickle.dumps(get_credentials), current)


Expand Down Expand Up @@ -116,7 +154,7 @@ def s3_credentials(
expires_after: datetime | None = None,
anonymous: bool | None = None,
from_env: bool | None = None,
get_credentials: Callable[[], S3StaticCredentials] | None = None,
get_credentials: S3CredentialProvider | None = None,
scatter_initial_credentials: bool = False,
) -> AnyS3Credential:
"""Create credentials for S3 and S3 compatible object stores.
Expand All @@ -137,7 +175,7 @@ def s3_credentials(
If set to True requests to the object store will not be signed
from_env: bool | None
Fetch credentials from the operative system environment
get_credentials: Callable[[], S3StaticCredentials] | None
get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None
Use this function to get and refresh object store credentials
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand Down Expand Up @@ -218,14 +256,14 @@ def gcs_static_credentials(


def gcs_refreshable_credentials(
get_credentials: Callable[[], GcsBearerCredential],
get_credentials: GcsCredentialProvider,
scatter_initial_credentials: bool = False,
) -> GcsCredentials.Refreshable:
"""Create refreshable credentials for Google Cloud Storage object store.

Parameters
----------
get_credentials: Callable[[], S3StaticCredentials]
get_credentials: Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]]
Use this function to get and refresh the credentials. The function must be pickable.
scatter_initial_credentials: bool, optional
Immediately call and store the value returned by get_credentials. This is useful if the
Expand All @@ -235,7 +273,9 @@ def gcs_refreshable_credentials(
obtained are stored, and they can be sent over the network if you pickle the session/repo.
"""

current = get_credentials() if scatter_initial_credentials else None
current = (
_resolve_callback_once(get_credentials) if scatter_initial_credentials else None
)
return GcsCredentials.Refreshable(pickle.dumps(get_credentials), current)


Expand All @@ -257,7 +297,7 @@ def gcs_credentials(
bearer_token: str | None = None,
from_env: bool | None = None,
anonymous: bool | None = None,
get_credentials: Callable[[], GcsBearerCredential] | None = None,
get_credentials: GcsCredentialProvider | None = None,
scatter_initial_credentials: bool = False,
) -> AnyGcsCredential:
"""Create credentials Google Cloud Storage object store.
Expand Down
Loading
Loading