Skip to content

Commit a929e7c

Browse files
authored
feat(app): relax rollout_entrypoint return value to accept any dict (#29)
1 parent ac4a1d3 commit a929e7c

3 files changed

Lines changed: 238 additions & 81 deletions

File tree

AGENTS.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,18 @@ See `examples/strands_math_agent` for a complete example adapting from `basic_ap
266266
+ response = agent(user_input)
267267
```
268268

269-
#### Step 3: Collect Token Data & Return Rollout
269+
#### Step 3: Collect Token Data & Return Result
270270

271271
The `@rollout_entrypoint` decorator automatically:
272272
- Executes the function in the background (works with both sync and async functions)
273-
- Saves rollout data to S3 with a predictable key
274-
- Handles errors and saves error rollouts for client awareness
273+
- Saves the returned dict to S3 with a predictable key
274+
- Handles errors and saves error results for client awareness
275+
276+
The return value must be a JSON-serializable dict when S3 save is configured. Any dict structure is accepted — there are no required keys. `rollout_data` is optional (e.g., when the gateway collects it server-side).
277+
278+
**Reserved keys**: The SDK injects metadata into the saved S3 JSON. Avoid using these keys in your return dict:
279+
- `status_code`, `stop_reason` — added only if not already present in your dict
280+
- `input_id`, `s3_bucket`, `result_key`, `payload` — always overwritten with SDK values
275281

276282
```diff
277283
- return response.message["content"][0]["text"]
@@ -280,6 +286,15 @@ The `@rollout_entrypoint` decorator automatically:
280286
+ return {"rollout_data": rollout_data, "rewards": rewards}
281287
```
282288

289+
Other valid return patterns:
290+
```python
291+
# Evaluation-only (no rollout_data needed)
292+
return {"rewards": rewards, "metrics": {"latency_ms": elapsed}}
293+
294+
# Custom artifacts
295+
return {"summary": "...", "artifacts": {...}}
296+
```
297+
283298
Each example in `/examples` contains `basic_app.py` and `rl_app.py` (or `dev_app.py`) to demonstrate this adaptation.
284299

285300
### Deployment to ACR
@@ -450,7 +465,7 @@ uv run pre-commit install
450465

451466
### Code Conventions
452467

453-
- Return dict with `rollout_data` and `rewards` keys from `@rollout_entrypoint`
468+
- Return a JSON-serializable dict from `@rollout_entrypoint` (any structure accepted — no required keys)
454469
- Create model and agent inside the entrypoint function (not at module level) so config comes from the `_rollout` payload
455470
- Use `vLLMModel.get_token_data()` to collect token IDs instead of hook-based rollout collection
456471
- Implement reward functions as classes inheriting `RewardFunction`

src/agentcore_rl_toolkit/app.py

Lines changed: 57 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -38,56 +38,24 @@ def __init__(self):
3838
super().__init__()
3939
self.s3_client = boto3.client("s3")
4040

41-
def _validate_and_normalize_rollout(self, rollout_dict: dict) -> dict:
41+
def save_result(self, result: dict, rollout_config: dict, result_key: str, payload: dict = None):
4242
"""
43-
Validate and normalize rollout data structure.
43+
Save result data to S3.
4444
45-
Ensures the return value from user functions has the expected format:
46-
{"rollout_data": [...], "rewards": [...]}
45+
The result dict is saved as-is with metadata added for correlation and debugging.
46+
Any JSON-serializable dict is accepted — there are no required keys.
4747
48-
Args:
49-
rollout_dict: Dictionary returned from user function
50-
51-
Returns:
52-
Normalized rollout dictionary with validated structure
53-
54-
Raises:
55-
ValueError: If structure is invalid or rewards don't match rollout length
56-
"""
57-
# Require both fields to exist
58-
if "rollout_data" not in rollout_dict:
59-
raise ValueError("Return value must include 'rollout_data' field")
60-
if "rewards" not in rollout_dict:
61-
raise ValueError("Return value must include 'rewards' field")
62-
63-
rollout_data = rollout_dict["rollout_data"]
64-
rewards = rollout_dict["rewards"]
65-
66-
# Validate rollout_data
67-
if not isinstance(rollout_data, list) or len(rollout_data) == 0:
68-
raise ValueError("rollout_data must be a list with length >= 1")
69-
70-
# Normalize rewards to list if not already
71-
if not isinstance(rewards, list):
72-
rewards = [rewards]
73-
74-
# Validate rewards length
75-
if len(rewards) != 1 and len(rewards) != len(rollout_data):
76-
raise ValueError(
77-
f"rewards must be length 1 (outcome reward) or "
78-
f"match rollout_data length {len(rollout_data)} (per-step reward)"
79-
)
80-
81-
# Update with normalized rewards
82-
rollout_dict["rewards"] = rewards
83-
return rollout_dict
84-
85-
def save_rollout(self, rollout_data: dict, rollout_config: dict, result_key: str, payload: dict = None):
86-
"""
87-
Save rollout data to S3.
48+
Reserved keys — the SDK injects the following keys into the saved JSON.
49+
Avoid using these in your return dict to prevent unexpected overwrites:
50+
- ``status_code``: Set to 200 if not already present in the user dict.
51+
- ``stop_reason``: Set to ``"end_turn"`` if not already present.
52+
- ``input_id``: Always overwritten with the value from rollout config.
53+
- ``s3_bucket``: Always overwritten with the value from rollout config.
54+
- ``result_key``: Always overwritten with the computed S3 key.
55+
- ``payload``: Always overwritten with the original request payload.
8856
8957
Args:
90-
rollout_data: The prepared rollout data
58+
result: The result data to save (any JSON-serializable dict)
9159
rollout_config: Rollout configuration dict containing:
9260
- s3_bucket: S3 bucket name
9361
- exp_id: Experiment ID for organizing data
@@ -102,27 +70,27 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, result_key: str
10270
logging.error(f"Invalid rollout configuration: {e}")
10371
raise
10472

105-
if "status_code" not in rollout_data:
106-
rollout_data["status_code"] = 200
73+
if "status_code" not in result:
74+
result["status_code"] = 200
10775

108-
if "stop_reason" not in rollout_data:
109-
rollout_data["stop_reason"] = "end_turn"
76+
if "stop_reason" not in result:
77+
result["stop_reason"] = "end_turn"
11078

11179
# Include metadata for correlation and debugging
112-
rollout_data["input_id"] = config.input_id
113-
rollout_data["s3_bucket"] = config.s3_bucket
114-
rollout_data["result_key"] = result_key
80+
result["input_id"] = config.input_id
81+
result["s3_bucket"] = config.s3_bucket
82+
result["result_key"] = result_key
11583

11684
# Include full payload for debugging (with _rollout config for reproducibility)
11785
if payload is not None:
118-
rollout_data["payload"] = payload
86+
result["payload"] = payload
11987

12088
# Save to S3
12189
try:
12290
self.s3_client.put_object(
12391
Bucket=config.s3_bucket,
12492
Key=result_key,
125-
Body=json.dumps(rollout_data, indent=2),
93+
Body=json.dumps(result, indent=2),
12694
ContentType="application/json",
12795
)
12896
logging.info(f"Stored complete results at {result_key}")
@@ -132,30 +100,44 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, result_key: str
132100

133101
def rollout_entrypoint(self, func):
134102
"""
135-
Decorator for RL training that handles asyncio.create_task and rollout saving automatically.
103+
Decorator for RL training that handles asyncio.create_task and result saving automatically.
136104
137105
This decorator:
138106
1. Handles both sync and async user functions using BedrockAgentCoreApp's infrastructure
139-
2. Automatically saves rollout data when user returns it
140-
3. Handles errors and saves error rollouts for client notification
107+
2. Automatically saves the returned dict to S3 when S3 config is present
108+
3. Handles errors and saves error results for client notification
141109
4. Returns immediately with {"status": "processing"} for non-blocking behavior
142110
111+
The return value must be a JSON-serializable dict when S3 save is configured.
112+
Any dict structure is accepted — there are no required keys. Common patterns:
113+
- RL training: {"rollout_data": [...], "rewards": [...]}
114+
- Evaluation: {"rewards": [...], "metrics": {...}}
115+
- Custom: {"summary": "...", "artifacts": {...}}
116+
117+
Serialization note: saved via json.dumps() → S3 as application/json.
118+
Supported types: str, int, float, bool, None, list, dict.
119+
Non-serializable values (custom objects, bytes, datetime, numpy arrays, etc.)
120+
will trigger the error path and an error file will be saved to S3.
121+
122+
Reserved keys: ``save_result`` injects SDK metadata into the saved JSON.
123+
See ``save_result`` docstring for the full list of reserved keys.
124+
143125
Usage:
144126
@app.rollout_entrypoint
145127
def invoke_agent(payload, context): # Can be sync or async
146128
# Framework-specific rollout collection
147-
rollout_data = collect_rollout(...)
148-
return rollout_data # Automatically saved!
129+
result = collect_result(...)
130+
return result # Automatically saved!
149131
150132
Args:
151-
func: The user function that handles agent logic and rollout collection
133+
func: The user function that handles agent logic and result collection
152134
153135
Returns:
154136
Decorated function registered as entrypoint
155137
"""
156138

157139
async def rollout_background_task(payload, context, result_key):
158-
"""Background task that does the actual agent work and rollout saving."""
140+
"""Background task that does the actual agent work and result saving."""
159141
rollout_dict = payload.get("_rollout")
160142

161143
# Register with async task tracking system for logging and ping status
@@ -166,44 +148,42 @@ async def rollout_background_task(payload, context, result_key):
166148
# This automatically runs sync functions in thread pool to avoid blocking
167149
result = await self._invoke_handler(func, context, self._takes_context(func), payload)
168150

169-
# If this is an RL training run, validate and normalize the rollout structure
170-
if rollout_dict:
151+
# Save result to S3 if S3 config is present
152+
if result_key:
171153
if not isinstance(result, dict):
172-
raise ValueError("RL training runs must return a dictionary")
173-
result = self._validate_and_normalize_rollout(result)
174-
175-
# Save rollout data if we have S3 config
176-
if isinstance(result, dict) and result_key:
177-
self.save_rollout(
178-
rollout_data=result,
154+
raise ValueError(
155+
f"Return value must be a dict when S3 save is configured, got {type(result).__name__}"
156+
)
157+
self.save_result(
158+
result=result,
179159
rollout_config=rollout_dict,
180160
payload=payload,
181161
result_key=result_key,
182162
)
183-
logging.info(f"Rollout data saved for function: {func.__name__}")
163+
logging.info(f"Result saved for function: {func.__name__}")
184164

185165
return result
186166

187167
except BaseException as e:
188-
# Save error rollout for client notification when S3 is configured.
168+
# Save error result for client notification when S3 is configured.
189169
# Uses BaseException to also catch CancelledError, GeneratorExit, etc.
190170
# that can arise from task cancellation or deep async generator unwinding.
191171
if result_key:
192172
try:
193-
error_rollout = {
173+
error_result = {
194174
"status_code": 500,
195175
"stop_reason": str(e),
196176
"traceback": traceback.format_exc(),
197177
}
198-
self.save_rollout(
199-
rollout_data=error_rollout,
178+
self.save_result(
179+
result=error_result,
200180
rollout_config=rollout_dict,
201181
payload=payload,
202182
result_key=result_key,
203183
)
204-
logging.error(f"Error rollout saved for function: {func.__name__}: {e}")
184+
logging.error(f"Error result saved for function: {func.__name__}: {e}")
205185
except Exception:
206-
logging.error(f"Failed to save error rollout for function: {func.__name__}", exc_info=True)
186+
logging.error(f"Failed to save error result for function: {func.__name__}", exc_info=True)
207187
raise
208188
finally:
209189
# Complete the async task for logging and ping status

0 commit comments

Comments
 (0)