Skip to content

Commit 0de7139

Browse files
committed
chore: update
1 parent acd95a9 commit 0de7139

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

examples/math_env.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2025 Horizon RL Contributors
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
"""Math environment example — run math problems through a Strands agent with a calculator tool.
216
317
Usage:
@@ -15,11 +29,10 @@
1529
import logging
1630

1731
import httpx
18-
from strands_tools import calculator
1932

20-
from strands_env.core.environment import Environment
2133
from strands_env.core.models import ModelFactory, bedrock_model_factory, sglang_model_factory
2234
from strands_env.core.types import Action, TaskContext
35+
from strands_env.environments.simple_math_env import SimpleMathEnv
2336
from strands_env.rewards.math_reward import MathRewardFunction
2437

2538
logger = logging.getLogger(__name__)
@@ -31,18 +44,6 @@
3144
]
3245

3346

34-
# ---------------------------------------------------------------------------
35-
# Environment
36-
# ---------------------------------------------------------------------------
37-
38-
39-
class MathEnvironment(Environment):
40-
"""Environment that gives the agent a calculator tool to solve math problems."""
41-
42-
def get_tools(self) -> list:
43-
return [calculator]
44-
45-
4647
# ---------------------------------------------------------------------------
4748
# Model factory helpers
4849
# ---------------------------------------------------------------------------
@@ -96,7 +97,7 @@ async def main() -> None:
9697
args = parser.parse_args()
9798

9899
model_factory = create_model_factory(args)
99-
env = MathEnvironment(model_factory=model_factory, reward_fn=MathRewardFunction(), verbose=args.verbose)
100+
env = SimpleMathEnv(model_factory=model_factory, reward_fn=MathRewardFunction(), verbose=args.verbose)
100101

101102
for question, ground_truth in MATH_PROBLEMS:
102103
print(f"\n{'=' * 60}")

0 commit comments

Comments
 (0)