@@ -204,13 +204,18 @@ def _build_train_dataloader_and_compute_training_steps(self):
204204 self .total_training_steps = min (self .total_training_steps , self .cfg .trainer .max_training_steps )
205205
206206 @torch .no_grad ()
207- async def eval (self ) -> Dict [str , float ]:
207+ async def eval (self , vllm_metrics_scraper : Optional [ VLLMMetricsScraper ] = None ) -> Dict [str , float ]:
208208 """
209209 Run generation and scoring on the evaluation dataset.
210210
211211 The eval metrics are recorded after having finished training `self.global_step` steps.
212212 Metrics recorded in global_step 0 corresponds to evaluations before training.
213213
214+ Args:
215+ vllm_metrics_scraper: when provided, the eval loop calls
216+ ``resume()``/``pause()`` around each generation so the scraper
217+ attributes only generation time to the open ``vllm/eval`` window.
218+
214219 Returns:
215220 A dictionary of evaluation metrics.
216221 """
@@ -221,6 +226,7 @@ async def eval(self) -> Dict[str, float]:
221226 cfg = self .cfg ,
222227 global_step = self .global_step ,
223228 tokenizer = self .tokenizer ,
229+ vllm_metrics_scraper = vllm_metrics_scraper ,
224230 )
225231 else :
226232 eval_metrics = await evaluate (
@@ -229,6 +235,7 @@ async def eval(self) -> Dict[str, float]:
229235 cfg = self .cfg ,
230236 global_step = self .global_step ,
231237 tokenizer = self .tokenizer ,
238+ vllm_metrics_scraper = vllm_metrics_scraper ,
232239 )
233240 return eval_metrics
234241
@@ -294,6 +301,13 @@ async def train(self):
294301 if not step_started :
295302 self ._fire ("on_step_start" )
296303 step_started = True
304+ # Open the train-rollout metrics window once per logical
305+ # step; paused so only the generation spans count toward the
306+ # throughput denominator (dynamic sampling may generate more
307+ # than once before the step completes).
308+ if self ._vllm_metrics_scraper is not None :
309+ await self ._vllm_metrics_scraper .start ("vllm/train" )
310+ self ._vllm_metrics_scraper .pause ()
297311 with Timer ("step" , self .all_timings ):
298312 # for colocate_all=true, inference engine is always on GPU when starting the training step
299313
@@ -311,8 +325,12 @@ async def train(self):
311325 )
312326
313327 # 1.1. generation phase
328+ if self ._vllm_metrics_scraper is not None :
329+ self ._vllm_metrics_scraper .resume ()
314330 with Timer ("generate" , self .all_timings ):
315331 generator_output : GeneratorOutput = await self .generate (generator_input )
332+ if self ._vllm_metrics_scraper is not None :
333+ self ._vllm_metrics_scraper .pause ()
316334
317335 if self .cfg .generator .step_wise_trajectories :
318336 # NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
@@ -331,6 +349,13 @@ async def train(self):
331349 # if we are not continuing sampling, we sleep the inference engine
332350 await self .inference_engine_client .sleep ()
333351
352+ # The train rollout for this step is done generating; close
353+ # its metrics window. ``vllm/eval/*`` is collected separately
354+ # around eval below.
355+ vllm_metrics : Dict [str , float ] = {}
356+ if self ._vllm_metrics_scraper is not None :
357+ vllm_metrics = await self ._vllm_metrics_scraper .stop ()
358+
334359 # 1.2 postprocess rewards (and merge step-wise turns if enabled)
335360 with Timer ("postprocess_generator_output" , self .all_timings ):
336361 generator_output , uids = self .postprocess_generator_output (generator_output , uids )
@@ -434,18 +459,26 @@ async def train(self):
434459 or self .global_step == self .total_training_steps
435460 )
436461 if force_eval or interval_eval :
462+ # Open the eval-rollout window; the scraper itself measures
463+ # the generation spans via resume()/pause() inside eval().
464+ if self ._vllm_metrics_scraper is not None :
465+ await self ._vllm_metrics_scraper .start ("vllm/eval" )
466+ self ._vllm_metrics_scraper .pause ()
437467 self ._fire ("on_eval_start" )
438468 with Timer ("eval" , self .all_timings ):
439- eval_metrics = await self .eval ()
469+ eval_metrics = await self .eval (vllm_metrics_scraper = self . _vllm_metrics_scraper )
440470 self .all_metrics .update (eval_metrics )
441471 self ._fire ("on_eval_end" , metrics = eval_metrics )
472+ if self ._vllm_metrics_scraper is not None :
473+ vllm_metrics .update (await self ._vllm_metrics_scraper .stop ())
442474
443475 log_payload = {
444476 ** self .all_metrics ,
445477 ** {f"timing/{ k } " : v for k , v in self .all_timings .items ()},
478+ # vllm/train/* = train rollout, vllm/eval/* = eval rollout,
479+ # each over its own generation time (owned by the scraper).
480+ ** vllm_metrics ,
446481 }
447- if self ._vllm_metrics_scraper is not None :
448- log_payload .update (await self ._vllm_metrics_scraper .sample ())
449482
450483 if self ._ray_gpu_monitor is not None :
451484 log_payload .update (self ._ray_gpu_monitor .flush ())
0 commit comments