Skip to content

Commit 152ebec

Browse files
authored
[Fix] vLLM Metrics Scrapper throughput calculation (NovaSky-AI#1794)
For more context checkout: [NovaSky-AI#1636 comment 1](https://github.com/NovaSky-AI/SkyRL/pull/1636/changes#r3411074771) [NovaSky-AI#1636 comment 2](https://github.com/NovaSky-AI/SkyRL/pull/1636/changes#r3411079924) The current vLLM Metrics Scrapper calculates rollout throughput with the full step size time, and doesn't account for variances caused during eval (different sampling params, eval batch sizes). This PR addresses those issues by updating the timing logic and separating metrics for training and eval rollout (for synchronous RL, in async the rollouts are combined)
1 parent fedc0b7 commit 152ebec

4 files changed

Lines changed: 443 additions & 35 deletions

File tree

skyrl/train/evaluate.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import time
12
from collections import defaultdict
23
from pathlib import Path
3-
from typing import Any, Dict, List
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
45

56
import torch
67
from loguru import logger
@@ -29,6 +30,9 @@
2930
validate_generator_output,
3031
)
3132

33+
if TYPE_CHECKING:
34+
from skyrl.train.utils.vllm_metrics_scraper import VLLMMetricsScraper
35+
3236

3337
@torch.no_grad()
3438
async def evaluate(
@@ -37,6 +41,7 @@ async def evaluate(
3741
cfg: SkyRLTrainConfig,
3842
global_step: int | None,
3943
tokenizer: AutoTokenizer,
44+
vllm_metrics_scraper: Optional["VLLMMetricsScraper"] = None,
4045
) -> Dict[str, float]:
4146
"""Runs generation and evaluation of trajectories.
4247
@@ -47,6 +52,9 @@ async def evaluate(
4752
global_step (int | None): current global step, or
4853
`None` to indicate a non-training context (e.g., eval-only)
4954
tokenizer (AutoTokenizer): tokenizer to use
55+
vllm_metrics_scraper: when set, the open ``vllm/eval`` window is resumed
56+
around each generation and paused after, so only generation time
57+
counts toward eval throughput.
5058
5159
Returns:
5260
Dict[str, float]: evaluation metrics
@@ -58,6 +66,7 @@ async def evaluate(
5866
concat_env_extras: List[Dict[str, Any]] = []
5967
concat_uids: List[str] = []
6068
sampling_params = cfg.generator.eval_sampling_params
69+
eval_generate_time = 0.0
6170
pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress")
6271
for _, prompts in enumerate(eval_dataloader):
6372
pbar.update(1)
@@ -69,7 +78,13 @@ async def evaluate(
6978
"eval",
7079
global_step,
7180
)
81+
gen_start = time.monotonic()
82+
if vllm_metrics_scraper is not None:
83+
vllm_metrics_scraper.resume()
7284
generator_output: GeneratorOutput = await generator.generate(generator_input)
85+
if vllm_metrics_scraper is not None:
86+
vllm_metrics_scraper.pause()
87+
eval_generate_time += time.monotonic() - gen_start
7388
validate_generator_output(len(generator_input["prompts"]), generator_output)
7489
generator_outputs.append(generator_output)
7590
concat_all_envs.extend(generator_input["env_classes"])
@@ -127,6 +142,7 @@ async def evaluate(
127142
eval_metrics,
128143
)
129144

145+
eval_metrics["timing/eval_generate"] = eval_generate_time
130146
return eval_metrics
131147

132148

@@ -137,6 +153,7 @@ async def evaluate_step_wise(
137153
cfg: SkyRLTrainConfig,
138154
global_step: int | None,
139155
tokenizer: AutoTokenizer,
156+
vllm_metrics_scraper: Optional["VLLMMetricsScraper"] = None,
140157
) -> Dict[str, float]:
141158
"""Runs generation and evaluation of trajectories for step-wise training.
142159
@@ -149,6 +166,9 @@ async def evaluate_step_wise(
149166
global_step (int | None): current global step, or
150167
`None` to indicate a non-training context (e.g., eval-only)
151168
tokenizer (AutoTokenizer): tokenizer to use
169+
vllm_metrics_scraper: when set, the open ``vllm/eval`` window is resumed
170+
around each generation and paused after, so only generation time
171+
counts toward eval throughput.
152172
153173
Returns:
154174
Dict[str, float]: evaluation metrics
@@ -160,6 +180,7 @@ async def evaluate_step_wise(
160180
concat_env_extras: List[Dict[str, Any]] = []
161181
concat_uids: List[str] = []
162182
sampling_params = cfg.generator.eval_sampling_params
183+
eval_generate_time = 0.0
163184
pbar = tqdm(total=len(eval_dataloader), initial=0, desc="Evaluation Progress")
164185
for _, prompts in enumerate(eval_dataloader):
165186
pbar.update(1)
@@ -171,7 +192,13 @@ async def evaluate_step_wise(
171192
"eval",
172193
global_step,
173194
)
195+
gen_start = time.monotonic()
196+
if vllm_metrics_scraper is not None:
197+
vllm_metrics_scraper.resume()
174198
generator_output: GeneratorOutput = await generator.generate(generator_input)
199+
if vllm_metrics_scraper is not None:
200+
vllm_metrics_scraper.pause()
201+
eval_generate_time += time.monotonic() - gen_start
175202
traj_id_to_input = {
176203
traj_id.instance_id: {"env_class": env_class, "env_extras": env_extra}
177204
for traj_id, env_class, env_extra in zip(
@@ -244,4 +271,5 @@ async def evaluate_step_wise(
244271
eval_metrics,
245272
)
246273

274+
eval_metrics["timing/eval_generate"] = eval_generate_time
247275
return eval_metrics

skyrl/train/trainer.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,18 @@ def _build_train_dataloader_and_compute_training_steps(self):
204204
self.total_training_steps = min(self.total_training_steps, self.cfg.trainer.max_training_steps)
205205

206206
@torch.no_grad()
207-
async def eval(self) -> Dict[str, float]:
207+
async def eval(self, vllm_metrics_scraper: Optional[VLLMMetricsScraper] = None) -> Dict[str, float]:
208208
"""
209209
Run generation and scoring on the evaluation dataset.
210210
211211
The eval metrics are recorded after having finished training `self.global_step` steps.
212212
Metrics recorded in global_step 0 corresponds to evaluations before training.
213213
214+
Args:
215+
vllm_metrics_scraper: when provided, the eval loop calls
216+
``resume()``/``pause()`` around each generation so the scraper
217+
attributes only generation time to the open ``vllm/eval`` window.
218+
214219
Returns:
215220
A dictionary of evaluation metrics.
216221
"""
@@ -221,6 +226,7 @@ async def eval(self) -> Dict[str, float]:
221226
cfg=self.cfg,
222227
global_step=self.global_step,
223228
tokenizer=self.tokenizer,
229+
vllm_metrics_scraper=vllm_metrics_scraper,
224230
)
225231
else:
226232
eval_metrics = await evaluate(
@@ -229,6 +235,7 @@ async def eval(self) -> Dict[str, float]:
229235
cfg=self.cfg,
230236
global_step=self.global_step,
231237
tokenizer=self.tokenizer,
238+
vllm_metrics_scraper=vllm_metrics_scraper,
232239
)
233240
return eval_metrics
234241

@@ -294,6 +301,13 @@ async def train(self):
294301
if not step_started:
295302
self._fire("on_step_start")
296303
step_started = True
304+
# Open the train-rollout metrics window once per logical
305+
# step; paused so only the generation spans count toward the
306+
# throughput denominator (dynamic sampling may generate more
307+
# than once before the step completes).
308+
if self._vllm_metrics_scraper is not None:
309+
await self._vllm_metrics_scraper.start("vllm/train")
310+
self._vllm_metrics_scraper.pause()
297311
with Timer("step", self.all_timings):
298312
# for colocate_all=true, inference engine is always on GPU when starting the training step
299313

@@ -311,8 +325,12 @@ async def train(self):
311325
)
312326

313327
# 1.1. generation phase
328+
if self._vllm_metrics_scraper is not None:
329+
self._vllm_metrics_scraper.resume()
314330
with Timer("generate", self.all_timings):
315331
generator_output: GeneratorOutput = await self.generate(generator_input)
332+
if self._vllm_metrics_scraper is not None:
333+
self._vllm_metrics_scraper.pause()
316334

317335
if self.cfg.generator.step_wise_trajectories:
318336
# NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
@@ -331,6 +349,13 @@ async def train(self):
331349
# if we are not continuing sampling, we sleep the inference engine
332350
await self.inference_engine_client.sleep()
333351

352+
# The train rollout for this step is done generating; close
353+
# its metrics window. ``vllm/eval/*`` is collected separately
354+
# around eval below.
355+
vllm_metrics: Dict[str, float] = {}
356+
if self._vllm_metrics_scraper is not None:
357+
vllm_metrics = await self._vllm_metrics_scraper.stop()
358+
334359
# 1.2 postprocess rewards (and merge step-wise turns if enabled)
335360
with Timer("postprocess_generator_output", self.all_timings):
336361
generator_output, uids = self.postprocess_generator_output(generator_output, uids)
@@ -434,18 +459,26 @@ async def train(self):
434459
or self.global_step == self.total_training_steps
435460
)
436461
if force_eval or interval_eval:
462+
# Open the eval-rollout window; the scraper itself measures
463+
# the generation spans via resume()/pause() inside eval().
464+
if self._vllm_metrics_scraper is not None:
465+
await self._vllm_metrics_scraper.start("vllm/eval")
466+
self._vllm_metrics_scraper.pause()
437467
self._fire("on_eval_start")
438468
with Timer("eval", self.all_timings):
439-
eval_metrics = await self.eval()
469+
eval_metrics = await self.eval(vllm_metrics_scraper=self._vllm_metrics_scraper)
440470
self.all_metrics.update(eval_metrics)
441471
self._fire("on_eval_end", metrics=eval_metrics)
472+
if self._vllm_metrics_scraper is not None:
473+
vllm_metrics.update(await self._vllm_metrics_scraper.stop())
442474

443475
log_payload = {
444476
**self.all_metrics,
445477
**{f"timing/{k}": v for k, v in self.all_timings.items()},
478+
# vllm/train/* = train rollout, vllm/eval/* = eval rollout,
479+
# each over its own generation time (owned by the scraper).
480+
**vllm_metrics,
446481
}
447-
if self._vllm_metrics_scraper is not None:
448-
log_payload.update(await self._vllm_metrics_scraper.sample())
449482

450483
if self._ray_gpu_monitor is not None:
451484
log_payload.update(self._ray_gpu_monitor.flush())

0 commit comments

Comments
 (0)