Skip to content

Commit 2dbd702

Browse files
committed
available metrics
1 parent 4e4f5c2 commit 2dbd702

1 file changed

Lines changed: 104 additions & 98 deletions

File tree

benchmarks_and_experiments/vllm_sharegpt_replay/run_sharegpt_vllm.py

Lines changed: 104 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from dataclasses import dataclass, field
3737
from pathlib import Path
3838
from typing import Dict, List, Optional
39+
import dataclasses
3940

4041
import numpy as np
4142

@@ -175,13 +176,21 @@ def load_sharegpt(
175176

176177
# ── vLLM runner ───────────────────────────────────────────────────
177178

179+
178180
class VLLMRunner:
179181
"""
180182
Uses AsyncLLMEngine + token streaming to measure true TTFT:
181183
wall-clock time from request submission to the first token chunk
182-
arriving in the async generator, before decode begins.
184+
arriving in the async generator.
183185
"""
184186

187+
_CACHED_TOKEN_ATTRS = (
188+
"num_cached_tokens",
189+
"num_prefix_cache_tokens",
190+
"cache_hit_tokens",
191+
"num_computed_tokens",
192+
)
193+
185194
def __init__(
186195
self,
187196
model_name: str,
@@ -190,18 +199,20 @@ def __init__(
190199
max_model_len: Optional[int] = None,
191200
tensor_parallel_size: int = 1,
192201
enable_prefix_caching: bool = True,
193-
):
194-
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
202+
) -> None:
203+
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
195204
from transformers import AutoTokenizer
196205

197-
self.SamplingParams = SamplingParams
206+
self.model_name = model_name
198207
self.max_new_tokens = max_new_tokens
199208
self.enable_prefix_caching = enable_prefix_caching
209+
self.SamplingParams = SamplingParams
200210

201211
log.info(
202-
f"Loading model: {model_name} "
203-
f"(prefix_caching={enable_prefix_caching}, "
204-
f"gpu_mem_util={gpu_memory_utilization}, async=True)"
212+
"Loading model=%s prefix_caching=%s gpu_mem_util=%.2f async=True",
213+
model_name,
214+
enable_prefix_caching,
215+
gpu_memory_utilization,
205216
)
206217

207218
engine_args = AsyncEngineArgs(
@@ -213,51 +224,21 @@ def __init__(
213224
max_model_len=max_model_len,
214225
)
215226
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
216-
# Silence vLLM's per-request INFO lines
227+
217228
logging.getLogger("vllm.engine.async_llm_engine").setLevel(logging.WARNING)
218229
logging.getLogger("vllm.core.scheduler").setLevel(logging.WARNING)
219-
self._tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
230+
231+
self._tokenizer = AutoTokenizer.from_pretrained(
232+
model_name, trust_remote_code=True
233+
)
234+
220235
self._request_counter = 0
221-
# Persistent event loop — reused for every run_turn() call so the
222-
# AsyncLLMEngine's internal tasks stay alive between requests.
223236
self._loop = asyncio.new_event_loop()
237+
asyncio.set_event_loop(self._loop)
224238

225-
# Discover the correct cached-token field name for this vLLM version.
226-
# We probe RequestMetrics at import time so we fail fast and loudly.
227-
self._cached_tokens_attr: Optional[str] = None
228-
try:
229-
from vllm.engine.metrics_types import RequestMetrics as _RM
230-
except ImportError:
231-
try:
232-
from vllm.outputs import RequestOutput as _RO
233-
# Fall back: inspect a dummy metrics object later
234-
_RM = None
235-
except ImportError:
236-
_RM = None
237-
238-
if _RM is not None:
239-
for attr in ("num_cached_tokens", "num_prefix_cache_tokens",
240-
"cache_hit_tokens", "num_computed_tokens"):
241-
if hasattr(_RM, attr) or (
242-
hasattr(_RM, "__dataclass_fields__") and
243-
attr in _RM.__dataclass_fields__
244-
):
245-
self._cached_tokens_attr = attr
246-
log.info(f" Cached-token field: RequestMetrics.{attr}")
247-
break
248-
if self._cached_tokens_attr is None:
249-
import dataclasses
250-
try:
251-
all_fields = [f.name for f in dataclasses.fields(_RM)]
252-
log.warning(
253-
f" No known cached-token field found on RequestMetrics. "
254-
f"Available fields: {all_fields}. "
255-
f"cache_hit_ratio will be 0 — update _CACHED_TOKEN_ATTRS."
256-
)
257-
except Exception:
258-
pass
259-
260-
log.info(" Async engine ready.")
239+
self._cached_tokens_attr = self._discover_cached_token_attr()
240+
241+
log.info("Async engine ready.")
261242

262243
@property
263244
def tokenizer(self):
@@ -267,82 +248,107 @@ def count_tokens(self, text: str) -> int:
267248
return len(self._tokenizer.encode(text, add_special_tokens=True))
268249

269250
def run_turn(self, prompt: str) -> dict:
270-
"""Sync wrapper — runs the async coroutine on the persistent event loop."""
271251
return self._loop.run_until_complete(self._run_turn_async(prompt))
272252

273253
async def _run_turn_async(self, prompt: str) -> dict:
274-
"""
275-
Stream tokens from AsyncLLMEngine.
276-
TTFT = wall time from generate() call to the first non-empty output chunk.
277-
total_ms = wall time to the final chunk (prefill + full decode).
278-
"""
279-
from vllm import SamplingParams
280-
281254
self._request_counter += 1
282255
request_id = f"req-{self._request_counter}"
283256

284-
params = SamplingParams(max_tokens=self.max_new_tokens, temperature=0.0)
257+
params = self.SamplingParams(max_tokens=self.max_new_tokens, temperature=0.0)
285258

286-
t_submit = time.perf_counter()
259+
t0 = time.perf_counter()
287260
ttft_ms: Optional[float] = None
288-
output_text = ""
289261
final_output = None
262+
first_chunk_seen = False
290263

291264
async for output in self.engine.generate(prompt, params, request_id=request_id):
292-
if ttft_ms is None and output.outputs and output.outputs[0].text:
293-
ttft_ms = (time.perf_counter() - t_submit) * 1000
265+
if not first_chunk_seen:
266+
ttft_ms = (time.perf_counter() - t0) * 1000
267+
first_chunk_seen = True
294268
final_output = output
295269

296-
total_ms = (time.perf_counter() - t_submit) * 1000
297-
298-
# If the model returned no text at all, TTFT = total_ms (pathological)
270+
total_ms = (time.perf_counter() - t0) * 1000
299271
if ttft_ms is None:
300272
ttft_ms = total_ms
301273

302-
if final_output is not None and final_output.outputs:
303-
output_text = final_output.outputs[0].text
304-
305-
# Prompt token count
306-
if final_output is not None and final_output.prompt_token_ids:
307-
prompt_tokens = len(final_output.prompt_token_ids)
308-
else:
309-
prompt_tokens = self.count_tokens(prompt)
310-
311-
# Cached tokens — field name varies by vLLM version
312-
cached_tokens = 0
313-
if final_output is not None and final_output.metrics is not None:
314-
m = final_output.metrics
315-
if self._cached_tokens_attr is not None:
316-
cached_tokens = int(getattr(m, self._cached_tokens_attr, 0) or 0)
317-
else:
318-
# Runtime fallback: try all known names, log what's available once
319-
for attr in ("num_cached_tokens", "num_prefix_cache_tokens",
320-
"cache_hit_tokens", "num_computed_tokens"):
321-
val = getattr(m, attr, None)
322-
if val is not None:
323-
cached_tokens = int(val)
324-
self._cached_tokens_attr = attr
325-
log.info(f" Discovered cached-token field at runtime: {attr}")
326-
break
327-
else:
328-
if self._request_counter == 1:
329-
log.warning(
330-
f" cache_hit_ratio will be 0 — RequestMetrics has no "
331-
f"known cached-token field. Fields present: "
332-
f"{[a for a in dir(m) if not a.startswith('_')]}"
333-
)
334-
335-
cache_hit_ratio = cached_tokens / max(prompt_tokens, 1)
274+
output_text = self._get_output_text(final_output)
275+
prompt_tokens = self._get_prompt_tokens(final_output, prompt)
276+
cached_tokens = self._get_cached_tokens(final_output)
336277

337278
return {
338279
"ttft_ms": ttft_ms,
339280
"total_ms": total_ms,
340281
"output_text": output_text,
341282
"prompt_tokens": prompt_tokens,
342283
"cached_tokens": cached_tokens,
343-
"cache_hit_ratio": cache_hit_ratio,
284+
"cache_hit_ratio": cached_tokens / max(prompt_tokens, 1),
344285
}
345286

287+
def _discover_cached_token_attr(self) -> Optional[str]:
288+
try:
289+
from vllm.engine.metrics_types import RequestMetrics as RM
290+
except ImportError:
291+
return None
292+
293+
fields = set()
294+
if hasattr(RM, "__dataclass_fields__"):
295+
fields.update(RM.__dataclass_fields__.keys())
296+
297+
try:
298+
fields.update(f.name for f in dataclasses.fields(RM))
299+
except Exception:
300+
pass
301+
302+
for attr in self._CACHED_TOKEN_ATTRS:
303+
if attr in fields or hasattr(RM, attr):
304+
log.info("Cached-token field: RequestMetrics.%s", attr)
305+
return attr
306+
307+
log.warning(
308+
"No cached-token field found on RequestMetrics. Fields: %s",
309+
sorted(fields) if fields else "unknown",
310+
)
311+
return None
312+
313+
def _get_output_text(self, final_output) -> str:
314+
if final_output and getattr(final_output, "outputs", None):
315+
return final_output.outputs[0].text or ""
316+
return ""
317+
318+
def _get_prompt_tokens(self, final_output, prompt: str) -> int:
319+
if final_output and getattr(final_output, "prompt_token_ids", None):
320+
return len(final_output.prompt_token_ids)
321+
return self.count_tokens(prompt)
322+
323+
def _get_cached_tokens(self, final_output) -> int:
324+
if not final_output:
325+
return 0
326+
327+
metrics = getattr(final_output, "metrics", None)
328+
if metrics is None:
329+
return 0
330+
331+
if self._cached_tokens_attr:
332+
return int(getattr(metrics, self._cached_tokens_attr, 0) or 0)
333+
334+
for attr in self._CACHED_TOKEN_ATTRS:
335+
value = getattr(metrics, attr, None)
336+
if value is not None:
337+
self._cached_tokens_attr = attr
338+
log.info("Discovered cached-token field at runtime: %s", attr)
339+
return int(value)
340+
341+
return 0
342+
343+
def close(self) -> None:
344+
if not self._loop.is_closed():
345+
self._loop.close()
346+
347+
def __enter__(self):
348+
return self
349+
350+
def __exit__(self, exc_type, exc, tb):
351+
self.close()
346352

347353
# ── Checkpoint helpers ────────────────────────────────────────────
348354

0 commit comments

Comments
 (0)