|
25 | 25 | from pathlib import Path |
26 | 26 |
|
27 | 27 | from pydantic import BaseModel |
| 28 | +from tqdm import tqdm |
| 29 | +from tqdm.contrib.logging import logging_redirect_tqdm |
28 | 30 |
|
29 | 31 | from strands_env.core import Action, Environment, StepResult |
30 | 32 |
|
@@ -159,25 +161,23 @@ async def run(self, actions: Iterable[Action]) -> dict[str, list[EvalSample]]: |
159 | 161 |
|
160 | 162 | semaphore = asyncio.Semaphore(self.max_concurrency) |
161 | 163 | save_counter = 0 |
162 | | - completed = 0 |
163 | 164 | total = len(to_process) |
164 | 165 |
|
165 | | - async def process(prompt_id: str, sample_id: str, action: Action) -> None: |
166 | | - nonlocal save_counter, completed |
| 166 | + async def process(prompt_id: str, sample_id: str, action: Action, pbar: tqdm) -> None: |
| 167 | + nonlocal save_counter |
167 | 168 | async with semaphore: |
168 | 169 | sample = await self.evaluate_sample(action) |
169 | 170 | self.results[prompt_id].append(sample) |
170 | 171 | self.completed_ids.add(sample_id) |
171 | | - completed += 1 |
| 172 | + pbar.update(1) |
172 | 173 | save_counter += 1 |
173 | 174 | if save_counter >= self.save_interval: |
174 | 175 | self.save_results() |
175 | | - logger.info(f"Progress: {completed}/{total}") |
176 | 176 | save_counter = 0 |
177 | 177 |
|
178 | | - await asyncio.gather(*[process(pid, sid, a) for pid, sid, a in to_process]) |
179 | | - |
180 | | - logger.info(f"Completed: {completed}/{total}") |
| 178 | + with logging_redirect_tqdm(): |
| 179 | + with tqdm(total=total, desc=f"Evaluating {self.benchmark_name}", unit="sample", dynamic_ncols=True) as pbar: |
| 180 | + await asyncio.gather(*[process(pid, sid, a, pbar) for pid, sid, a in to_process]) |
181 | 181 | self.save_results() |
182 | 182 | return dict(self.results) |
183 | 183 |
|
|
0 commit comments