Skip to content

Commit 55cbf25

Browse files
authored
feat(client): add input_id attribute to RolloutFuture (#23)
1 parent e137007 commit 55cbf25

2 files changed

Lines changed: 28 additions & 0 deletions

File tree

src/agentcore_rl_toolkit/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ def __init__(
4545
max_interval: float = 30.0,
4646
backoff_factor: float = 1.5,
4747
session_id: str = None,
48+
input_id: str = None,
4849
agentcore_client=None,
4950
agent_runtime_arn: str = None,
5051
):
5152
self.s3_client = s3_client
5253
self.s3_bucket = s3_bucket
5354
self.result_key = result_key
5455
self.session_id = session_id
56+
self.input_id = input_id
5557
self.agentcore_client = agentcore_client
5658
self.agent_runtime_arn = agent_runtime_arn
5759
self._result = None
@@ -284,6 +286,7 @@ def _rate_limited_invoke(self, payload: dict, session_id: str, input_id: str) ->
284286
s3_bucket=data["s3_bucket"],
285287
result_key=data["result_key"],
286288
session_id=session_id,
289+
input_id=input_id,
287290
agentcore_client=self.agentcore_client,
288291
agent_runtime_arn=self.agent_runtime_arn,
289292
)

tests/test_client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,31 @@ def test_invoke_future_has_session_id(self):
443443
assert future.agentcore_client is mock_acr
444444
assert future.agent_runtime_arn == "arn:aws:bedrock-agentcore:us-west-2:123:agent/test"
445445

446+
def test_invoke_future_has_input_id(self):
447+
"""Test invoke() returns future with input_id set."""
448+
with patch("agentcore_rl_toolkit.client.boto3") as mock_boto3:
449+
mock_acr = MagicMock()
450+
mock_s3 = MagicMock()
451+
mock_boto3.client.side_effect = lambda service, **kwargs: (
452+
mock_acr if service == "bedrock-agentcore" else mock_s3
453+
)
454+
455+
mock_acr.invoke_agent_runtime.return_value = {
456+
"response": mock_streaming_body(
457+
{"status": "processing", "s3_bucket": "test-bucket", "result_key": "exp/key.json"}
458+
)
459+
}
460+
461+
client = RolloutClient(
462+
agent_runtime_arn="arn:aws:bedrock-agentcore:us-west-2:123:agent/test",
463+
s3_bucket="test-bucket",
464+
exp_id="exp-001",
465+
)
466+
467+
future = client.invoke({"prompt": "test"}, input_id="my-input")
468+
469+
assert future.input_id == "my-input"
470+
446471
def test_run_batch_returns_batch_result(self):
447472
"""Test run_batch() returns a BatchResult."""
448473
with patch("agentcore_rl_toolkit.client.boto3"):

0 commit comments

Comments
 (0)