Skip to content

Commit 639d55a

Browse files
committed
feat(eval): add pbar for eval progress tracking
1 parent 3ea4c95 commit 639d55a

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"strands-agents-tools",
2626
"math-verify>=0.8.0",
2727
"click>=8.0.0",
28+
"tqdm>=4.0.0",
2829
]
2930

3031
[project.optional-dependencies]

src/strands_env/eval/evaluator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from pathlib import Path
2626

2727
from pydantic import BaseModel
28+
from tqdm import tqdm
29+
from tqdm.contrib.logging import logging_redirect_tqdm
2830

2931
from strands_env.core import Action, Environment, StepResult
3032

@@ -159,25 +161,23 @@ async def run(self, actions: Iterable[Action]) -> dict[str, list[EvalSample]]:
159161

160162
semaphore = asyncio.Semaphore(self.max_concurrency)
161163
save_counter = 0
162-
completed = 0
163164
total = len(to_process)
164165

165-
async def process(prompt_id: str, sample_id: str, action: Action) -> None:
166-
nonlocal save_counter, completed
166+
async def process(prompt_id: str, sample_id: str, action: Action, pbar: tqdm) -> None:
167+
nonlocal save_counter
167168
async with semaphore:
168169
sample = await self.evaluate_sample(action)
169170
self.results[prompt_id].append(sample)
170171
self.completed_ids.add(sample_id)
171-
completed += 1
172+
pbar.update(1)
172173
save_counter += 1
173174
if save_counter >= self.save_interval:
174175
self.save_results()
175-
logger.info(f"Progress: {completed}/{total}")
176176
save_counter = 0
177177

178-
await asyncio.gather(*[process(pid, sid, a) for pid, sid, a in to_process])
179-
180-
logger.info(f"Completed: {completed}/{total}")
178+
with logging_redirect_tqdm():
179+
with tqdm(total=total, desc=f"Evaluating {self.benchmark_name}", unit="sample", dynamic_ncols=True) as pbar:
180+
await asyncio.gather(*[process(pid, sid, a, pbar) for pid, sid, a in to_process])
181181
self.save_results()
182182
return dict(self.results)
183183

0 commit comments

Comments
 (0)