Skip to content

Commit 1c5c28e

Browse files
authored
Minor fixes (#1)
1. Add tests for rollout entrypoint decorator 2. Add dependency constraints and resolutions
1 parent 5e98272 commit 1c5c28e

10 files changed

Lines changed: 1585 additions & 28 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ uv venv --python 3.13
6060
source .venv/bin/activate
6161

6262
# Install with development dependencies
63-
uv pip install -e ".[dev]"
63+
uv sync --frozen --extra dev
6464

6565
# Install pre-commit hooks
6666
pre-commit install

examples/strands_math_agent/basic_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ async def invoke_agent(payload):
3131
print("User input:", user_input)
3232

3333
response = await agent.invoke_async(user_input)
34+
3435
return response.message["content"][0]["text"]
3536

3637

examples/strands_math_agent/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
requires-python = ">=3.11"
77
dependencies = [
88
"bedrock-agentcore>=1.0.3",
9-
"bedrock-agentcore-starter-toolkit>=0.1.34",
9+
"bedrock-agentcore-starter-toolkit<=0.2.0",
1010
"boto3>=1.40.55",
1111
"python-dotenv>=1.0.0",
1212
# TODO: replace the above dependencies with agentcore-rl-toolkit>=0.1.0 after PyPI indexing

examples/strands_math_agent/reward.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,25 @@
44

55

66
class GSM8KReward(RewardFunction):
7+
def __call__(
8+
self,
9+
response_text="",
10+
ground_truth="",
11+
method="strict",
12+
format_score=0.0,
13+
score=1.0,
14+
**kwargs,
15+
):
16+
answer = self.extract_solution(solution_str=response_text, method=method)
17+
if answer is None:
18+
reward = 0
19+
else:
20+
if answer == ground_truth:
21+
reward = score
22+
else:
23+
reward = format_score
24+
return reward
25+
726
@staticmethod
827
def extract_solution(solution_str, method="strict"):
928
"""
@@ -40,22 +59,3 @@ def extract_solution(solution_str, method="strict"):
4059
if final_answer not in invalid_str:
4160
break
4261
return final_answer
43-
44-
def __call__(
45-
self,
46-
response_text="",
47-
ground_truth="",
48-
method="strict",
49-
format_score=0.0,
50-
score=1.0,
51-
**kwargs,
52-
):
53-
answer = self.extract_solution(solution_str=response_text, method=method)
54-
if answer is None:
55-
reward = 0
56-
else:
57-
if answer == ground_truth:
58-
reward = score
59-
else:
60-
reward = format_score
61-
return reward

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ readme = "README.md"
1010
requires-python = ">=3.11"
1111
dependencies = [
1212
"bedrock-agentcore>=1.0.3",
13-
"bedrock-agentcore-starter-toolkit>=0.1.34",
13+
"bedrock-agentcore-starter-toolkit<=0.2.0",
1414
"boto3>=1.40.55",
1515
"python-dotenv>=1.0.0",
1616
]

src/agentcore_rl_toolkit/app.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,5 +243,10 @@ async def rollout_entrypoint_wrapper(payload, context):
243243
asyncio.create_task(rollout_background_task(payload, context))
244244
return {"status": "processing"}
245245

246+
# Remove __wrapped__ so inspect.signature() sees the wrapper's actual signature
247+
# (payload, context) instead of the user function's signature. This ensures
248+
# BedrockAgentCoreApp._takes_context() correctly passes context to this wrapper.
249+
del rollout_entrypoint_wrapper.__wrapped__
250+
246251
# Register using existing BedrockAgentCoreApp entrypoint infrastructure
247252
return self.entrypoint(rollout_entrypoint_wrapper)

src/agentcore_rl_toolkit/frameworks/strands/app.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,29 @@
22

33

44
class StrandsAgentCoreRLApp(AgentCoreRLApp):
5-
def create_openai_compatible_model(self, **kwargs):
6-
"""Create Strands OpenAI-compatible model for vLLM/SGLang server."""
5+
def create_openai_compatible_model(self, provider_model_id=None, **kwargs):
6+
"""
7+
Create Strands model that's compatible with the OpenAI format. When provider_model_id
8+
is provided, LiteLLM model will be used. Otherwise, an OpenAI compatible model with
9+
base_url and model_id will be used.
10+
11+
:param provider_model_id: Provide this parameter when using cloud providers (bedrock,
12+
anthropic, openai, etc.) that does not use a base_url. Example: Otherwise, leave it to None.
13+
"""
714
try:
815
from strands.models.openai import OpenAIModel
916
except ImportError:
10-
raise ImportError("Strands not installed. Install with: uv pip install strands-agents[openai]") from None
17+
raise ImportError("Strands not installed. Install with: " "uv pip install strands-agents[openai]") from None
18+
19+
if not provider_model_id:
20+
base_url, model_id = self._get_model_config()
21+
return OpenAIModel(client_args={"api_key": "dummy", "base_url": base_url}, model_id=model_id, **kwargs)
1122

12-
base_url, model_id = self._get_model_config()
23+
try:
24+
from strands.models.litellm import LiteLLMModel
25+
except ImportError:
26+
raise ImportError(
27+
"Strands not installed. Install with: " "uv pip install strands-agents[litellm]"
28+
) from None
1329

14-
return OpenAIModel(client_args={"api_key": "dummy", "base_url": base_url}, model_id=model_id, **kwargs)
30+
return LiteLLMModel(model_id=provider_model_id, **kwargs)

src/agentcore_rl_toolkit/frameworks/strands/rollout_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def register_hooks(self, registry):
1313
from strands.experimental.hooks import BeforeModelInvocationEvent
1414
from strands.hooks import AfterInvocationEvent
1515
except ImportError:
16-
raise ImportError("Strands not installed. Install with: uv pip install strands-agents[openai]") from None
16+
raise ImportError("Strands not installed. Install with: " "uv pip install strands-agents[openai]") from None
1717

1818
registry.add_callback(BeforeModelInvocationEvent, self.collect_messages)
1919
registry.add_callback(AfterInvocationEvent, self.prepare_rollout)

tests/test_rollout_entrypoint.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Tests for the @rollout_entrypoint decorator."""
2+
3+
import inspect
4+
5+
from starlette.testclient import TestClient
6+
7+
from agentcore_rl_toolkit import AgentCoreRLApp
8+
9+
10+
class MockAgentCoreRLApp(AgentCoreRLApp):
11+
"""Minimal concrete implementation for testing."""
12+
13+
def create_openai_compatible_model(self, **kwargs):
14+
return None
15+
16+
17+
def test_wrapper_signature_has_context():
18+
"""Test that the wrapper's signature includes (payload, context) for BedrockAgentCoreApp."""
19+
app = MockAgentCoreRLApp()
20+
21+
@app.rollout_entrypoint
22+
async def my_handler(payload: dict):
23+
return {"rollout_data": [], "rewards": [0]}
24+
25+
wrapper = app.handlers["main"]
26+
params = list(inspect.signature(wrapper).parameters.keys())
27+
28+
assert len(params) == 2
29+
assert params[0] == "payload"
30+
assert params[1] == "context"
31+
32+
33+
def test_wrapper_preserves_function_name():
34+
"""Test that @wraps preserves the original function name."""
35+
app = MockAgentCoreRLApp()
36+
37+
@app.rollout_entrypoint
38+
async def my_custom_handler(payload: dict):
39+
return {"rollout_data": [], "rewards": [0]}
40+
41+
wrapper = app.handlers["main"]
42+
assert wrapper.__name__ == "my_custom_handler"
43+
44+
45+
def test_entrypoint_with_payload_only():
46+
"""Test that user function with signature (payload) works."""
47+
app = MockAgentCoreRLApp()
48+
49+
@app.rollout_entrypoint
50+
async def handler(payload: dict):
51+
return {"rollout_data": [{"test": True}], "rewards": [1.0]}
52+
53+
client = TestClient(app)
54+
response = client.post("/invocations", json={"prompt": "test"})
55+
56+
assert response.status_code == 200
57+
assert response.json() == {"status": "processing"}
58+
59+
60+
def test_entrypoint_with_payload_and_context():
61+
"""Test that user function with signature (payload, context) works."""
62+
app = MockAgentCoreRLApp()
63+
64+
@app.rollout_entrypoint
65+
async def handler(payload: dict, context):
66+
return {"rollout_data": [{"session": context.session_id}], "rewards": [1.0]}
67+
68+
client = TestClient(app)
69+
response = client.post(
70+
"/invocations",
71+
json={"prompt": "test"},
72+
headers={"X-Amz-Bedrock-AgentCore-Session-Id": "session-123"},
73+
)
74+
75+
assert response.status_code == 200
76+
assert response.json() == {"status": "processing"}
77+
78+
79+
def test_entrypoint_with_sync_handler():
80+
"""Test that sync user function works."""
81+
app = MockAgentCoreRLApp()
82+
83+
@app.rollout_entrypoint
84+
def handler(payload: dict):
85+
return {"rollout_data": [{"sync": True}], "rewards": [1.0]}
86+
87+
client = TestClient(app)
88+
response = client.post("/invocations", json={"prompt": "test"})
89+
90+
assert response.status_code == 200
91+
assert response.json() == {"status": "processing"}

0 commit comments

Comments
 (0)