Skip to content

Commit 5bb0f7d

Browse files
authored
Credential refresh can be initialized with the first set of credentials (#921)
`s3_storage` style functions got a new argument: ``` resolve_get_credentials: bool = False ``` If True is passed, pickled repos and sessions will use the result returned the first time `get_credentials` was called, instead of having to run it again. This speeds up short distributed tasks, since they no longer need to spend time fetching credentials. Tasks will reuse the credentials gathered by the "coordinator"
1 parent f7f62be commit 5bb0f7d

File tree

8 files changed

+246
-63
lines changed

8 files changed

+246
-63
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

icechunk-python/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ typetag = "0.2.20"
3939
serde = { version = "1.0.219", features = ["derive", "rc"] }
4040
miette = { version = "7.5.0", features = ["fancy"] }
4141
clap = { version = "4.5", features = ["derive"], optional = true }
42+
rand = "0.9.0"
4243

4344
[features]
4445
cli = ["clap", "icechunk/cli"]

icechunk-python/python/icechunk/_icechunk_python.pyi

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,8 +1260,13 @@ class S3Credentials:
12601260
----------
12611261
pickled_function: bytes
12621262
The pickled function to use to provide credentials.
1263+
current: S3StaticCredentials
1264+
The initial credentials. They will be returned the first time credentials
1265+
are requested and then deleted.
12631266
"""
1264-
def __init__(self, pickled_function: bytes) -> None: ...
1267+
def __init__(
1268+
self, pickled_function: bytes, current: S3StaticCredentials | None = None
1269+
) -> None: ...
12651270

12661271
AnyS3Credential = (
12671272
S3Credentials.Static
@@ -1359,7 +1364,9 @@ class GcsCredentials:
13591364
13601365
This is useful for credentials that have an expiration time, or are otherwise not known ahead of time.
13611366
"""
1362-
def __init__(self, pickled_function: bytes) -> None: ...
1367+
def __init__(
1368+
self, pickled_function: bytes, current: GcsBearerCredential | None = None
1369+
) -> None: ...
13631370

13641371
AnyGcsCredential = (
13651372
GcsCredentials.FromEnv | GcsCredentials.Static | GcsCredentials.Refreshable

icechunk-python/python/icechunk/credentials.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,23 @@
4545

4646
def s3_refreshable_credentials(
4747
get_credentials: Callable[[], S3StaticCredentials],
48+
scatter_initial_credentials: bool = False,
4849
) -> S3Credentials.Refreshable:
4950
"""Create refreshable credentials for S3 and S3 compatible object stores.
5051
51-
5252
Parameters
5353
----------
5454
get_credentials: Callable[[], S3StaticCredentials]
5555
Use this function to get and refresh the credentials. The function must be pickable.
56+
scatter_initial_credentials: bool, optional
57+
Immediately call and store the value returned by get_credentials. This is useful if the
58+
repo or session will be pickled to generate many copies. Passing scatter_initial_credentials=True will
59+
ensure all those copies don't need to call get_credentials immediately. After the initial
60+
set of credentials has expired, the cached value is no longer used. Notice that credentials
61+
obtained are stored, and they can be sent over the network if you pickle the session/repo.
5662
"""
57-
return S3Credentials.Refreshable(pickle.dumps(get_credentials))
63+
current = get_credentials() if scatter_initial_credentials else None
64+
return S3Credentials.Refreshable(pickle.dumps(get_credentials), current)
5865

5966

6067
def s3_static_credentials(
@@ -106,6 +113,7 @@ def s3_credentials(
106113
anonymous: bool | None = None,
107114
from_env: bool | None = None,
108115
get_credentials: Callable[[], S3StaticCredentials] | None = None,
116+
scatter_initial_credentials: bool = False,
109117
) -> AnyS3Credential:
110118
"""Create credentials for S3 and S3 compatible object stores.
111119
@@ -127,6 +135,12 @@ def s3_credentials(
127135
Fetch credentials from the operative system environment
128136
get_credentials: Callable[[], S3StaticCredentials] | None
129137
Use this function to get and refresh object store credentials
138+
scatter_initial_credentials: bool, optional
139+
Immediately call and store the value returned by get_credentials. This is useful if the
140+
repo or session will be pickled to generate many copies. Passing scatter_initial_credentials=True will
141+
ensure all those copies don't need to call get_credentials immediately. After the initial
142+
set of credentials has expired, the cached value is no longer used. Notice that credentials
143+
obtained are stored, and they can be sent over the network if you pickle the session/repo.
130144
"""
131145
if (
132146
(from_env is None or from_env)
@@ -159,7 +173,9 @@ def s3_credentials(
159173
and not from_env
160174
and not anonymous
161175
):
162-
return s3_refreshable_credentials(get_credentials)
176+
return s3_refreshable_credentials(
177+
get_credentials, scatter_initial_credentials=scatter_initial_credentials
178+
)
163179

164180
if (
165181
access_key_id
@@ -199,9 +215,24 @@ def gcs_static_credentials(
199215

200216
def gcs_refreshable_credentials(
201217
get_credentials: Callable[[], GcsBearerCredential],
218+
scatter_initial_credentials: bool = False,
202219
) -> GcsCredentials.Refreshable:
203-
"""Create refreshable credentials for Google Cloud Storage object store."""
204-
return GcsCredentials.Refreshable(pickle.dumps(get_credentials))
220+
"""Create refreshable credentials for Google Cloud Storage object store.
221+
222+
Parameters
223+
----------
224+
get_credentials: Callable[[], S3StaticCredentials]
225+
Use this function to get and refresh the credentials. The function must be pickable.
226+
scatter_initial_credentials: bool, optional
227+
Immediately call and store the value returned by get_credentials. This is useful if the
228+
repo or session will be pickled to generate many copies. Passing scatter_initial_credentials=True will
229+
ensure all those copies don't need to call get_credentials immediately. After the initial
230+
set of credentials has expired, the cached value is no longer used. Notice that credentials
231+
obtained are stored, and they can be sent over the network if you pickle the session/repo.
232+
"""
233+
234+
current = get_credentials() if scatter_initial_credentials else None
235+
return GcsCredentials.Refreshable(pickle.dumps(get_credentials), current)
205236

206237

207238
def gcs_from_env_credentials() -> GcsCredentials.FromEnv:
@@ -217,6 +248,7 @@ def gcs_credentials(
217248
bearer_token: str | None = None,
218249
from_env: bool | None = None,
219250
get_credentials: Callable[[], GcsBearerCredential] | None = None,
251+
scatter_initial_credentials: bool = False,
220252
) -> AnyGcsCredential:
221253
"""Create credentials Google Cloud Storage object store.
222254
@@ -246,7 +278,9 @@ def gcs_credentials(
246278
)
247279

248280
if get_credentials is not None:
249-
return gcs_refreshable_credentials(get_credentials)
281+
return gcs_refreshable_credentials(
282+
get_credentials, scatter_initial_credentials=scatter_initial_credentials
283+
)
250284

251285
raise ValueError("Conflicting arguments to gcs_credentials function")
252286

icechunk-python/python/icechunk/storage.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def s3_store(
4949
force_path_style: bool = False,
5050
) -> ObjectStoreConfig.S3Compatible | ObjectStoreConfig.S3:
5151
"""Build an ObjectStoreConfig instance for S3 or S3 compatible object stores."""
52+
5253
options = S3Options(
5354
region=region,
5455
endpoint_url=endpoint_url,
@@ -76,6 +77,7 @@ def s3_storage(
7677
anonymous: bool | None = None,
7778
from_env: bool | None = None,
7879
get_credentials: Callable[[], S3StaticCredentials] | None = None,
80+
scatter_initial_credentials: bool = False,
7981
force_path_style: bool = False,
8082
) -> Storage:
8183
"""Create a Storage instance that saves data in S3 or S3 compatible object stores.
@@ -106,6 +108,12 @@ def s3_storage(
106108
Fetch credentials from the operative system environment
107109
get_credentials: Callable[[], S3StaticCredentials] | None
108110
Use this function to get and refresh object store credentials
111+
scatter_initial_credentials: bool, optional
112+
Immediately call and store the value returned by get_credentials. This is useful if the
113+
repo or session will be pickled to generate many copies. Passing scatter_initial_credentials=True will
114+
ensure all those copies don't need to call get_credentials immediately. After the initial
115+
set of credentials has expired, the cached value is no longer used. Notice that credentials
116+
obtained are stored, and they can be sent over the network if you pickle the session/repo.
109117
force_path_style: bool
110118
Whether to force using path-style addressing for buckets
111119
"""
@@ -118,6 +126,7 @@ def s3_storage(
118126
anonymous=anonymous,
119127
from_env=from_env,
120128
get_credentials=get_credentials,
129+
scatter_initial_credentials=scatter_initial_credentials,
121130
)
122131
options = S3Options(
123132
region=region,
@@ -186,6 +195,7 @@ def tigris_storage(
186195
anonymous: bool | None = None,
187196
from_env: bool | None = None,
188197
get_credentials: Callable[[], S3StaticCredentials] | None = None,
198+
scatter_initial_credentials: bool = False,
189199
) -> Storage:
190200
"""Create a Storage instance that saves data in Tigris object store.
191201
@@ -219,6 +229,12 @@ def tigris_storage(
219229
Fetch credentials from the operative system environment
220230
get_credentials: Callable[[], S3StaticCredentials] | None
221231
Use this function to get and refresh object store credentials
232+
scatter_initial_credentials: bool, optional
233+
Immediately call and store the value returned by get_credentials. This is useful if the
234+
repo or session will be pickled to generate many copies. Passing scatter_initial_credentials=True will
235+
ensure all those copies don't need to call get_credentials immediately. After the initial
236+
set of credentials has expired, the cached value is no longer used. Notice that credentials
237+
obtained are stored, and they can be sent over the network if you pickle the session/repo.
222238
"""
223239
credentials = s3_credentials(
224240
access_key_id=access_key_id,
@@ -228,6 +244,7 @@ def tigris_storage(
228244
anonymous=anonymous,
229245
from_env=from_env,
230246
get_credentials=get_credentials,
247+
scatter_initial_credentials=scatter_initial_credentials,
231248
)
232249
options = S3Options(region=region, endpoint_url=endpoint_url, allow_http=allow_http)
233250
return Storage.new_tigris(
@@ -254,6 +271,7 @@ def r2_storage(
254271
anonymous: bool | None = None,
255272
from_env: bool | None = None,
256273
get_credentials: Callable[[], S3StaticCredentials] | None = None,
274+
scatter_initial_credentials: bool = False,
257275
) -> Storage:
258276
"""Create a Storage instance that saves data in Tigris object store.
259277
@@ -287,6 +305,12 @@ def r2_storage(
287305
Fetch credentials from the operative system environment
288306
get_credentials: Callable[[], S3StaticCredentials] | None
289307
Use this function to get and refresh object store credentials
308+
scatter_initial_credentials: bool, optional
309+
Immediately call and store the value returned by get_credentials. This is useful if the
310+
repo or session will be pickled to generate many copies. Passing scatter_initial_credentials=True will
311+
ensure all those copies don't need to call get_credentials immediately. After the initial
312+
set of credentials has expired, the cached value is no longer used. Notice that credentials
313+
obtained are stored, and they can be sent over the network if you pickle the session/repo.
290314
"""
291315
credentials = s3_credentials(
292316
access_key_id=access_key_id,
@@ -296,6 +320,7 @@ def r2_storage(
296320
anonymous=anonymous,
297321
from_env=from_env,
298322
get_credentials=get_credentials,
323+
scatter_initial_credentials=scatter_initial_credentials,
299324
)
300325
options = S3Options(region=region, endpoint_url=endpoint_url, allow_http=allow_http)
301326
return Storage.new_r2(
@@ -318,6 +343,7 @@ def gcs_storage(
318343
from_env: bool | None = None,
319344
config: dict[str, str] | None = None,
320345
get_credentials: Callable[[], GcsBearerCredential] | None = None,
346+
scatter_initial_credentials: bool = False,
321347
) -> Storage:
322348
"""Create a Storage instance that saves data in Google Cloud Storage object store.
323349
@@ -333,6 +359,12 @@ def gcs_storage(
333359
The bearer token to use for the object store
334360
get_credentials: Callable[[], GcsBearerCredential] | None
335361
Use this function to get and refresh object store credentials
362+
scatter_initial_credentials: bool, optional
363+
Immediately call and store the value returned by get_credentials. This is useful if the
364+
repo or session will be pickled to generate many copies. Passing scatter_initial_credentials=True will
365+
ensure all those copies don't need to call get_credentials immediately. After the initial
366+
set of credentials has expired, the cached value is no longer used. Notice that credentials
367+
obtained are stored, and they can be sent over the network if you pickle the session/repo.
336368
"""
337369
credentials = gcs_credentials(
338370
service_account_file=service_account_file,
@@ -341,6 +373,7 @@ def gcs_storage(
341373
bearer_token=bearer_token,
342374
from_env=from_env,
343375
get_credentials=get_credentials,
376+
scatter_initial_credentials=scatter_initial_credentials,
344377
)
345378
return Storage.new_gcs(
346379
bucket=bucket,

0 commit comments

Comments
 (0)