Skip to content

Commit 5334752

Browse files
Lawhyclaude
andcommitted
feat(eval): add evaluator for running agentic benchmarks
Add Evaluator class with async batching, pass@k computation, and checkpoint/resume support via JSONL output. Returns dict[str, list[EvalSample]] mapping problem_id to rollout samples. Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent dabee47 commit 5334752

File tree

3 files changed

+534
-0
lines changed

3 files changed

+534
-0
lines changed

src/strands_env/eval/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 Horizon RL Contributors
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .evaluator import EvalSample, Evaluator
16+
17+
__all__ = [
18+
"Evaluator",
19+
"EvalSample",
20+
]

src/strands_env/eval/evaluator.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright 2025 Horizon RL Contributors
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Evaluator for running agentic benchmarks with `strands-env` environments."""
16+
17+
from __future__ import annotations
18+
19+
import asyncio
20+
import json
21+
import logging
22+
import math
23+
from collections import defaultdict
24+
from collections.abc import Awaitable, Callable, Iterable
25+
from pathlib import Path
26+
27+
from pydantic import BaseModel
28+
29+
from strands_env.core import Action, Environment, StepResult
30+
31+
logger = logging.getLogger(__name__)
32+
33+
#: Type alias for environment factory function (async).
34+
AsyncEnvFactory = Callable[[Action], Awaitable[Environment]]
35+
36+
37+
class EvalSample(BaseModel):
38+
"""Evaluation sample result."""
39+
40+
action: Action
41+
"""The action (task) that was evaluated."""
42+
43+
step_result: StepResult
44+
"""The result of the step (observation, reward, termination reason)."""
45+
46+
47+
class Evaluator:
48+
"""Evaluator for running concurrent environment evaluations."""
49+
50+
def __init__(
51+
self,
52+
env_factory: AsyncEnvFactory,
53+
*,
54+
max_concurrency: int = 10,
55+
n_rollouts: int = 1,
56+
output_path: Path | str = Path.cwd() / "results.jsonl",
57+
save_interval: int = 10,
58+
keep_tokens: bool = False,
59+
):
60+
"""Initialize the evaluator.
61+
62+
Args:
63+
env_factory: Async factory function that creates a fresh Environment per sample.
64+
max_concurrency: Maximum concurrent evaluate_sample() calls.
65+
n_rollouts: Number of rollouts per problem (for pass@k, set to max(k_values)).
66+
output_path: Path to JSONL file for saving results. Enables resume.
67+
save_interval: Flush results to disk every N completed samples.
68+
keep_tokens: Keep token-level observation in results (only valid for `SGLangModel` backends).
69+
"""
70+
self.env_factory: AsyncEnvFactory = env_factory
71+
72+
# Configuration
73+
self.max_concurrency = max_concurrency
74+
self.n_rollouts = n_rollouts
75+
self.output_path = Path(output_path)
76+
self.save_interval = save_interval
77+
self.keep_tokens = keep_tokens
78+
79+
# Runtime state: {problem_id: [samples]}
80+
self.results: dict[str, list[EvalSample]] = defaultdict(list)
81+
self.completed_ids: set[str] = set() # Tracks individual sample IDs for checkpoint
82+
83+
def load_dataset(self, dataset_path: Path | str) -> Iterable[Action]:
84+
"""Load dataset from file. Override to implement custom dataset loading logic."""
85+
logger.info(f"Loading dataset from: {dataset_path}")
86+
raise NotImplementedError("Evaluator subclasses must implement load_dataset()")
87+
88+
def load_results(self) -> None:
89+
"""Load completed samples from results file."""
90+
if not self.output_path.exists():
91+
return
92+
93+
self.results = defaultdict(list)
94+
self.completed_ids = set()
95+
96+
with open(self.output_path) as f:
97+
for line in f:
98+
data = json.loads(line)
99+
problem_id = data.pop("problem_id")
100+
sample = EvalSample.model_validate(data)
101+
self.results[problem_id].append(sample)
102+
self.completed_ids.add(sample.action.task_context.id)
103+
104+
total = sum(len(samples) for samples in self.results.values())
105+
logger.info(f"Loaded {total} completed samples from: {self.output_path}")
106+
107+
def save_results(self) -> None:
108+
"""Write all samples to results file."""
109+
with open(self.output_path, "w") as f:
110+
for problem_id, samples in self.results.items():
111+
for sample in samples:
112+
data = sample.model_dump()
113+
data["problem_id"] = problem_id
114+
f.write(json.dumps(data) + "\n")
115+
116+
total = sum(len(samples) for samples in self.results.values())
117+
logger.info(f"Saved {total} samples to: {self.output_path}")
118+
119+
async def evaluate_sample(self, action: Action) -> EvalSample:
120+
"""Evaluate a single sample."""
121+
env = await self.env_factory(action)
122+
await env.reset()
123+
step_result = await env.step(action)
124+
if not self.keep_tokens:
125+
# Token trajectory is usually not needed for evaluation.
126+
step_result.observation.tokens = None
127+
await env.cleanup()
128+
return EvalSample(action=action, step_result=step_result)
129+
130+
async def run(self, actions: Iterable[Action]) -> dict[str, list[EvalSample]]:
131+
"""Run evaluation on a collection of actions.
132+
133+
Each action is duplicated `n_rollouts` times for pass@k computation.
134+
Completed samples are saved incrementally and can be resumed via output_path.
135+
136+
Args:
137+
actions: `Iterable` of `Action`s to evaluate.
138+
139+
Returns:
140+
Dict mapping problem_id to list of `EvalSample` rollouts.
141+
"""
142+
self.load_results()
143+
144+
# Build list of (problem_id, sample_id, action) for processing
145+
to_process: list[tuple[str, str, Action]] = []
146+
for action in actions:
147+
problem_id = action.task_context.id
148+
for i in range(self.n_rollouts):
149+
sample_id = f"{problem_id}_{i}"
150+
if sample_id not in self.completed_ids:
151+
expanded = action.model_copy(deep=True)
152+
expanded.task_context.id = sample_id
153+
to_process.append((problem_id, sample_id, expanded))
154+
155+
semaphore = asyncio.Semaphore(self.max_concurrency)
156+
save_counter = 0
157+
158+
async def process(problem_id: str, sample_id: str, action: Action) -> None:
159+
nonlocal save_counter
160+
async with semaphore:
161+
sample = await self.evaluate_sample(action)
162+
self.results[problem_id].append(sample)
163+
self.completed_ids.add(sample_id)
164+
save_counter += 1
165+
if save_counter >= self.save_interval:
166+
self.save_results()
167+
save_counter = 0
168+
169+
tasks = [process(pid, sid, action) for pid, sid, action in to_process]
170+
await asyncio.gather(*tasks)
171+
172+
self.save_results()
173+
return dict(self.results)
174+
175+
@staticmethod
176+
def _pass_at_k_single(n: int, c: int, k: int) -> float:
177+
"""Compute pass@k for a single problem using unbiased estimator.
178+
179+
pass@k = 1 - C(n-c, k) / C(n, k)
180+
181+
Uses log-space for numerical stability with large factorials.
182+
"""
183+
if n - c < k:
184+
return 1.0
185+
if c == 0:
186+
return 0.0
187+
188+
log_ratio = 0.0
189+
for i in range(k):
190+
log_ratio += math.log(n - c - i) - math.log(n - i)
191+
return 1.0 - math.exp(log_ratio)
192+
193+
@staticmethod
194+
def compute_pass_at_k(
195+
results: dict[str, list[EvalSample]],
196+
k_values: list[int] = [1],
197+
reward_threshold: float = 1.0,
198+
) -> dict[int, float]:
199+
"""Compute pass@k metric using unbiased estimator.
200+
201+
Args:
202+
results: Dict mapping problem_id to list of sample rollouts.
203+
k_values: List of k values for pass@k computation.
204+
reward_threshold: Reward threshold for considering a sample "passed" (default: 1.0).
205+
206+
Returns:
207+
Dictionary mapping k to average pass@k score.
208+
"""
209+
if not results:
210+
return {k: 0.0 for k in k_values}
211+
212+
def is_correct(s: EvalSample) -> bool:
213+
reward = s.step_result.reward
214+
return reward is not None and reward.reward >= reward_threshold
215+
216+
# Compute pass@k for each k value
217+
pass_at_k = {}
218+
for k in k_values:
219+
scores = []
220+
for samples in results.values():
221+
n = len(samples)
222+
c = sum(1 for s in samples if is_correct(s))
223+
if k <= n:
224+
scores.append(Evaluator._pass_at_k_single(n, c, k))
225+
pass_at_k[k] = sum(scores) / len(scores) if scores else 0.0
226+
227+
return pass_at_k

0 commit comments

Comments
 (0)