22import json
33import logging
44import os
5+ import uuid
56from dataclasses import dataclass
67from functools import wraps
78
89import boto3
910from bedrock_agentcore .runtime import BedrockAgentCoreApp
1011
11- _S3_CONFIG_FIELDS = ("exp_id" , "session_id" , " input_id" , "s3_bucket" )
12+ _S3_CONFIG_FIELDS = ("exp_id" , "input_id" , "s3_bucket" )
1213
1314
1415@dataclass
1516class RolloutConfig :
1617 """Rollout configuration for rollout collection and storage."""
1718
1819 exp_id : str
19- session_id : str
2020 input_id : str
2121 s3_bucket : str
2222
@@ -26,7 +26,6 @@ def from_dict(cls, data: dict) -> "RolloutConfig":
2626 try :
2727 return cls (
2828 exp_id = data ["exp_id" ],
29- session_id = data ["session_id" ],
3029 input_id = data ["input_id" ],
3130 s3_bucket = data ["s3_bucket" ],
3231 )
@@ -115,7 +114,7 @@ def _validate_and_normalize_rollout(self, rollout_dict: dict) -> dict:
115114 rollout_dict ["rewards" ] = rewards
116115 return rollout_dict
117116
118- def save_rollout (self , rollout_data : dict , rollout_config : dict , payload : dict = None , result_key : str = None ):
117+ def save_rollout (self , rollout_data : dict , rollout_config : dict , result_key : str , payload : dict = None ):
119118 """
120119 Save rollout data to S3.
121120
@@ -124,7 +123,6 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, payload: dict =
124123 rollout_config: Rollout configuration dict containing:
125124 - s3_bucket: S3 bucket name
126125 - exp_id: Experiment ID for organizing data
127- - session_id: Session id for the current task
128126 - input_id: id for discriminating different input data examples
129127 payload: Original request payload (included in saved result for debugging)
130128 result_key: S3 key for the result (computed externally for consistency)
@@ -136,10 +134,6 @@ def save_rollout(self, rollout_data: dict, rollout_config: dict, payload: dict =
136134 logging .error (f"Invalid rollout configuration: { e } " )
137135 raise
138136
139- # Use provided result_key or compute it
140- if result_key is None :
141- result_key = f"{ config .exp_id } /{ config .input_id } _{ config .session_id } .json"
142-
143137 if "status_code" not in rollout_data :
144138 rollout_data ["status_code" ] = 200
145139
@@ -249,7 +243,10 @@ async def rollout_entrypoint_wrapper(payload, context):
249243 rollout_config = None
250244 if rollout_dict is not None and any (f in rollout_dict for f in _S3_CONFIG_FIELDS ):
251245 rollout_config = RolloutConfig .from_dict (rollout_dict )
252- result_key = f"{ rollout_config .exp_id } /{ rollout_config .input_id } _{ rollout_config .session_id } .json"
246+ # session_id comes from ACR's HTTP header (set via runtimeSessionId),
247+ # fall back to UUID for local testing.
248+ session_id = context .session_id if context .session_id else str (uuid .uuid4 ())
249+ result_key = f"{ rollout_config .exp_id } /{ rollout_config .input_id } /{ session_id } .json"
253250
254251 # Start background task without waiting
255252 asyncio .create_task (rollout_background_task (payload , context , result_key ))
0 commit comments