Skip to content

Commit b563a61

Browse files
authored
refactor(app): remove abstract method requirement from AgentCoreRLApp (#11)
1 parent 7459411 commit b563a61

2 files changed

Lines changed: 15 additions & 17 deletions

File tree

src/agentcore_rl_toolkit/app.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import json
33
import logging
44
import os
5-
from abc import ABC, abstractmethod
65
from dataclasses import dataclass
76
from datetime import datetime, timezone
87
from functools import wraps
@@ -36,25 +35,31 @@ def from_dict(cls, data: dict) -> "TrainingConfig":
3635
raise ValueError(f"Missing required training config field: {e}") from e
3736

3837

39-
class AgentCoreRLApp(BedrockAgentCoreApp, ABC):
38+
class AgentCoreRLApp(BedrockAgentCoreApp):
4039
def __init__(self):
4140
super().__init__()
4241
self.s3_client = boto3.client("s3")
4342
self.sqs_client = boto3.client("sqs")
4443

45-
@abstractmethod
4644
def create_openai_compatible_model(self, **kwargs):
4745
"""Create an OpenAI-compatible model for this framework.
4846
49-
Must be implemented by framework-specific subclasses.
47+
Optional: Override in framework-specific subclasses, or create model directly
48+
in your entrypoint (see examples/strands_migration_agent/dev_app.py).
5049
5150
Args:
5251
**kwargs: Framework-specific model parameters
5352
5453
Returns:
5554
Framework-specific model instance configured for vLLM server
55+
56+
Raises:
57+
NotImplementedError: If called without override. Create model directly instead.
5658
"""
57-
pass
59+
raise NotImplementedError(
60+
"create_openai_compatible_model() is optional. "
61+
"Either override in a subclass or create your model directly in the entrypoint."
62+
)
5863

5964
def _get_model_config(self):
6065
"""Get and validate model configuration from environment."""

tests/test_rollout_entrypoint.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,9 @@
77
from agentcore_rl_toolkit import AgentCoreRLApp
88

99

10-
class MockAgentCoreRLApp(AgentCoreRLApp):
11-
"""Minimal concrete implementation for testing."""
12-
13-
def create_openai_compatible_model(self, **kwargs):
14-
return None
15-
16-
1710
def test_wrapper_signature_has_context():
1811
"""Test that the wrapper's signature includes (payload, context) for BedrockAgentCoreApp."""
19-
app = MockAgentCoreRLApp()
12+
app = AgentCoreRLApp()
2013

2114
@app.rollout_entrypoint
2215
async def my_handler(payload: dict):
@@ -32,7 +25,7 @@ async def my_handler(payload: dict):
3225

3326
def test_wrapper_preserves_function_name():
3427
"""Test that @wraps preserves the original function name."""
35-
app = MockAgentCoreRLApp()
28+
app = AgentCoreRLApp()
3629

3730
@app.rollout_entrypoint
3831
async def my_custom_handler(payload: dict):
@@ -44,7 +37,7 @@ async def my_custom_handler(payload: dict):
4437

4538
def test_entrypoint_with_payload_only():
4639
"""Test that user function with signature (payload) works."""
47-
app = MockAgentCoreRLApp()
40+
app = AgentCoreRLApp()
4841

4942
@app.rollout_entrypoint
5043
async def handler(payload: dict):
@@ -59,7 +52,7 @@ async def handler(payload: dict):
5952

6053
def test_entrypoint_with_payload_and_context():
6154
"""Test that user function with signature (payload, context) works."""
62-
app = MockAgentCoreRLApp()
55+
app = AgentCoreRLApp()
6356

6457
@app.rollout_entrypoint
6558
async def handler(payload: dict, context):
@@ -78,7 +71,7 @@ async def handler(payload: dict, context):
7871

7972
def test_entrypoint_with_sync_handler():
8073
"""Test that sync user function works."""
81-
app = MockAgentCoreRLApp()
74+
app = AgentCoreRLApp()
8275

8376
@app.rollout_entrypoint
8477
def handler(payload: dict):

0 commit comments

Comments
 (0)