Skip to content

Commit 775e0be

Browse files
fix(python): Load _expiry_time from botocore Credentials in CredentialProviderAWS (#23753)
1 parent 257714d commit 775e0be

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

py-polars/polars/dependencies.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
155155
import subprocess
156156

157157
import altair
158+
import boto3
158159
import deltalake
159160
import fsspec
160161
import gevent
@@ -179,6 +180,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
179180

180181
# heavy/optional third party libs
181182
altair, _ALTAIR_AVAILABLE = _lazy_import("altair")
183+
boto3, _BOTO3_AVAILABLE = _lazy_import("boto3")
182184
deltalake, _DELTALAKE_AVAILABLE = _lazy_import("deltalake")
183185
fsspec, _FSSPEC_AVAILABLE = _lazy_import("fsspec")
184186
gevent, _GEVENT_AVAILABLE = _lazy_import("gevent")
@@ -315,6 +317,7 @@ def import_optional(
315317
"subprocess",
316318
# lazy-load third party libs
317319
"altair",
320+
"boto3",
318321
"deltalake",
319322
"fsspec",
320323
"gevent",

py-polars/polars/io/cloud/credential_provider/_providers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from polars.io.cloud._utils import NoPickleOption
2323

2424
if TYPE_CHECKING:
25+
from polars.dependencies import boto3
26+
2527
if sys.version_info >= (3, 10):
2628
from typing import TypeAlias
2729
else:
@@ -171,11 +173,17 @@ def retrieve_credentials_impl(self) -> CredentialProviderFunctionReturn:
171173
msg = "did not receive any credentials from boto3.Session.get_credentials()"
172174
raise self.EmptyCredentialError(msg)
173175

176+
expiry = (
177+
int(expiry.timestamp())
178+
if isinstance(expiry := getattr(creds, "_expiry_time", None), datetime)
179+
else None
180+
)
181+
174182
return {
175183
"aws_access_key_id": creds.access_key,
176184
"aws_secret_access_key": creds.secret_key,
177185
**({"aws_session_token": creds.token} if creds.token is not None else {}),
178-
}, None
186+
}, expiry
179187

180188
def _finish_assume_role(self, session: Any) -> CredentialProviderFunctionReturn:
181189
client = session.client("sts")
@@ -245,7 +253,7 @@ def _can_use_as_provider(self) -> bool:
245253

246254
return True
247255

248-
def _session(self) -> Any:
256+
def _session(self) -> boto3.Session:
249257
# Note: boto3 automatically sources the AWS_PROFILE env var
250258
import boto3
251259

py-polars/tests/unit/io/cloud/test_credential_provider.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import io
22
import pickle
3+
import sys
4+
from datetime import datetime, timezone
35
from functools import lru_cache
46
from pathlib import Path
57
from typing import Any
@@ -539,3 +541,68 @@ def cache(value: ZeroHashWrap[Any]) -> int:
539541
assert cache(ZeroHashWrap(3)) == 3
540542
assert cache(ZeroHashWrap(7)) == 3
541543
assert cache(ZeroHashWrap("A")) == 3
544+
545+
546+
@pytest.mark.write_disk
547+
def test_credential_provider_aws_expiry(
548+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
549+
) -> None:
550+
credential_file_path = tmp_path / "credentials.json"
551+
552+
credential_file_path.write_text(
553+
"""\
554+
{
555+
"Version": 1,
556+
"AccessKeyId": "123",
557+
"SecretAccessKey": "456",
558+
"SessionToken": "789",
559+
"Expiration": "2099-01-01T00:00:00+00:00"
560+
}
561+
"""
562+
)
563+
564+
cfg_file_path = tmp_path / "config"
565+
566+
credential_file_path_str = str(credential_file_path).replace("\\", "/")
567+
568+
cfg_file_path.write_text(f"""\
569+
[profile cred_process]
570+
credential_process = "{sys.executable}" -c "from pathlib import Path; print(Path('{credential_file_path_str}').read_text())"
571+
""")
572+
573+
monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))
574+
575+
creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()
576+
577+
assert creds == {
578+
"aws_access_key_id": "123",
579+
"aws_secret_access_key": "456",
580+
"aws_session_token": "789",
581+
}
582+
583+
assert expiry is not None
584+
585+
assert datetime.fromtimestamp(expiry, tz=timezone.utc) == datetime.fromisoformat(
586+
"2099-01-01T00:00:00+00:00"
587+
)
588+
589+
credential_file_path.write_text(
590+
"""\
591+
{
592+
"Version": 1,
593+
"AccessKeyId": "...",
594+
"SecretAccessKey": "...",
595+
"SessionToken": "..."
596+
}
597+
"""
598+
)
599+
600+
creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()
601+
602+
assert creds == {
603+
"aws_access_key_id": "...",
604+
"aws_secret_access_key": "...",
605+
"aws_session_token": "...",
606+
}
607+
608+
assert expiry is None

0 commit comments

Comments
 (0)