|
| 1 | +"""Tests for the @rollout_entrypoint decorator.""" |
| 2 | + |
| 3 | +import inspect |
| 4 | + |
| 5 | +from starlette.testclient import TestClient |
| 6 | + |
| 7 | +from agentcore_rl_toolkit import AgentCoreRLApp |
| 8 | + |
| 9 | + |
| 10 | +class MockAgentCoreRLApp(AgentCoreRLApp): |
| 11 | + """Minimal concrete implementation for testing.""" |
| 12 | + |
| 13 | + def create_openai_compatible_model(self, **kwargs): |
| 14 | + return None |
| 15 | + |
| 16 | + |
| 17 | +def test_wrapper_signature_has_context(): |
| 18 | + """Test that the wrapper's signature includes (payload, context) for BedrockAgentCoreApp.""" |
| 19 | + app = MockAgentCoreRLApp() |
| 20 | + |
| 21 | + @app.rollout_entrypoint |
| 22 | + async def my_handler(payload: dict): |
| 23 | + return {"rollout_data": [], "rewards": [0]} |
| 24 | + |
| 25 | + wrapper = app.handlers["main"] |
| 26 | + params = list(inspect.signature(wrapper).parameters.keys()) |
| 27 | + |
| 28 | + assert len(params) == 2 |
| 29 | + assert params[0] == "payload" |
| 30 | + assert params[1] == "context" |
| 31 | + |
| 32 | + |
| 33 | +def test_wrapper_preserves_function_name(): |
| 34 | + """Test that @wraps preserves the original function name.""" |
| 35 | + app = MockAgentCoreRLApp() |
| 36 | + |
| 37 | + @app.rollout_entrypoint |
| 38 | + async def my_custom_handler(payload: dict): |
| 39 | + return {"rollout_data": [], "rewards": [0]} |
| 40 | + |
| 41 | + wrapper = app.handlers["main"] |
| 42 | + assert wrapper.__name__ == "my_custom_handler" |
| 43 | + |
| 44 | + |
| 45 | +def test_entrypoint_with_payload_only(): |
| 46 | + """Test that user function with signature (payload) works.""" |
| 47 | + app = MockAgentCoreRLApp() |
| 48 | + |
| 49 | + @app.rollout_entrypoint |
| 50 | + async def handler(payload: dict): |
| 51 | + return {"rollout_data": [{"test": True}], "rewards": [1.0]} |
| 52 | + |
| 53 | + client = TestClient(app) |
| 54 | + response = client.post("/invocations", json={"prompt": "test"}) |
| 55 | + |
| 56 | + assert response.status_code == 200 |
| 57 | + assert response.json() == {"status": "processing"} |
| 58 | + |
| 59 | + |
| 60 | +def test_entrypoint_with_payload_and_context(): |
| 61 | + """Test that user function with signature (payload, context) works.""" |
| 62 | + app = MockAgentCoreRLApp() |
| 63 | + |
| 64 | + @app.rollout_entrypoint |
| 65 | + async def handler(payload: dict, context): |
| 66 | + return {"rollout_data": [{"session": context.session_id}], "rewards": [1.0]} |
| 67 | + |
| 68 | + client = TestClient(app) |
| 69 | + response = client.post( |
| 70 | + "/invocations", |
| 71 | + json={"prompt": "test"}, |
| 72 | + headers={"X-Amz-Bedrock-AgentCore-Session-Id": "session-123"}, |
| 73 | + ) |
| 74 | + |
| 75 | + assert response.status_code == 200 |
| 76 | + assert response.json() == {"status": "processing"} |
| 77 | + |
| 78 | + |
| 79 | +def test_entrypoint_with_sync_handler(): |
| 80 | + """Test that sync user function works.""" |
| 81 | + app = MockAgentCoreRLApp() |
| 82 | + |
| 83 | + @app.rollout_entrypoint |
| 84 | + def handler(payload: dict): |
| 85 | + return {"rollout_data": [{"sync": True}], "rewards": [1.0]} |
| 86 | + |
| 87 | + client = TestClient(app) |
| 88 | + response = client.post("/invocations", json={"prompt": "test"}) |
| 89 | + |
| 90 | + assert response.status_code == 200 |
| 91 | + assert response.json() == {"status": "processing"} |
0 commit comments