Skip to content

Commit e137007

Browse files
authored
refactor(app): derive session_id from ACR header instead of _rollout payload (#22)
1 parent 05ae840 commit e137007

5 files changed

Lines changed: 12 additions & 17 deletions

File tree

AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ Since the client won't get results directly from HTTP:
204204
**AgentCoreRLApp** (`src/agentcore_rl_toolkit/app.py`)
205205
- Inherits `BedrockAgentCoreApp` - drop-in replacement
206206
- Provides `@app.rollout_entrypoint` decorator
207-
- Expects `_rollout` dict in payload following `RolloutConfig` model (experiment id, session id, input id, base_url, model_id)
207+
- Expects `_rollout` dict in payload following `RolloutConfig` model (experiment id, input id, s3_bucket, base_url, model_id)
208208
- Framework-agnostic: works with any agent framework, not just Strands
209209

210210
#### Utilities

src/agentcore_rl_toolkit/app.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
import json
33
import logging
44
import os
5+
import uuid
56
from dataclasses import dataclass
67
from functools import wraps
78

89
import boto3
910
from bedrock_agentcore.runtime import BedrockAgentCoreApp
1011

11-
_S3_CONFIG_FIELDS = ("exp_id", "session_id", "input_id", "s3_bucket")
12+
_S3_CONFIG_FIELDS = ("exp_id", "input_id", "s3_bucket")
1213

1314

1415
@dataclass
1516
class RolloutConfig:
1617
"""Rollout configuration for rollout collection and storage."""
1718

1819
exp_id: str
19-
session_id: str
2020
input_id: str
2121
s3_bucket: str
2222

@@ -26,7 +26,6 @@ def from_dict(cls, data: dict) -> "RolloutConfig":
2626
try:
2727
return cls(
2828
exp_id=data["exp_id"],
29-
session_id=data["session_id"],
3029
input_id=data["input_id"],
3130
s3_bucket=data["s3_bucket"],
3231
)
@@ -115,7 +114,7 @@ def _validate_and_normalize_rollout(self, rollout_dict: dict) -> dict:
115114
rollout_dict["rewards"] = rewards
116115
return rollout_dict
117116

118-
def save_rollout(self, rollout_data: dict, rollout_config: dict, payload: dict = None, result_key: str = None):
117+
def save_rollout(self, rollout_data: dict, rollout_config: dict, result_key: str, payload: dict = None):
119118
"""
120119
Save rollout data to S3.
121120
@@ -124,7 +123,6 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, payload: dict =
124123
rollout_config: Rollout configuration dict containing:
125124
- s3_bucket: S3 bucket name
126125
- exp_id: Experiment ID for organizing data
127-
- session_id: Session id for the current task
128126
- input_id: id for discriminating different input data examples
129127
payload: Original request payload (included in saved result for debugging)
130128
result_key: S3 key for the result (computed externally for consistency)
@@ -136,10 +134,6 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, payload: dict =
136134
logging.error(f"Invalid rollout configuration: {e}")
137135
raise
138136

139-
# Use provided result_key or compute it
140-
if result_key is None:
141-
result_key = f"{config.exp_id}/{config.input_id}_{config.session_id}.json"
142-
143137
if "status_code" not in rollout_data:
144138
rollout_data["status_code"] = 200
145139

@@ -249,7 +243,10 @@ async def rollout_entrypoint_wrapper(payload, context):
249243
rollout_config = None
250244
if rollout_dict is not None and any(f in rollout_dict for f in _S3_CONFIG_FIELDS):
251245
rollout_config = RolloutConfig.from_dict(rollout_dict)
252-
result_key = f"{rollout_config.exp_id}/{rollout_config.input_id}_{rollout_config.session_id}.json"
246+
# session_id comes from ACR's HTTP header (set via runtimeSessionId),
247+
# fall back to UUID for local testing.
248+
session_id = context.session_id if context.session_id else str(uuid.uuid4())
249+
result_key = f"{rollout_config.exp_id}/{rollout_config.input_id}/{session_id}.json"
253250

254251
# Start background task without waiting
255252
asyncio.create_task(rollout_background_task(payload, context, result_key))

src/agentcore_rl_toolkit/client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,6 @@ def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str) ->
257257
# Build rollout config
258258
rollout_config = {
259259
"exp_id": self.exp_id,
260-
"session_id": session_id,
261260
"input_id": input_id,
262261
"s3_bucket": self.s3_bucket,
263262
**self.extra_config,

tests/test_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ def test_invoke_builds_rollout_config(self):
410410

411411
assert payload["prompt"] == "test"
412412
assert payload["_rollout"]["exp_id"] == "exp-001"
413-
assert payload["_rollout"]["session_id"] == "sess-1"
414413
assert payload["_rollout"]["input_id"] == "input-1"
415414
assert payload["_rollout"]["s3_bucket"] == "test-bucket"
416415
assert payload["_rollout"]["base_url"] == "http://localhost:8000"

tests/test_rollout_entrypoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,18 @@ async def handler(payload: dict):
100100
"prompt": "test",
101101
"_rollout": {
102102
"exp_id": "exp-123",
103-
"session_id": "sess-456",
104103
"input_id": "input-789",
105104
"s3_bucket": "my-bucket",
106105
},
107106
},
107+
headers={"X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": "sess-456"},
108108
)
109109

110110
assert response.status_code == 200
111111
result = response.json()
112112
assert result["status"] == "processing"
113113
assert result["s3_bucket"] == "my-bucket"
114-
assert result["result_key"] == "exp-123/input-789_sess-456.json"
114+
assert result["result_key"] == "exp-123/input-789/sess-456.json"
115115

116116

117117
def test_response_without_rollout_config():
@@ -180,7 +180,7 @@ async def handler(payload: dict):
180180
assert "result_key" not in result
181181

182182

183-
@pytest.mark.parametrize("missing_field", ["exp_id", "session_id", "input_id", "s3_bucket"])
183+
@pytest.mark.parametrize("missing_field", ["exp_id", "input_id", "s3_bucket"])
184184
def test_entrypoint_rejects_partial_s3_config(missing_field):
185185
"""Test that providing some but not all S3 fields returns HTTP 500."""
186186
app = AgentCoreRLApp()
@@ -191,7 +191,6 @@ async def handler(payload: dict):
191191

192192
complete_config = {
193193
"exp_id": "exp-123",
194-
"session_id": "sess-456",
195194
"input_id": "input-789",
196195
"s3_bucket": "my-bucket",
197196
}
@@ -201,6 +200,7 @@ async def handler(payload: dict):
201200
response = client.post(
202201
"/invocations",
203202
json={"prompt": "test", "_rollout": incomplete_config},
203+
headers={"X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": "sess-456"},
204204
)
205205

206206
assert response.status_code == 500

0 commit comments

Comments
 (0)