Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s3: Load aws env vars from local environment #22008

Merged
merged 3 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/notes/2.26.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ The `experiemental_test_shell_command` target type now supports the same `runnab

For the `tfsec` linter, the deprecation of support for leading `v`s in the `version` and `known_versions` field has expired and been removed. Write `1.28.13` instead of `v1.28.13`.

#### S3

The S3 backend now creates new AWS credentials when the `AWS_` environment variables change. This allows credentials to be updated without restarting the Pants daemon.

### Plugin API changes


Expand Down
85 changes: 70 additions & 15 deletions src/python/pants/backend/url_handlers/s3/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import pytest

from pants.backend.url_handlers.s3.register import (
AWSCredentials,
DownloadS3AuthorityPathStyleURL,
DownloadS3AuthorityVirtualHostedStyleURL,
DownloadS3SchemeURL,
)
from pants.backend.url_handlers.s3.register import rules as s3_rules
from pants.engine.env_vars import EnvironmentVars, EnvironmentVarsRequest
from pants.engine.fs import Digest, FileDigest, NativeDownloadFile, Snapshot
from pants.engine.rules import QueryRule
from pants.testutil.rule_runner import RuleRunner
Expand All @@ -35,6 +37,8 @@ def rule_runner() -> RuleRunner:
QueryRule(Snapshot, [DownloadS3SchemeURL]),
QueryRule(Snapshot, [DownloadS3AuthorityVirtualHostedStyleURL]),
QueryRule(Snapshot, [DownloadS3AuthorityPathStyleURL]),
QueryRule(AWSCredentials, []),
QueryRule(EnvironmentVars, [EnvironmentVarsRequest]),
],
isolated_local_store=True,
)
Expand All @@ -45,28 +49,49 @@ def monkeypatch_botocore(monkeypatch):
def do_patching(expected_url):
botocore = SimpleNamespace()
botocore.exceptions = SimpleNamespace(NoCredentialsError=Exception)
fake_session = object()
fake_creds = SimpleNamespace(access_key="ACCESS", secret_key="SECRET", token=None)
botocore.session = SimpleNamespace(get_session=lambda: fake_session)
# NB: HTTPHeaders is just a simple subclass of HTTPMessage
botocore.compat = SimpleNamespace(HTTPHeaders=HTTPMessage)

def fake_resolver_creator(session):
assert session is fake_session
return SimpleNamespace(load_credentials=lambda: fake_creds)
class FakeSession:
def __init__(self):
self.config_vars = {}
self.creds = None

def set_config_variable(self, key, value):
self.config_vars.update({key: value})

def get_credentials(self):
if self.creds:
return self.creds

key = "ACCESS"
secret = "SECRET"
# suffix the access key with the profile name to make testing easier
if self.config_vars.get("profile"):
key = f"ACCESS_{self.config_vars.get('profile')}"
return FakeCredentials.create(access_key=key, secret_key=secret)

def set_credentials(self, creds):
self.creds = creds

def fake_creds_ctor(access_key, secret_key):
assert access_key == fake_creds.access_key
assert secret_key == fake_creds.secret_key
return fake_creds
class FakeCredentials:
@staticmethod
def create(access_key, secret_key, token=None):
return SimpleNamespace(access_key=access_key, secret_key=secret_key, token=token)

class FakeCredentialsResolver:
def __init__(self, session):
self.session = session

def load_credentials(self):
return self.session.get_credentials()

botocore.session = SimpleNamespace(Session=lambda: FakeSession())
botocore.compat = SimpleNamespace(HTTPHeaders=HTTPMessage)
botocore.credentials = SimpleNamespace(
create_credential_resolver=fake_resolver_creator, Credentials=fake_creds_ctor
create_credential_resolver=lambda session: FakeCredentialsResolver(session),
Credentials=FakeCredentials.create,
)

def fake_auth_ctor(creds):
assert creds is fake_creds

def add_auth(request):
request.url == expected_url
request.headers["AUTH"] = "TOKEN"
Expand Down Expand Up @@ -197,3 +222,33 @@ def send_headers(self):
)
assert snapshot.files == ("file.txt",)
assert snapshot.digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST


def test_aws_credentials_caching(rule_runner: RuleRunner, monkeypatch_botocore) -> None:
"""Test that AWS credentials are properly cached based on environment variables."""
monkeypatch_botocore("https://example.com")

def set_aws_env_vars(rule_runner: RuleRunner, env_vars: dict[str, str]) -> None:
rule_runner.set_options(
args=[],
env=env_vars,
)

set_aws_env_vars(rule_runner, {"AWS_PROFILE": "profile1"})

creds1 = rule_runner.request(AWSCredentials, [])
creds2 = rule_runner.request(AWSCredentials, [])
assert creds1 is creds2
assert creds1.creds.access_key == "ACCESS_profile1"

# Request with different environment should return different credentials
set_aws_env_vars(rule_runner, {"AWS_PROFILE": "profile2"})
creds3 = rule_runner.request(AWSCredentials, [])
assert creds1 is not creds3
assert creds3.creds.access_key == "ACCESS_profile2"

# Request with original environment should return original credentials
set_aws_env_vars(rule_runner, {"AWS_PROFILE": "profile1"})
creds4 = rule_runner.request(AWSCredentials, [])
# N.B. Not totally sure why, but 'is' doesn't work here because of how set_options operates
assert creds1 == creds4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally happy with these tests as I couldn't get the assertions working that the exact same object is returned, but in local testing with pantsd it appears to work correctly

54 changes: 47 additions & 7 deletions src/python/pants/backend/url_handlers/s3/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from urllib.parse import urlsplit

from pants.engine.download_file import URLDownloadHandler
from pants.engine.env_vars import EnvironmentVars, EnvironmentVarsRequest
from pants.engine.environment import ChosenLocalEnvironmentName, EnvironmentName
from pants.engine.fs import Digest, NativeDownloadFile
from pants.engine.internals.native_engine import FileDigest
from pants.engine.internals.selectors import Get
Expand All @@ -29,9 +31,12 @@ class AWSCredentials:


@rule
async def access_aws_credentials() -> AWSCredentials:
async def access_aws_credentials(
local_environment_name: ChosenLocalEnvironmentName,
) -> AWSCredentials:
try:
from botocore import credentials, session
from botocore import credentials
from botocore import session as boto_session
except ImportError:
logger.warning(
softwrap(
Expand All @@ -48,7 +53,44 @@ async def access_aws_credentials() -> AWSCredentials:
)
raise

session = session.get_session()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was previously shadowing the import so I updated the import to an alias to avoid confusion

env_vars = await Get(
EnvironmentVars,
{
EnvironmentVarsRequest(
[
"AWS_PROFILE",
"AWS_REGION",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
]
): EnvironmentVarsRequest,
local_environment_name.val: EnvironmentName,
},
)

session = boto_session.Session()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


aws_profile = env_vars.get("AWS_PROFILE")
if aws_profile:
session.set_config_variable("profile", aws_profile)

aws_region = env_vars.get("AWS_REGION")
if aws_region:
session.set_config_variable("region", aws_region)

aws_access_key = env_vars.get("AWS_ACCESS_KEY_ID")
aws_secret_key = env_vars.get("AWS_SECRET_ACCESS_KEY")
aws_session_token = env_vars.get("AWS_SESSION_TOKEN")
if aws_access_key and aws_secret_key:
session.set_credentials(
credentials.Credentials(
access_key=aws_access_key,
secret_key=aws_secret_key,
token=aws_session_token,
)
)

creds = credentials.create_credential_resolver(session).load_credentials()

return AWSCredentials(creds)
Expand Down Expand Up @@ -136,7 +178,7 @@ class DownloadS3AuthorityVirtualHostedStyleURL(URLDownloadHandler):

@rule
async def download_file_from_virtual_hosted_s3_authority(
request: DownloadS3AuthorityVirtualHostedStyleURL, aws_credentials: AWSCredentials
request: DownloadS3AuthorityVirtualHostedStyleURL,
) -> Digest:
split = urlsplit(request.url)
bucket, aws_netloc = split.netloc.split(".", 1)
Expand All @@ -157,9 +199,7 @@ class DownloadS3AuthorityPathStyleURL(URLDownloadHandler):


@rule
async def download_file_from_path_s3_authority(
request: DownloadS3AuthorityPathStyleURL, aws_credentials: AWSCredentials
) -> Digest:
async def download_file_from_path_s3_authority(request: DownloadS3AuthorityPathStyleURL) -> Digest:
split = urlsplit(request.url)
_, bucket, key = split.path.split("/", 2)
return await Get(
Expand Down
Loading