Skip to content

Commit dd9a057

Browse files
committed
feat(app): return S3 result location in rollout response
1 parent b563a61 commit dd9a057

2 files changed

Lines changed: 95 additions & 8 deletions

File tree

src/agentcore_rl_toolkit/app.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def _validate_and_normalize_rollout(self, rollout_dict: dict) -> dict:
117117
rollout_dict["rewards"] = rewards
118118
return rollout_dict
119119

120-
def save_rollout_and_notify(self, rollout_data: dict, training_config: dict):
120+
def save_rollout_and_notify(
121+
self, rollout_data: dict, training_config: dict, payload: dict = None, result_key: str = None
122+
):
121123
"""
122124
Save rollout data to S3 and notify SQS queue.
123125
@@ -129,6 +131,8 @@ def save_rollout_and_notify(self, rollout_data: dict, training_config: dict):
129131
- exp_id: Experiment ID for organizing data
130132
- session_id: Session id for the current task
131133
- input_id: id for discriminating different input data examples
134+
payload: Original request payload (included in saved result for debugging)
135+
result_key: S3 key for the result (computed externally for consistency)
132136
"""
133137
# Validate and extract training configuration
134138
try:
@@ -137,17 +141,24 @@ def save_rollout_and_notify(self, rollout_data: dict, training_config: dict):
137141
logging.error(f"Invalid training configuration: {e}")
138142
raise
139143

140-
result_key = f"{config.exp_id}/{config.input_id}_{config.session_id}.json"
144+
# Use provided result_key or compute it
145+
if result_key is None:
146+
result_key = f"{config.exp_id}/{config.input_id}_{config.session_id}.json"
141147

142148
if "status_code" not in rollout_data:
143149
rollout_data["status_code"] = 200
144150

145151
if "stop_reason" not in rollout_data:
146152
rollout_data["stop_reason"] = "end_turn"
147153

148-
# Return the input id identifying rollouts of the same input data (prompt) example
149-
# for advantage computation.
154+
# Include metadata for correlation and debugging
150155
rollout_data["input_id"] = config.input_id
156+
rollout_data["s3_bucket"] = config.s3_bucket
157+
rollout_data["result_key"] = result_key
158+
159+
# Include full payload for debugging (with _training config for reproducibility)
160+
if payload is not None:
161+
rollout_data["payload"] = payload
151162

152163
# Save to S3
153164
try:
@@ -205,7 +216,7 @@ def invoke_agent(payload, context): # Can be sync or async
205216
Decorated function registered as entrypoint
206217
"""
207218

208-
async def rollout_background_task(payload, context):
219+
async def rollout_background_task(payload, context, result_key):
209220
"""Background task that does the actual agent work and rollout saving."""
210221
training_config = payload.get("_training")
211222

@@ -225,7 +236,12 @@ async def rollout_background_task(payload, context):
225236

226237
# Save rollout data if we have training config
227238
if isinstance(result, dict) and training_config:
228-
self.save_rollout_and_notify(rollout_data=result, training_config=training_config)
239+
self.save_rollout_and_notify(
240+
rollout_data=result,
241+
training_config=training_config,
242+
payload=payload,
243+
result_key=result_key,
244+
)
229245
logging.info(f"Rollout data saved for function: {func.__name__}")
230246

231247
return result
@@ -234,7 +250,12 @@ async def rollout_background_task(payload, context):
234250
# Always save error rollout for client notification
235251
if training_config:
236252
error_rollout = {"status_code": 500, "stop_reason": str(e)}
237-
self.save_rollout_and_notify(rollout_data=error_rollout, training_config=training_config)
253+
self.save_rollout_and_notify(
254+
rollout_data=error_rollout,
255+
training_config=training_config,
256+
payload=payload,
257+
result_key=result_key,
258+
)
238259
logging.error(f"Error rollout saved for function: {func.__name__}: {e}")
239260
raise
240261
finally:
@@ -244,8 +265,26 @@ async def rollout_background_task(payload, context):
244265
@wraps(func)
245266
async def rollout_entrypoint_wrapper(payload, context):
246267
"""Entrypoint that starts background task and returns immediately."""
268+
training_config = payload.get("_training")
269+
270+
# Compute result_key upfront so we can return it to the client
271+
result_key = None
272+
if training_config:
273+
exp_id = training_config.get("exp_id", "")
274+
input_id = training_config.get("input_id", "")
275+
session_id = training_config.get("session_id", "")
276+
result_key = f"{exp_id}/{input_id}_{session_id}.json"
277+
247278
# Start background task without waiting
248-
asyncio.create_task(rollout_background_task(payload, context))
279+
asyncio.create_task(rollout_background_task(payload, context, result_key))
280+
281+
# Return result location so client can poll S3 for completion
282+
if training_config:
283+
return {
284+
"status": "processing",
285+
"s3_bucket": training_config.get("s3_bucket"),
286+
"result_key": result_key,
287+
}
249288
return {"status": "processing"}
250289

251290
# Remove __wrapped__ so inspect.signature() sees the wrapper's actual signature

tests/test_rollout_entrypoint.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,51 @@ def handler(payload: dict):
8282

8383
assert response.status_code == 200
8484
assert response.json() == {"status": "processing"}
85+
86+
87+
def test_response_includes_result_location_with_training_config():
88+
"""Test that response includes s3_bucket and result_key when _training config is provided."""
89+
app = AgentCoreRLApp()
90+
91+
@app.rollout_entrypoint
92+
async def handler(payload: dict):
93+
return {"rollout_data": [{"test": True}], "rewards": [1.0]}
94+
95+
client = TestClient(app)
96+
response = client.post(
97+
"/invocations",
98+
json={
99+
"prompt": "test",
100+
"_training": {
101+
"exp_id": "exp-123",
102+
"session_id": "sess-456",
103+
"input_id": "input-789",
104+
"s3_bucket": "my-bucket",
105+
"sqs_url": "https://sqs.us-east-1.amazonaws.com/123/queue",
106+
},
107+
},
108+
)
109+
110+
assert response.status_code == 200
111+
result = response.json()
112+
assert result["status"] == "processing"
113+
assert result["s3_bucket"] == "my-bucket"
114+
assert result["result_key"] == "exp-123/input-789_sess-456.json"
115+
116+
117+
def test_response_without_training_config():
118+
"""Test that response is minimal when no _training config is provided."""
119+
app = AgentCoreRLApp()
120+
121+
@app.rollout_entrypoint
122+
async def handler(payload: dict):
123+
return {"rollout_data": [{"test": True}], "rewards": [1.0]}
124+
125+
client = TestClient(app)
126+
response = client.post("/invocations", json={"prompt": "test"})
127+
128+
assert response.status_code == 200
129+
result = response.json()
130+
assert result == {"status": "processing"}
131+
assert "s3_bucket" not in result
132+
assert "result_key" not in result

0 commit comments

Comments
 (0)