|
| 1 | +# Copyright 2025 Horizon RL Contributors |
| 2 | + |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | + |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | + |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
1 | 15 | """Math environment example — run math problems through a Strands agent with a calculator tool. |
2 | 16 |
|
3 | 17 | Usage: |
|
15 | 29 | import logging |
16 | 30 |
|
17 | 31 | import httpx |
18 | | -from strands_tools import calculator |
19 | 32 |
|
20 | | -from strands_env.core.environment import Environment |
21 | 33 | from strands_env.core.models import ModelFactory, bedrock_model_factory, sglang_model_factory |
22 | 34 | from strands_env.core.types import Action, TaskContext |
| 35 | +from strands_env.environments.simple_math_env import SimpleMathEnv |
23 | 36 | from strands_env.rewards.math_reward import MathRewardFunction |
24 | 37 |
|
25 | 38 | logger = logging.getLogger(__name__) |
|
31 | 44 | ] |
32 | 45 |
|
33 | 46 |
|
34 | | -# --------------------------------------------------------------------------- |
35 | | -# Environment |
36 | | -# --------------------------------------------------------------------------- |
37 | | - |
38 | | - |
39 | | -class MathEnvironment(Environment): |
40 | | - """Environment that gives the agent a calculator tool to solve math problems.""" |
41 | | - |
42 | | - def get_tools(self) -> list: |
43 | | - return [calculator] |
44 | | - |
45 | | - |
46 | 47 | # --------------------------------------------------------------------------- |
47 | 48 | # Model factory helpers |
48 | 49 | # --------------------------------------------------------------------------- |
@@ -96,7 +97,7 @@ async def main() -> None: |
96 | 97 | args = parser.parse_args() |
97 | 98 |
|
98 | 99 | model_factory = create_model_factory(args) |
99 | | - env = MathEnvironment(model_factory=model_factory, reward_fn=MathRewardFunction(), verbose=args.verbose) |
| 100 | + env = SimpleMathEnv(model_factory=model_factory, reward_fn=MathRewardFunction(), verbose=args.verbose) |
100 | 101 |
|
101 | 102 | for question, ground_truth in MATH_PROBLEMS: |
102 | 103 | print(f"\n{'=' * 60}") |
|
0 commit comments