3636from dataclasses import dataclass , field
3737from pathlib import Path
3838from typing import Dict , List , Optional
39+ import dataclasses
3940
4041import numpy as np
4142
@@ -175,13 +176,21 @@ def load_sharegpt(
175176
176177# ── vLLM runner ───────────────────────────────────────────────────
177178
179+
178180class VLLMRunner :
179181 """
180182 Uses AsyncLLMEngine + token streaming to measure true TTFT:
181183 wall-clock time from request submission to the first token chunk
182- arriving in the async generator, before decode begins .
184+ arriving in the async generator.
183185 """
184186
187+ _CACHED_TOKEN_ATTRS = (
188+ "num_cached_tokens" ,
189+ "num_prefix_cache_tokens" ,
190+ "cache_hit_tokens" ,
191+ "num_computed_tokens" ,
192+ )
193+
185194 def __init__ (
186195 self ,
187196 model_name : str ,
@@ -190,18 +199,20 @@ def __init__(
190199 max_model_len : Optional [int ] = None ,
191200 tensor_parallel_size : int = 1 ,
192201 enable_prefix_caching : bool = True ,
193- ):
194- from vllm import AsyncLLMEngine , AsyncEngineArgs , SamplingParams
202+ ) -> None :
203+ from vllm import AsyncEngineArgs , AsyncLLMEngine , SamplingParams
195204 from transformers import AutoTokenizer
196205
197- self .SamplingParams = SamplingParams
206+ self .model_name = model_name
198207 self .max_new_tokens = max_new_tokens
199208 self .enable_prefix_caching = enable_prefix_caching
209+ self .SamplingParams = SamplingParams
200210
201211 log .info (
202- f"Loading model: { model_name } "
203- f"(prefix_caching={ enable_prefix_caching } , "
204- f"gpu_mem_util={ gpu_memory_utilization } , async=True)"
212+ "Loading model=%s prefix_caching=%s gpu_mem_util=%.2f async=True" ,
213+ model_name ,
214+ enable_prefix_caching ,
215+ gpu_memory_utilization ,
205216 )
206217
207218 engine_args = AsyncEngineArgs (
@@ -213,51 +224,21 @@ def __init__(
213224 max_model_len = max_model_len ,
214225 )
215226 self .engine = AsyncLLMEngine .from_engine_args (engine_args )
216- # Silence vLLM's per-request INFO lines
227+
217228 logging .getLogger ("vllm.engine.async_llm_engine" ).setLevel (logging .WARNING )
218229 logging .getLogger ("vllm.core.scheduler" ).setLevel (logging .WARNING )
219- self ._tokenizer = AutoTokenizer .from_pretrained (model_name , trust_remote_code = True )
230+
231+ self ._tokenizer = AutoTokenizer .from_pretrained (
232+ model_name , trust_remote_code = True
233+ )
234+
220235 self ._request_counter = 0
221- # Persistent event loop — reused for every run_turn() call so the
222- # AsyncLLMEngine's internal tasks stay alive between requests.
223236 self ._loop = asyncio .new_event_loop ()
237+ asyncio .set_event_loop (self ._loop )
224238
225- # Discover the correct cached-token field name for this vLLM version.
226- # We probe RequestMetrics at import time so we fail fast and loudly.
227- self ._cached_tokens_attr : Optional [str ] = None
228- try :
229- from vllm .engine .metrics_types import RequestMetrics as _RM
230- except ImportError :
231- try :
232- from vllm .outputs import RequestOutput as _RO
233- # Fall back: inspect a dummy metrics object later
234- _RM = None
235- except ImportError :
236- _RM = None
237-
238- if _RM is not None :
239- for attr in ("num_cached_tokens" , "num_prefix_cache_tokens" ,
240- "cache_hit_tokens" , "num_computed_tokens" ):
241- if hasattr (_RM , attr ) or (
242- hasattr (_RM , "__dataclass_fields__" ) and
243- attr in _RM .__dataclass_fields__
244- ):
245- self ._cached_tokens_attr = attr
246- log .info (f" Cached-token field: RequestMetrics.{ attr } " )
247- break
248- if self ._cached_tokens_attr is None :
249- import dataclasses
250- try :
251- all_fields = [f .name for f in dataclasses .fields (_RM )]
252- log .warning (
253- f" No known cached-token field found on RequestMetrics. "
254- f"Available fields: { all_fields } . "
255- f"cache_hit_ratio will be 0 — update _CACHED_TOKEN_ATTRS."
256- )
257- except Exception :
258- pass
259-
260- log .info (" Async engine ready." )
239+ self ._cached_tokens_attr = self ._discover_cached_token_attr ()
240+
241+ log .info ("Async engine ready." )
261242
262243 @property
263244 def tokenizer (self ):
@@ -267,82 +248,107 @@ def count_tokens(self, text: str) -> int:
267248 return len (self ._tokenizer .encode (text , add_special_tokens = True ))
268249
269250 def run_turn (self , prompt : str ) -> dict :
270- """Sync wrapper — runs the async coroutine on the persistent event loop."""
271251 return self ._loop .run_until_complete (self ._run_turn_async (prompt ))
272252
273253 async def _run_turn_async (self , prompt : str ) -> dict :
274- """
275- Stream tokens from AsyncLLMEngine.
276- TTFT = wall time from generate() call to the first non-empty output chunk.
277- total_ms = wall time to the final chunk (prefill + full decode).
278- """
279- from vllm import SamplingParams
280-
281254 self ._request_counter += 1
282255 request_id = f"req-{ self ._request_counter } "
283256
284- params = SamplingParams (max_tokens = self .max_new_tokens , temperature = 0.0 )
257+ params = self . SamplingParams (max_tokens = self .max_new_tokens , temperature = 0.0 )
285258
286- t_submit = time .perf_counter ()
259+ t0 = time .perf_counter ()
287260 ttft_ms : Optional [float ] = None
288- output_text = ""
289261 final_output = None
262+ first_chunk_seen = False
290263
291264 async for output in self .engine .generate (prompt , params , request_id = request_id ):
292- if ttft_ms is None and output .outputs and output .outputs [0 ].text :
293- ttft_ms = (time .perf_counter () - t_submit ) * 1000
265+ if not first_chunk_seen :
266+ ttft_ms = (time .perf_counter () - t0 ) * 1000
267+ first_chunk_seen = True
294268 final_output = output
295269
296- total_ms = (time .perf_counter () - t_submit ) * 1000
297-
298- # If the model returned no text at all, TTFT = total_ms (pathological)
270+ total_ms = (time .perf_counter () - t0 ) * 1000
299271 if ttft_ms is None :
300272 ttft_ms = total_ms
301273
302- if final_output is not None and final_output .outputs :
303- output_text = final_output .outputs [0 ].text
304-
305- # Prompt token count
306- if final_output is not None and final_output .prompt_token_ids :
307- prompt_tokens = len (final_output .prompt_token_ids )
308- else :
309- prompt_tokens = self .count_tokens (prompt )
310-
311- # Cached tokens — field name varies by vLLM version
312- cached_tokens = 0
313- if final_output is not None and final_output .metrics is not None :
314- m = final_output .metrics
315- if self ._cached_tokens_attr is not None :
316- cached_tokens = int (getattr (m , self ._cached_tokens_attr , 0 ) or 0 )
317- else :
318- # Runtime fallback: try all known names, log what's available once
319- for attr in ("num_cached_tokens" , "num_prefix_cache_tokens" ,
320- "cache_hit_tokens" , "num_computed_tokens" ):
321- val = getattr (m , attr , None )
322- if val is not None :
323- cached_tokens = int (val )
324- self ._cached_tokens_attr = attr
325- log .info (f" Discovered cached-token field at runtime: { attr } " )
326- break
327- else :
328- if self ._request_counter == 1 :
329- log .warning (
330- f" cache_hit_ratio will be 0 — RequestMetrics has no "
331- f"known cached-token field. Fields present: "
332- f"{ [a for a in dir (m ) if not a .startswith ('_' )]} "
333- )
334-
335- cache_hit_ratio = cached_tokens / max (prompt_tokens , 1 )
274+ output_text = self ._get_output_text (final_output )
275+ prompt_tokens = self ._get_prompt_tokens (final_output , prompt )
276+ cached_tokens = self ._get_cached_tokens (final_output )
336277
337278 return {
338279 "ttft_ms" : ttft_ms ,
339280 "total_ms" : total_ms ,
340281 "output_text" : output_text ,
341282 "prompt_tokens" : prompt_tokens ,
342283 "cached_tokens" : cached_tokens ,
343- "cache_hit_ratio" : cache_hit_ratio ,
284+ "cache_hit_ratio" : cached_tokens / max ( prompt_tokens , 1 ) ,
344285 }
345286
287+ def _discover_cached_token_attr (self ) -> Optional [str ]:
288+ try :
289+ from vllm .engine .metrics_types import RequestMetrics as RM
290+ except ImportError :
291+ return None
292+
293+ fields = set ()
294+ if hasattr (RM , "__dataclass_fields__" ):
295+ fields .update (RM .__dataclass_fields__ .keys ())
296+
297+ try :
298+ fields .update (f .name for f in dataclasses .fields (RM ))
299+ except Exception :
300+ pass
301+
302+ for attr in self ._CACHED_TOKEN_ATTRS :
303+ if attr in fields or hasattr (RM , attr ):
304+ log .info ("Cached-token field: RequestMetrics.%s" , attr )
305+ return attr
306+
307+ log .warning (
308+ "No cached-token field found on RequestMetrics. Fields: %s" ,
309+ sorted (fields ) if fields else "unknown" ,
310+ )
311+ return None
312+
313+ def _get_output_text (self , final_output ) -> str :
314+ if final_output and getattr (final_output , "outputs" , None ):
315+ return final_output .outputs [0 ].text or ""
316+ return ""
317+
318+ def _get_prompt_tokens (self , final_output , prompt : str ) -> int :
319+ if final_output and getattr (final_output , "prompt_token_ids" , None ):
320+ return len (final_output .prompt_token_ids )
321+ return self .count_tokens (prompt )
322+
323+ def _get_cached_tokens (self , final_output ) -> int :
324+ if not final_output :
325+ return 0
326+
327+ metrics = getattr (final_output , "metrics" , None )
328+ if metrics is None :
329+ return 0
330+
331+ if self ._cached_tokens_attr :
332+ return int (getattr (metrics , self ._cached_tokens_attr , 0 ) or 0 )
333+
334+ for attr in self ._CACHED_TOKEN_ATTRS :
335+ value = getattr (metrics , attr , None )
336+ if value is not None :
337+ self ._cached_tokens_attr = attr
338+ log .info ("Discovered cached-token field at runtime: %s" , attr )
339+ return int (value )
340+
341+ return 0
342+
343+ def close (self ) -> None :
344+ if not self ._loop .is_closed ():
345+ self ._loop .close ()
346+
347+ def __enter__ (self ):
348+ return self
349+
350+ def __exit__ (self , exc_type , exc , tb ):
351+ self .close ()
346352
347353# ── Checkpoint helpers ────────────────────────────────────────────
348354
0 commit comments