Skip to content

Commit 22cb416

Browse files
committed
chore: update
1 parent 93a1d2f commit 22cb416

File tree

1 file changed

+114
-21
lines changed

1 file changed

+114
-21
lines changed

README.md

Lines changed: 114 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,55 +4,148 @@
44
[![PyPI](https://img.shields.io/pypi/v/strands-env.svg)](https://pypi.org/project/strands-env/)
55
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)
66

7-
Standardizing environment infrastructure with [Strands Agents](https://github.com/strands-agents/sdk-python) — step, observe, reward.
7+
RL environment abstraction for [Strands Agents](https://github.com/strands-agents/sdk-python) — step, observe, reward.
88

9-
> `strands-agents` is designed for serving, not training. `strands-env` integrates [`strands-sglang`](https://github.com/horizon-rl/strands-sglang) to bridge this gap.
9+
## Features
1010

11-
## Define an environment
11+
This package standardizes agent environments by treating each `env.step()` as a full agent loop (`prompt → (tool_call, tool_response)* → response`), not a single model call. Built on [strands](https://github.com/strands-agents/sdk-python) agent loop and [`strands-sglang`](https://github.com/horizon-rl/strands-sglang) for RL training.
1212

13-
Subclass `Environment` and customize your tools:
13+
- **Define environments easily** — subclass `Environment` and implement tools as `@tool` functions
14+
- **Capture token-level observations** — TITO data for on-policy RL training (SGLang backend)
15+
- **Plug in reward functions** — evaluate agent outputs with custom `RewardFunction`
16+
- **Run benchmarks**`Evaluator` with pass@k metrics, checkpointing, and resume
17+
18+
## Install
19+
20+
```bash
21+
pip install strands-env
22+
```
23+
24+
For development:
25+
26+
```bash
27+
git clone https://github.com/horizon-rl/strands-env.git && cd strands-env
28+
pip install -e ".[dev]"
29+
```
30+
31+
## Usage
32+
33+
### Define an Environment
34+
35+
Subclass `Environment` and add tools as `@tool`-decorated functions:
1436

1537
```python
16-
from strands_tools import calculator
17-
from strands_env.core.environment import Environment
38+
from strands import tool
39+
from strands_env.core import Environment
40+
41+
@tool
42+
def calculator(expression: str) -> str:
43+
"""Evaluate a math expression."""
44+
return str(eval(expression))
1845

1946
class MathEnv(Environment):
2047
def get_tools(self):
2148
return [calculator]
2249
```
2350

24-
## Run it
51+
### Run It
2552

2653
```python
2754
env = MathEnv(model_factory=factory, reward_fn=reward_fn)
2855
result = await env.step(Action(message="What is 2^10?", task_context=TaskContext(ground_truth="1024")))
2956

3057
result.observation.final_response # "1024"
31-
result.observation.tokens # TokenObservation (SGLang only, for on-policy RL training)
58+
result.observation.tokens # TokenObservation (SGLang only)
3259
result.reward.reward # 1.0
33-
result.termination_reason # task_complete
60+
result.termination_reason # TerminationReason.TASK_COMPLETE
3461
```
3562

36-
Each `step()` runs a full agent loop (reasoning + tool calls), not a single model call. Strands' hook-based design makes it easy to customize what happens within each step.
37-
38-
## Install
63+
See [`examples/math_env.py`](examples/math_env.py) for a complete example:
3964

4065
```bash
41-
pip install strands-env
66+
python examples/math_env.py --backend sglang --sglang-base-url http://localhost:30000
4267
```
4368

44-
For development:
69+
## RL Training
4570

46-
```bash
47-
git clone https://github.com/horizon-rl/strands-env.git && cd strands-env
48-
pip install -e ".[dev]"
71+
For RL training with [slime](https://github.com/THUDM/slime/), customize the `generate` and `reward_func` methods to replace single generation with agentic rollout:
72+
73+
```python
74+
from strands_env.core import Action, TaskContext
75+
from strands_env.core.models import sglang_model_factory
76+
from strands_env.utils import get_cached_client_from_slime_args
77+
78+
async def generate(args, sample, sampling_params):
79+
# Build model factory with cached client
80+
factory = sglang_model_factory(
81+
model_id=args.hf_checkpoint,
82+
tokenizer=tokenizer,
83+
client=get_cached_client_from_slime_args(args),
84+
sampling_params=sampling_params,
85+
)
86+
87+
# Create environment and run step
88+
env = YourEnv(model_factory=factory, reward_fn=None)
89+
action = Action(message=sample.prompt, task_context=TaskContext(ground_truth=sample.label))
90+
step_result = await env.step(action)
91+
92+
# Extract TITO data for training
93+
token_obs = step_result.observation.tokens
94+
sample.tokens = token_obs.token_ids
95+
sample.loss_mask = token_obs.rollout_loss_mask
96+
sample.rollout_log_probs = token_obs.rollout_logprobs
97+
sample.response_length = len(token_obs.rollout_token_ids)
98+
99+
# Attach for reward computation
100+
sample.action = action
101+
sample.step_result = step_result
102+
return sample
103+
104+
async def reward_func(args, sample, **kwargs):
105+
reward_fn = YourRewardFunction()
106+
reward_result = await reward_fn.compute(action=sample.action, step_result=sample.step_result)
107+
return reward_result.reward
49108
```
50109

51-
See [`examples/math_env.py`](examples/math_env.py) for a complete runnable example:
110+
Key points:
111+
- `get_cached_client_from_slime_args(args)` provides connection pooling across rollouts
112+
- `TokenObservation` contains token IDs and logprobs for on-policy training
113+
- Reward is computed separately to allow async/batched reward computation
114+
115+
## Evaluation
116+
117+
The `Evaluator` orchestrates concurrent rollouts with checkpointing and pass@k metrics. It takes an async `env_factory` for flexible environment creation per sample, and subclasses implement `load_dataset` for different benchmarks:
118+
119+
```python
120+
...
121+
from strands_env.eval import Evaluator
122+
123+
class YourEvaluator(Evaluator):
124+
benchmark_name = "YourBenchmark"
125+
126+
def load_dataset(self) -> Iterable[Action]:
127+
...
128+
129+
async def env_factory(action: Action) -> Environment:
130+
...
131+
132+
evaluator = YourEvaluator(
133+
env_factory=env_factory,
134+
n_samples_per_prompt=8,
135+
max_concurrency=30,
136+
keep_tokens=False, # Set True if requiring token-level trajectories (SGLang only)
137+
metrics_fns=[...], # Define more metrics, pass@k has been included by default
138+
)
139+
140+
actions = evaluator.load_dataset()
141+
results = await evaluator.run(actions)
142+
metrics = evaluator.compute_metrics(results) # {"pass@1": 0.75, "pass@8": 0.95}
143+
```
144+
145+
See [`examples/aime_eval.py`](examples/aime_eval.py) for a complete example:
52146

53147
```bash
54-
python examples/math_env.py --backend sglang --sglang-base-url http://localhost:30000
55-
python examples/math_env.py --backend bedrock --model-id us.anthropic.claude-sonnet-4-20250514
148+
python examples/aime_eval.py --backend sglang --sglang-base-url http://localhost:30000
56149
```
57150

58151
## Development
@@ -70,4 +163,4 @@ pytest tests/integration/ -v --sglang-base-url=http://localhost:30000
70163

71164
## License
72165

73-
Apache License 2.0 - see [LICENSE](LICENSE).
166+
Apache License 2.0 see [LICENSE](LICENSE).

0 commit comments

Comments
 (0)