-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathrl_app.py
More file actions
70 lines (51 loc) · 1.93 KB
/
rl_app.py
File metadata and controls
70 lines (51 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import logging
from reward import GSM8KReward
from strands import Agent
from strands.models.openai import OpenAIModel
from strands_tools import calculator
from agentcore_rl_toolkit import AgentCoreRLApp
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = AgentCoreRLApp()
system_prompt = (
"Your task is to solve the math problem. "
+ "Use the calculator tool to compute all mathematical expressions. "
+ 'Let\'s think step by step and output the final answer after "####".'
)
reward_fn = GSM8KReward()
@app.rollout_entrypoint
def invoke_agent(payload: dict):
"""
Invoke the math agent with a payload using the rollout_entrypoint decorator.
For RL training, the following fields are expected:
- prompt: question from gsm8k
- answer: ground truth (str)
- _rollout: rollout config with base_url and model_id
The @rollout_entrypoint decorator automatically:
- Executes the function in the background for non-blocking processing
- Saves results to S3 with a predictable key
- Handles errors and saves error results for client awareness
- Works with both sync and async functions
"""
base_url = payload["_rollout"]["base_url"]
model_id = payload["_rollout"]["model_id"]
params = payload["_rollout"].get("sampling_params", {})
model = OpenAIModel(
client_args={"api_key": "EMPTY", "base_url": base_url},
model_id=model_id,
params=params,
)
agent = Agent(
model=model,
tools=[calculator],
system_prompt=system_prompt,
)
user_input = payload.get("prompt")
answer = payload.get("answer") # used for computing reward
logger.info("User input: %s", user_input)
response = agent(user_input)
# Compute rewards
rewards = reward_fn(response_text=response.message["content"][0]["text"], ground_truth=answer)
return {"rewards": rewards}
if __name__ == "__main__":
app.run()