|
| 1 | +# CLAUDE.md |
| 2 | + |
| 3 | +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. |
| 4 | + |
| 5 | +## Project Overview |
| 6 | + |
| 7 | +Strands-env is an RL environment abstraction for Strands agents — step, observe, reward. It provides a base `Environment` class that wraps a Strands `Agent` with token-level observation tracking (TITO), reward computation, and termination handling. Supports SGLang, Bedrock, and OpenAI model backends. |
| 8 | + |
| 9 | +## Commands |
| 10 | + |
| 11 | +### Setup |
| 12 | +```bash |
| 13 | +pip install -e ".[dev]" |
| 14 | +``` |
| 15 | + |
| 16 | +### Linting |
| 17 | +```bash |
| 18 | +ruff check src/ |
| 19 | +ruff format --check src/ |
| 20 | +``` |
| 21 | + |
| 22 | +### Testing |
| 23 | +```bash |
| 24 | +# Unit tests (no server needed) |
| 25 | +pytest tests/unit/ -v |
| 26 | + |
| 27 | +# Single test |
| 28 | +pytest tests/unit/test_environment.py::TestStep::test_successful_step -v |
| 29 | + |
| 30 | +# Unit tests with coverage |
| 31 | +pytest tests/unit/ -v --cov=src/strands_env --cov-report=html |
| 32 | + |
| 33 | +# Integration tests (requires running SGLang server; model ID auto-detected via /get_model_info) |
| 34 | +# Tests skip automatically if server is unreachable (/health check) |
| 35 | +pytest tests/integration/ -v --sglang-base-url=http://localhost:30000 |
| 36 | +# Or via env var: SGLANG_BASE_URL=http://localhost:30000 pytest tests/integration/ |
| 37 | +``` |
| 38 | + |
| 39 | +### Integration Tests with Remote GPU Server |
| 40 | + |
| 41 | +```bash |
| 42 | +# 1. Launch SGLang on the remote server in docker |
| 43 | +ssh <remote-host> "sudo docker run -d --gpus '\"device=0\"' --name sglang-test -p 30000:30000 --ipc=host lmsysorg/sglang:<tag> python3 -m sglang.launch_server --model-path <model-id> --host 0.0.0.0 --port 30000 --tp <num_gpus> --mem-fraction-static 0.7" |
| 44 | +# 2. Tunnel the port locally |
| 45 | +ssh -L 30000:localhost:30000 -N -f <remote-host> |
| 46 | +# 3. Run tests locally |
| 47 | +pytest tests/integration/ -v |
| 48 | +``` |
| 49 | + |
| 50 | +## Architecture |
| 51 | + |
| 52 | +The package lives in `src/strands_env/core/` with three modules: |
| 53 | + |
| 54 | +**types.py** — All data types. `Action` carries a user message + `TaskContext` (ground truth, conversation history, arbitrary metadata via `extra="allow"`). `Observation` holds messages, metrics, and optional `TokenObservation` for TITO training. `TerminationReason` maps agent exceptions to enum values via `from_error()` which walks exception cause chains. `StepResult` bundles observation + reward + termination reason. |
| 55 | + |
| 56 | +**models.py** — `ModelFactory = Callable[[], Model]` type and three factory functions (`sglang_model_factory`, `bedrock_model_factory`, `openai_model_factory`). Each returns a zero-arg lambda that creates a fresh Model instance per `step()` call for concurrent isolation. Bedrock and OpenAI remap `max_new_tokens` → `max_tokens` with a shallow dict copy to avoid mutating defaults. |
| 57 | + |
| 58 | +**environment.py** — Base `Environment` class. `step(action)` creates a fresh model via factory, attaches a `TokenManager`, builds an `Agent` with tools/hooks (always includes `ToolIterationLimiter`), runs `invoke_async`, then collects metrics and optional reward. Subclasses override `get_tools()` and `get_hooks()` to customize. Messages are sliced so only new messages from the current step appear in the observation. |
| 59 | + |
| 60 | +### Key Design Decisions |
| 61 | + |
| 62 | +- **Factory pattern**: `ModelFactory` returns lambdas (not Model instances) so each `step()` gets a fresh model with clean token tracking state. |
| 63 | +- **TITO token tracking**: `TokenManager` on SGLang models captures exact token IDs and logprobs during generation. `TokenObservation.from_token_manager()` extracts prompt/rollout split. Non-SGLang models get an empty `TokenManager` (returns `None` from `from_token_manager`). |
| 64 | +- **`list()` copies**: Tools, hooks, and messages are copied via `list()` before passing to Agent to prevent cross-step mutation. |
| 65 | +- **ToolIterationLimiter**: Always prepended to hooks list. Raises `MaxToolIterationsReachedError` which `TerminationReason.from_error()` maps to `MAX_TOOL_ITERATIONS_REACHED`. |
| 66 | + |
| 67 | +## Code Style |
| 68 | + |
| 69 | +- Ruff for linting and formatting (line-length 120, rules: E, F, I, N, W) |
| 70 | +- Conventional commits (feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert) |
| 71 | +- Python 3.10+ required |
| 72 | +- asyncio_mode = "auto" for pytest-asyncio |
| 73 | +- Async-first: all Environment methods that interact with Agent are async |
0 commit comments