@@ -117,7 +117,9 @@ def _validate_and_normalize_rollout(self, rollout_dict: dict) -> dict:
117117 rollout_dict ["rewards" ] = rewards
118118 return rollout_dict
119119
120- def save_rollout_and_notify (self , rollout_data : dict , training_config : dict ):
120+ def save_rollout_and_notify (
121+ self , rollout_data : dict , training_config : dict , payload : dict = None , result_key : str = None
122+ ):
121123 """
122124 Save rollout data to S3 and notify SQS queue.
123125
@@ -129,6 +131,8 @@ def save_rollout_and_notify(self, rollout_data: dict, training_config: dict):
129131 - exp_id: Experiment ID for organizing data
130132 - session_id: Session id for the current task
131133 - input_id: id for discriminating different input data examples
134+ payload: Original request payload (included in saved result for debugging)
135+ result_key: S3 key for the result (computed externally for consistency)
132136 """
133137 # Validate and extract training configuration
134138 try :
@@ -137,17 +141,24 @@ def save_rollout_and_notify(self, rollout_data: dict, training_config: dict):
137141 logging .error (f"Invalid training configuration: { e } " )
138142 raise
139143
140- result_key = f"{ config .exp_id } /{ config .input_id } _{ config .session_id } .json"
144+ # Use provided result_key or compute it
145+ if result_key is None :
146+ result_key = f"{ config .exp_id } /{ config .input_id } _{ config .session_id } .json"
141147
142148 if "status_code" not in rollout_data :
143149 rollout_data ["status_code" ] = 200
144150
145151 if "stop_reason" not in rollout_data :
146152 rollout_data ["stop_reason" ] = "end_turn"
147153
148- # Return the input id identifying rollouts of the same input data (prompt) example
149- # for advantage computation.
154+ # Include metadata for correlation and debugging
150155 rollout_data ["input_id" ] = config .input_id
156+ rollout_data ["s3_bucket" ] = config .s3_bucket
157+ rollout_data ["result_key" ] = result_key
158+
159+ # Include full payload for debugging (with _training config for reproducibility)
160+ if payload is not None :
161+ rollout_data ["payload" ] = payload
151162
152163 # Save to S3
153164 try :
@@ -205,7 +216,7 @@ def invoke_agent(payload, context): # Can be sync or async
205216 Decorated function registered as entrypoint
206217 """
207218
208- async def rollout_background_task (payload , context ):
219+ async def rollout_background_task (payload , context , result_key ):
209220 """Background task that does the actual agent work and rollout saving."""
210221 training_config = payload .get ("_training" )
211222
@@ -225,7 +236,12 @@ async def rollout_background_task(payload, context):
225236
226237 # Save rollout data if we have training config
227238 if isinstance (result , dict ) and training_config :
228- self .save_rollout_and_notify (rollout_data = result , training_config = training_config )
239+ self .save_rollout_and_notify (
240+ rollout_data = result ,
241+ training_config = training_config ,
242+ payload = payload ,
243+ result_key = result_key ,
244+ )
229245 logging .info (f"Rollout data saved for function: { func .__name__ } " )
230246
231247 return result
@@ -234,7 +250,12 @@ async def rollout_background_task(payload, context):
234250 # Always save error rollout for client notification
235251 if training_config :
236252 error_rollout = {"status_code" : 500 , "stop_reason" : str (e )}
237- self .save_rollout_and_notify (rollout_data = error_rollout , training_config = training_config )
253+ self .save_rollout_and_notify (
254+ rollout_data = error_rollout ,
255+ training_config = training_config ,
256+ payload = payload ,
257+ result_key = result_key ,
258+ )
238259 logging .error (f"Error rollout saved for function: { func .__name__ } : { e } " )
239260 raise
240261 finally :
@@ -244,8 +265,26 @@ async def rollout_background_task(payload, context):
244265 @wraps (func )
245266 async def rollout_entrypoint_wrapper (payload , context ):
246267 """Entrypoint that starts background task and returns immediately."""
268+ training_config = payload .get ("_training" )
269+
270+ # Compute result_key upfront so we can return it to the client
271+ result_key = None
272+ if training_config :
273+ exp_id = training_config .get ("exp_id" , "" )
274+ input_id = training_config .get ("input_id" , "" )
275+ session_id = training_config .get ("session_id" , "" )
276+ result_key = f"{ exp_id } /{ input_id } _{ session_id } .json"
277+
247278 # Start background task without waiting
248- asyncio .create_task (rollout_background_task (payload , context ))
279+ asyncio .create_task (rollout_background_task (payload , context , result_key ))
280+
281+ # Return result location so client can poll S3 for completion
282+ if training_config :
283+ return {
284+ "status" : "processing" ,
285+ "s3_bucket" : training_config .get ("s3_bucket" ),
286+ "result_key" : result_key ,
287+ }
249288 return {"status" : "processing" }
250289
251290 # Remove __wrapped__ so inspect.signature() sees the wrapper's actual signature
0 commit comments