Skip to content

Commit a703a55

Browse files
Lawhyclaude
andcommitted
refactor(utils): unify AWS session functions into single get_session
- Merge get_boto3_session and get_assumed_role_session into get_session(region, profile_name, role_arn) - Rename clear_sessions to clear_session_cache - Remove top-level exports from utils/__init__.py for clearer imports - Users now import from specific modules: strands_env.utils.aws, strands_env.utils.sglang - Update all usages in CLI, environments, tools, and documentation Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 009c4ec commit a703a55

File tree

8 files changed

+64
-87
lines changed

8 files changed

+64
-87
lines changed

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ The package lives in `src/strands_env/` with these modules:
8585

8686
**sglang.py** — SGLang client caching with `@lru_cache`. `get_cached_client(base_url, max_connections)` for connection pooling. `get_cached_client_from_slime_args(args)` for slime RL training integration. `check_server_health(base_url)` for early validation.
8787

88-
**aws.py** — AWS boto3 session caching. `get_boto3_session(region, profile_name)` with `@lru_cache` (boto3 handles credential refresh). `get_assumed_role_session(role_arn, region)` uses `RefreshableCredentials` for programmatic role assumption with auto-refresh.
88+
**aws.py** — AWS boto3 session caching. `get_session(region, profile_name, role_arn)` with `@lru_cache`. If `role_arn` provided, uses `RefreshableCredentials` for programmatic role assumption with auto-refresh; otherwise returns basic session.
8989

9090
### `tools/`
9191

docs/rl-training.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ Customize the `generate` and `reward_func` methods to replace single generation
1717
```python
1818
from strands_env.core import Action, TaskContext
1919
from strands_env.core.models import sglang_model_factory
20-
from strands_env.utils import get_cached_client_from_slime_args
20+
from strands_env.utils.sglang import get_cached_client_from_slime_args
2121

2222
async def generate(args, sample, sampling_params):
2323
# Build model factory with cached client
2424
factory = sglang_model_factory(
25-
model_id=args.hf_checkpoint,
2625
tokenizer=tokenizer,
2726
client=get_cached_client_from_slime_args(args),
2827
sampling_params=sampling_params,

src/strands_env/cli/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,15 @@ def _build_sglang_model_factory(config: ModelConfig, max_concurrency: int, sampl
273273

274274
def _build_bedrock_model_factory(config: ModelConfig, sampling: dict) -> ModelFactory:
275275
"""Build Bedrock model factory."""
276-
from strands_env.utils.aws import get_assumed_role_session, get_boto3_session
276+
from strands_env.utils.aws import get_session
277277

278278
if not config.model_id:
279279
raise click.ClickException("--model-id is required for Bedrock backend")
280280

281-
if config.role_arn:
282-
boto_session = get_assumed_role_session(config.role_arn, config.region)
283-
else:
284-
boto_session = get_boto3_session(config.region, config.profile_name)
281+
boto_session = get_session(
282+
region=config.region,
283+
profile_name=config.profile_name,
284+
role_arn=config.role_arn,
285+
)
285286

286287
return bedrock_model_factory(model_id=config.model_id, boto_session=boto_session, sampling_params=sampling)

src/strands_env/environments/code_sandbox/env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from strands_env.core.environment import Environment
2626
from strands_env.tools import CodeInterpreterToolkit
27-
from strands_env.utils.aws import get_boto3_session
27+
from strands_env.utils.aws import get_session
2828

2929
if TYPE_CHECKING:
3030
import boto3
@@ -53,9 +53,9 @@ class CodeSandboxEnv(Environment):
5353
5454
Example:
5555
from strands_env.environments.code_sandbox import CodeSandboxEnv, CodeMode
56-
from strands_env.utils import get_boto3_session
56+
from strands_env.utils.aws import get_session
5757
58-
session = get_boto3_session(region="us-east-1")
58+
session = get_session(region="us-east-1")
5959
env = CodeSandboxEnv(
6060
boto3_session=session,
6161
model_factory=model_factory,
@@ -99,7 +99,7 @@ def __init__(
9999
)
100100
self.mode = mode
101101
self._toolkit = CodeInterpreterToolkit(
102-
boto3_session=boto3_session or get_boto3_session(), session_name="strands-env-code-sandbox"
102+
boto3_session=boto3_session or get_session(), session_name="strands-env-code-sandbox"
103103
)
104104

105105
@override

src/strands_env/tools/code_interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class CodeInterpreterToolkit:
3333
and shell commands in a sandboxed environment.
3434
3535
Example:
36-
from strands_env.utils import get_boto3_session
36+
from strands_env.utils.aws import get_session
3737
38-
session = get_boto3_session(region="us-east-1")
38+
session = get_session(region="us-east-1")
3939
toolkit = CodeInterpreterToolkit(boto3_session=session)
4040
4141
# In environment:

src/strands_env/utils/__init__.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,3 @@
1313
# limitations under the License.
1414

1515
"""Utilities for `strands_env`."""
16-
17-
from .aws import clear_sessions, get_assumed_role_session, get_boto3_session
18-
from .sglang import clear_clients, get_cached_client, get_cached_client_from_slime_args
19-
20-
__all__ = [
21-
# AWS
22-
"clear_sessions",
23-
"get_assumed_role_session",
24-
"get_boto3_session",
25-
# SGLang
26-
"clear_clients",
27-
"get_cached_client",
28-
"get_cached_client_from_slime_args",
29-
]

src/strands_env/utils/aws.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,40 @@
2525

2626

2727
@lru_cache(maxsize=None)
28-
def get_boto3_session(region: str = "us-east-1", profile_name: str | None = None) -> boto3.Session:
28+
def get_session(
29+
region: str = "us-east-1",
30+
profile_name: str | None = None,
31+
role_arn: str | None = None,
32+
session_name: str = "StrandsEnvSession",
33+
) -> boto3.Session:
2934
"""Get a cached boto3 session.
3035
3136
Credentials are managed by boto3's provider chain (env vars, ~/.aws/credentials,
3237
IAM instance role, etc.) and auto-refresh automatically.
3338
34-
For role assumption, either:
35-
1. Configure a profile in ~/.aws/config with `role_arn` and `source_profile`, or
36-
2. Use `get_assumed_role_session()` for programmatic role assumption
39+
If `role_arn` is provided, assumes the role using STS with auto-refreshing
40+
credentials via botocore's `RefreshableCredentials`.
3741
3842
Args:
3943
region: AWS region name.
4044
profile_name: Optional AWS profile name from ~/.aws/config.
45+
role_arn: Optional ARN of the IAM role to assume.
46+
session_name: Session name for assumed role (only used if role_arn provided).
4147
4248
Returns:
4349
Cached boto3 Session instance.
4450
"""
45-
logger.info(f"Creating boto3 session: region={region}, profile={profile_name}")
46-
return boto3.Session(region_name=region, profile_name=profile_name)
47-
48-
49-
@lru_cache(maxsize=None)
50-
def get_assumed_role_session(
51-
role_arn: str,
52-
region: str = "us-east-1",
53-
session_name: str = "StrandsEnvSession",
54-
) -> boto3.Session:
55-
"""Get a cached boto3 session with assumed role credentials.
51+
if role_arn:
52+
return _create_assumed_role_session(role_arn, region, session_name)
53+
else:
54+
logger.info(f"Creating boto3 session: region={region}, profile={profile_name}")
55+
return boto3.Session(region_name=region, profile_name=profile_name)
5656

57-
Uses botocore's `RefreshableCredentials` so credentials auto-refresh when expired.
58-
The session is cached by (role_arn, region, session_name).
5957

60-
Args:
61-
role_arn: ARN of the IAM role to assume.
62-
region: AWS region name.
63-
session_name: Session name for the assumed role.
64-
65-
Returns:
66-
Cached boto3 Session with auto-refreshing credentials.
67-
"""
58+
def _create_assumed_role_session(role_arn: str, region: str, session_name: str) -> boto3.Session:
59+
"""Create a boto3 session with assumed role credentials."""
6860
from botocore.credentials import RefreshableCredentials
69-
from botocore.session import get_session
61+
from botocore.session import get_session as get_botocore_session
7062

7163
logger.info(f"Creating boto3 session with assumed role: role={role_arn}, region={region}")
7264

@@ -87,12 +79,11 @@ def refresh() -> dict:
8779
method="sts-assume-role",
8880
)
8981

90-
botocore_session = get_session()
82+
botocore_session = get_botocore_session()
9183
botocore_session._credentials = session_credentials
9284
return boto3.Session(botocore_session=botocore_session, region_name=region)
9385

9486

95-
def clear_sessions() -> None:
87+
def clear_session_cache() -> None:
9688
"""Clear all cached boto3 sessions."""
97-
get_boto3_session.cache_clear()
98-
get_assumed_role_session.cache_clear()
89+
get_session.cache_clear()

tests/unit/test_aws.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,40 +18,40 @@
1818

1919
import boto3
2020

21-
from strands_env.utils.aws import clear_sessions, get_assumed_role_session, get_boto3_session
21+
from strands_env.utils.aws import clear_session_cache, get_session
2222

2323

24-
class TestGetBoto3Session:
25-
"""Tests for get_boto3_session."""
24+
class TestGetSession:
25+
"""Tests for get_session (basic mode without role assumption)."""
2626

2727
def setup_method(self):
2828
"""Clear cache before each test."""
29-
clear_sessions()
29+
clear_session_cache()
3030

3131
def test_returns_session(self):
3232
"""Should return a boto3 Session."""
33-
session = get_boto3_session(region="us-west-2")
33+
session = get_session(region="us-west-2")
3434
assert isinstance(session, boto3.Session)
3535
assert session.region_name == "us-west-2"
3636

3737
def test_cached_by_region(self):
3838
"""Same region should return cached session."""
39-
session1 = get_boto3_session(region="us-east-1")
40-
session2 = get_boto3_session(region="us-east-1")
39+
session1 = get_session(region="us-east-1")
40+
session2 = get_session(region="us-east-1")
4141
assert session1 is session2
4242

4343
def test_different_regions_different_sessions(self):
4444
"""Different regions should return different sessions."""
45-
session1 = get_boto3_session(region="us-east-1")
46-
session2 = get_boto3_session(region="us-west-2")
45+
session1 = get_session(region="us-east-1")
46+
session2 = get_session(region="us-west-2")
4747
assert session1 is not session2
4848

4949
@patch("strands_env.utils.aws.boto3.Session")
5050
def test_cached_by_profile(self, mock_session_cls):
5151
"""Same profile should return cached session."""
5252
mock_session_cls.return_value = MagicMock()
53-
session1 = get_boto3_session(region="us-east-1", profile_name="test-profile")
54-
session2 = get_boto3_session(region="us-east-1", profile_name="test-profile")
53+
session1 = get_session(region="us-east-1", profile_name="test-profile")
54+
session2 = get_session(region="us-east-1", profile_name="test-profile")
5555
assert session1 is session2
5656
# Session should only be created once due to caching
5757
assert mock_session_cls.call_count == 1
@@ -60,22 +60,22 @@ def test_cached_by_profile(self, mock_session_cls):
6060
def test_different_profiles_different_sessions(self, mock_session_cls):
6161
"""Different profiles should return different sessions."""
6262
mock_session_cls.side_effect = [MagicMock(), MagicMock()]
63-
session1 = get_boto3_session(region="us-east-1", profile_name="profile-a")
64-
session2 = get_boto3_session(region="us-east-1", profile_name="profile-b")
63+
session1 = get_session(region="us-east-1", profile_name="profile-a")
64+
session2 = get_session(region="us-east-1", profile_name="profile-b")
6565
assert session1 is not session2
6666

6767

68-
class TestGetAssumedRoleSession:
69-
"""Tests for get_assumed_role_session."""
68+
class TestGetSessionWithRoleAssumption:
69+
"""Tests for get_session with role_arn (role assumption mode)."""
7070

7171
def setup_method(self):
7272
"""Clear cache before each test."""
73-
clear_sessions()
73+
clear_session_cache()
7474

7575
@patch("strands_env.utils.aws.boto3.client")
7676
@patch("botocore.session.get_session")
7777
def test_assumes_role(self, mock_get_session, mock_boto3_client):
78-
"""Should call STS assume_role."""
78+
"""Should call STS assume_role when role_arn provided."""
7979
from datetime import datetime, timezone
8080

8181
# Mock STS response
@@ -95,7 +95,7 @@ def test_assumes_role(self, mock_get_session, mock_boto3_client):
9595
mock_get_session.return_value = mock_botocore_session
9696

9797
role_arn = "arn:aws:iam::123456789:role/TestRole"
98-
session = get_assumed_role_session(role_arn=role_arn, region="us-east-1")
98+
session = get_session(region="us-east-1", role_arn=role_arn)
9999

100100
# Verify STS was called
101101
mock_boto3_client.assert_called_with("sts", region_name="us-east-1")
@@ -123,8 +123,8 @@ def test_cached_by_role_arn(self, mock_get_session, mock_boto3_client):
123123
mock_get_session.return_value = MagicMock()
124124

125125
role_arn = "arn:aws:iam::123456789:role/TestRole"
126-
session1 = get_assumed_role_session(role_arn=role_arn)
127-
session2 = get_assumed_role_session(role_arn=role_arn)
126+
session1 = get_session(role_arn=role_arn)
127+
session2 = get_session(role_arn=role_arn)
128128

129129
assert session1 is session2
130130
# assume_role should only be called once due to caching
@@ -149,7 +149,7 @@ def test_has_refreshable_credentials(self, mock_boto3_client):
149149
mock_boto3_client.return_value = mock_sts
150150

151151
role_arn = "arn:aws:iam::123456789:role/TestRole"
152-
session = get_assumed_role_session(role_arn=role_arn)
152+
session = get_session(role_arn=role_arn)
153153

154154
# Get the underlying botocore credentials directly
155155
botocore_creds = session._session._credentials
@@ -159,17 +159,17 @@ def test_has_refreshable_credentials(self, mock_boto3_client):
159159
assert botocore_creds._refresh_using is not None
160160

161161

162-
class TestClearSessions:
163-
"""Tests for clear_sessions."""
162+
class TestClearSessionCache:
163+
"""Tests for clear_session_cache."""
164164

165-
def test_clears_all_caches(self):
166-
"""Should clear both session caches."""
167-
# Create some cached sessions
168-
session1 = get_boto3_session(region="us-east-1")
165+
def test_clears_cache(self):
166+
"""Should clear the session cache."""
167+
# Create a cached session
168+
session1 = get_session(region="us-east-1")
169169

170170
# Clear
171-
clear_sessions()
171+
clear_session_cache()
172172

173173
# New call should create new session
174-
session2 = get_boto3_session(region="us-east-1")
174+
session2 = get_session(region="us-east-1")
175175
assert session1 is not session2

0 commit comments

Comments
 (0)