@@ -38,56 +38,24 @@ def __init__(self):
3838 super ().__init__ ()
3939 self .s3_client = boto3 .client ("s3" )
4040
41- def _validate_and_normalize_rollout (self , rollout_dict : dict ) -> dict :
41+ def save_result (self , result : dict , rollout_config : dict , result_key : str , payload : dict = None ) :
4242 """
43- Validate and normalize rollout data structure .
43+ Save result data to S3 .
4444
45- Ensures the return value from user functions has the expected format:
46- {"rollout_data": [...], "rewards": [...]}
45+ The result dict is saved as-is with metadata added for correlation and debugging.
46+ Any JSON-serializable dict is accepted — there are no required keys.
4747
48- Args:
49- rollout_dict: Dictionary returned from user function
50-
51- Returns:
52- Normalized rollout dictionary with validated structure
53-
54- Raises:
55- ValueError: If structure is invalid or rewards don't match rollout length
56- """
57- # Require both fields to exist
58- if "rollout_data" not in rollout_dict :
59- raise ValueError ("Return value must include 'rollout_data' field" )
60- if "rewards" not in rollout_dict :
61- raise ValueError ("Return value must include 'rewards' field" )
62-
63- rollout_data = rollout_dict ["rollout_data" ]
64- rewards = rollout_dict ["rewards" ]
65-
66- # Validate rollout_data
67- if not isinstance (rollout_data , list ) or len (rollout_data ) == 0 :
68- raise ValueError ("rollout_data must be a list with length >= 1" )
69-
70- # Normalize rewards to list if not already
71- if not isinstance (rewards , list ):
72- rewards = [rewards ]
73-
74- # Validate rewards length
75- if len (rewards ) != 1 and len (rewards ) != len (rollout_data ):
76- raise ValueError (
77- f"rewards must be length 1 (outcome reward) or "
78- f"match rollout_data length { len (rollout_data )} (per-step reward)"
79- )
80-
81- # Update with normalized rewards
82- rollout_dict ["rewards" ] = rewards
83- return rollout_dict
84-
85- def save_rollout (self , rollout_data : dict , rollout_config : dict , result_key : str , payload : dict = None ):
86- """
87- Save rollout data to S3.
48+ Reserved keys — the SDK injects the following keys into the saved JSON.
49+ Avoid using these in your return dict to prevent unexpected overwrites:
50+ - ``status_code``: Set to 200 if not already present in the user dict.
51+ - ``stop_reason``: Set to ``"end_turn"`` if not already present.
52+ - ``input_id``: Always overwritten with the value from rollout config.
53+ - ``s3_bucket``: Always overwritten with the value from rollout config.
54+ - ``result_key``: Always overwritten with the computed S3 key.
55+ - ``payload``: Always overwritten with the original request payload.
8856
8957 Args:
90- rollout_data : The prepared rollout data
58+ result : The result data to save (any JSON-serializable dict)
9159 rollout_config: Rollout configuration dict containing:
9260 - s3_bucket: S3 bucket name
9361 - exp_id: Experiment ID for organizing data
@@ -102,27 +70,27 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, result_key: str
10270 logging .error (f"Invalid rollout configuration: { e } " )
10371 raise
10472
105- if "status_code" not in rollout_data :
106- rollout_data ["status_code" ] = 200
73+ if "status_code" not in result :
74+ result ["status_code" ] = 200
10775
108- if "stop_reason" not in rollout_data :
109- rollout_data ["stop_reason" ] = "end_turn"
76+ if "stop_reason" not in result :
77+ result ["stop_reason" ] = "end_turn"
11078
11179 # Include metadata for correlation and debugging
112- rollout_data ["input_id" ] = config .input_id
113- rollout_data ["s3_bucket" ] = config .s3_bucket
114- rollout_data ["result_key" ] = result_key
80+ result ["input_id" ] = config .input_id
81+ result ["s3_bucket" ] = config .s3_bucket
82+ result ["result_key" ] = result_key
11583
11684 # Include full payload for debugging (with _rollout config for reproducibility)
11785 if payload is not None :
118- rollout_data ["payload" ] = payload
86+ result ["payload" ] = payload
11987
12088 # Save to S3
12189 try :
12290 self .s3_client .put_object (
12391 Bucket = config .s3_bucket ,
12492 Key = result_key ,
125- Body = json .dumps (rollout_data , indent = 2 ),
93+ Body = json .dumps (result , indent = 2 ),
12694 ContentType = "application/json" ,
12795 )
12896 logging .info (f"Stored complete results at { result_key } " )
@@ -132,30 +100,44 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, result_key: str
132100
133101 def rollout_entrypoint (self , func ):
134102 """
135- Decorator for RL training that handles asyncio.create_task and rollout saving automatically.
103+ Decorator for RL training that handles asyncio.create_task and result saving automatically.
136104
137105 This decorator:
138106 1. Handles both sync and async user functions using BedrockAgentCoreApp's infrastructure
139- 2. Automatically saves rollout data when user returns it
140- 3. Handles errors and saves error rollouts for client notification
107+ 2. Automatically saves the returned dict to S3 when S3 config is present
108+ 3. Handles errors and saves error results for client notification
141109 4. Returns immediately with {"status": "processing"} for non-blocking behavior
142110
111+ The return value must be a JSON-serializable dict when S3 save is configured.
112+ Any dict structure is accepted — there are no required keys. Common patterns:
113+ - RL training: {"rollout_data": [...], "rewards": [...]}
114+ - Evaluation: {"rewards": [...], "metrics": {...}}
115+ - Custom: {"summary": "...", "artifacts": {...}}
116+
117+ Serialization note: saved via json.dumps() → S3 as application/json.
118+ Supported types: str, int, float, bool, None, list, dict.
119+ Non-serializable values (custom objects, bytes, datetime, numpy arrays, etc.)
120+ will trigger the error path and an error file will be saved to S3.
121+
122+ Reserved keys: ``save_result`` injects SDK metadata into the saved JSON.
123+ See ``save_result`` docstring for the full list of reserved keys.
124+
143125 Usage:
144126 @app.rollout_entrypoint
145127 def invoke_agent(payload, context): # Can be sync or async
146128 # Framework-specific rollout collection
147- rollout_data = collect_rollout (...)
148- return rollout_data # Automatically saved!
129+ result = collect_result (...)
130+ return result # Automatically saved!
149131
150132 Args:
151- func: The user function that handles agent logic and rollout collection
133+ func: The user function that handles agent logic and result collection
152134
153135 Returns:
154136 Decorated function registered as entrypoint
155137 """
156138
157139 async def rollout_background_task (payload , context , result_key ):
158- """Background task that does the actual agent work and rollout saving."""
140+ """Background task that does the actual agent work and result saving."""
159141 rollout_dict = payload .get ("_rollout" )
160142
161143 # Register with async task tracking system for logging and ping status
@@ -166,44 +148,42 @@ async def rollout_background_task(payload, context, result_key):
166148 # This automatically runs sync functions in thread pool to avoid blocking
167149 result = await self ._invoke_handler (func , context , self ._takes_context (func ), payload )
168150
169- # If this is an RL training run, validate and normalize the rollout structure
170- if rollout_dict :
151+ # Save result to S3 if S3 config is present
152+ if result_key :
171153 if not isinstance (result , dict ):
172- raise ValueError ("RL training runs must return a dictionary" )
173- result = self ._validate_and_normalize_rollout (result )
174-
175- # Save rollout data if we have S3 config
176- if isinstance (result , dict ) and result_key :
177- self .save_rollout (
178- rollout_data = result ,
154+ raise ValueError (
155+ f"Return value must be a dict when S3 save is configured, got { type (result ).__name__ } "
156+ )
157+ self .save_result (
158+ result = result ,
179159 rollout_config = rollout_dict ,
180160 payload = payload ,
181161 result_key = result_key ,
182162 )
183- logging .info (f"Rollout data saved for function: { func .__name__ } " )
163+ logging .info (f"Result saved for function: { func .__name__ } " )
184164
185165 return result
186166
187167 except BaseException as e :
188- # Save error rollout for client notification when S3 is configured.
168+ # Save error result for client notification when S3 is configured.
189169 # Uses BaseException to also catch CancelledError, GeneratorExit, etc.
190170 # that can arise from task cancellation or deep async generator unwinding.
191171 if result_key :
192172 try :
193- error_rollout = {
173+ error_result = {
194174 "status_code" : 500 ,
195175 "stop_reason" : str (e ),
196176 "traceback" : traceback .format_exc (),
197177 }
198- self .save_rollout (
199- rollout_data = error_rollout ,
178+ self .save_result (
179+ result = error_result ,
200180 rollout_config = rollout_dict ,
201181 payload = payload ,
202182 result_key = result_key ,
203183 )
204- logging .error (f"Error rollout saved for function: { func .__name__ } : { e } " )
184+ logging .error (f"Error result saved for function: { func .__name__ } : { e } " )
205185 except Exception :
206- logging .error (f"Failed to save error rollout for function: { func .__name__ } " , exc_info = True )
186+ logging .error (f"Failed to save error result for function: { func .__name__ } " , exc_info = True )
207187 raise
208188 finally :
209189 # Complete the async task for logging and ping status
0 commit comments