Skip to content

Commit bdd5e8c

Browse files
author
Weiliangl User
committed
Add CSV export for sa-bench rollup
1 parent 31c3e59 commit bdd5e8c

2 files changed

Lines changed: 508 additions & 10 deletions

File tree

src/srtctl/benchmarks/scripts/sa-bench/rollup.py

Lines changed: 342 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,47 @@
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""Generate benchmark-rollup.json from sa-bench results."""
5+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench results."""
66

7+
from __future__ import annotations
8+
9+
import csv
710
import json
11+
from collections import Counter
12+
import math
13+
import re
814
import sys
915
from datetime import datetime, timezone
1016
from pathlib import Path
17+
from typing import Any, NamedTuple
18+
19+
import yaml
20+
21+
22+
OUTPUT_FIELDS = [
23+
"Config",
24+
"Total GPU Count",
25+
"Decode GPU Count",
26+
"Concurrency",
27+
"Total Token Throughput",
28+
"Output Token Throughput",
29+
"Median TTFT",
30+
"Median TPOT",
31+
"Median ITL",
32+
"P90 Decode Running Requests",
33+
"Output Token Throughput per User",
34+
"Total Token Throughput per GPU",
35+
]
36+
37+
RUNNING_REQ_PATTERN = re.compile(r"#running-req:\s*(\d+)")
38+
39+
40+
class RollupContext(NamedTuple):
41+
"""Resolved runtime context shared by CSV enrichment helpers."""
42+
43+
config_name: str | None
44+
resources: dict[str, Any] | None
45+
backend_type: str | None
1146

1247

1348
def _get_percentile(percentiles: list, target: float) -> float | None:
@@ -20,24 +55,299 @@ def _get_percentile(percentiles: list, target: float) -> float | None:
2055
return None
2156

2257

58+
def _read_yaml_dict(path: Path) -> dict[str, Any] | None:
59+
"""Read a YAML file into a dictionary."""
60+
try:
61+
data = yaml.safe_load(path.read_text()) or {}
62+
except Exception as exc:
63+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
64+
return None
65+
66+
return data if isinstance(data, dict) else None
67+
68+
69+
def _read_json_dict(path: Path) -> dict[str, Any] | None:
70+
"""Read a JSON file into a dictionary."""
71+
try:
72+
data = json.loads(path.read_text())
73+
except Exception as exc:
74+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
75+
return None
76+
77+
return data if isinstance(data, dict) else None
78+
79+
80+
def _read_runtime_config(log_dir: Path) -> dict[str, Any] | None:
81+
"""Read resolved runtime config, preferring override-expanded configs."""
82+
output_dir = log_dir.parent
83+
84+
runtime_configs = sorted(output_dir.glob("config_*.yaml"))
85+
for config_path in runtime_configs:
86+
config = _read_yaml_dict(config_path)
87+
if config:
88+
return config
89+
90+
config_path = log_dir / "config.yaml"
91+
if config_path.exists():
92+
return _read_yaml_dict(config_path)
93+
94+
return None
95+
96+
97+
def _read_job_metadata(log_dir: Path) -> dict[str, Any] | None:
98+
"""Read submit metadata JSON from the output directory when available."""
99+
output_dir = log_dir.parent
100+
for metadata_path in sorted(output_dir.glob("*.json")):
101+
data = _read_json_dict(metadata_path)
102+
if data:
103+
return data
104+
return None
105+
106+
107+
def _extract_backend_type(runtime_config: dict[str, Any] | None, metadata: dict[str, Any] | None) -> str | None:
108+
"""Extract backend type from metadata first, then runtime config if possible."""
109+
if metadata:
110+
backend_type = metadata.get("backend_type")
111+
if isinstance(backend_type, str) and backend_type:
112+
return backend_type
113+
114+
if runtime_config:
115+
backend = runtime_config.get("backend")
116+
if isinstance(backend, dict):
117+
backend_type = backend.get("type")
118+
if isinstance(backend_type, str) and backend_type:
119+
return backend_type
120+
if "sglang_config" in backend:
121+
return "sglang"
122+
123+
return None
124+
125+
126+
def _load_rollup_context(log_dir: Path) -> RollupContext:
127+
"""Load config name, resources, and backend type once for downstream helpers."""
128+
runtime_config = _read_runtime_config(log_dir)
129+
metadata = _read_job_metadata(log_dir)
130+
131+
config_name = None
132+
if runtime_config:
133+
name = runtime_config.get("name")
134+
if isinstance(name, str) and name:
135+
config_name = name
136+
if config_name is None and metadata:
137+
job_name = metadata.get("job_name")
138+
if isinstance(job_name, str) and job_name:
139+
config_name = job_name
140+
141+
resources = None
142+
if metadata:
143+
metadata_resources = metadata.get("resources")
144+
if isinstance(metadata_resources, dict):
145+
resources = metadata_resources
146+
if resources is None and runtime_config:
147+
runtime_resources = runtime_config.get("resources")
148+
if isinstance(runtime_resources, dict):
149+
resources = runtime_resources
150+
151+
return RollupContext(
152+
config_name=config_name,
153+
resources=resources,
154+
backend_type=_extract_backend_type(runtime_config, metadata),
155+
)
156+
157+
158+
def _compute_total_gpu_count(resources: dict[str, Any]) -> int | None:
159+
"""Compute total GPU count from resources using the same topology semantics as the config."""
160+
gpus_per_node_raw = resources.get("gpus_per_node")
161+
if gpus_per_node_raw in (None, 0):
162+
return None
163+
gpus_per_node = int(gpus_per_node_raw)
164+
165+
prefill_nodes = int(resources.get("prefill_nodes", 0) or 0)
166+
decode_nodes = int(resources.get("decode_nodes", 0) or 0)
167+
if prefill_nodes or decode_nodes:
168+
return (prefill_nodes + decode_nodes) * gpus_per_node
169+
170+
agg_nodes = int(resources.get("agg_nodes", 0) or 0)
171+
if agg_nodes:
172+
return agg_nodes * gpus_per_node
173+
174+
return gpus_per_node
175+
176+
177+
def _compute_prefill_gpus_per_worker(resources: dict[str, Any], gpus_per_node: int | None) -> int | None:
178+
"""Compute prefill GPUs per worker using the same fallback order as ResourceConfig."""
179+
explicit = resources.get("gpus_per_prefill")
180+
if explicit not in (None, 0):
181+
return int(explicit)
182+
183+
prefill_nodes = resources.get("prefill_nodes")
184+
prefill_workers = resources.get("prefill_workers")
185+
if prefill_nodes not in (None, 0) and prefill_workers not in (None, 0) and gpus_per_node not in (None, 0):
186+
return (int(prefill_nodes) * int(gpus_per_node)) // int(prefill_workers)
187+
188+
return gpus_per_node
189+
190+
191+
def _compute_decode_gpu_count(resources: dict[str, Any], total_gpu_count: int | None) -> int | None:
192+
"""Compute total decode GPUs using ResourceConfig-compatible rules when possible."""
193+
decode_workers = int(resources.get("decode_workers", 0) or 0)
194+
decode_nodes_raw = resources.get("decode_nodes")
195+
decode_nodes = int(decode_nodes_raw) if decode_nodes_raw not in (None, "") else None
196+
197+
explicit = resources.get("gpus_per_decode")
198+
if explicit not in (None, 0):
199+
gpus_per_decode = int(explicit)
200+
return decode_workers * gpus_per_decode if decode_workers else gpus_per_decode
201+
202+
gpus_per_node_raw = resources.get("gpus_per_node")
203+
gpus_per_node = int(gpus_per_node_raw) if gpus_per_node_raw not in (None, 0) else None
204+
205+
if gpus_per_node is None and total_gpu_count not in (None, 0):
206+
prefill_nodes = int(resources.get("prefill_nodes", 0) or 0)
207+
total_nodes = prefill_nodes + (decode_nodes or 0)
208+
if total_nodes > 0 and total_gpu_count % total_nodes == 0:
209+
gpus_per_node = total_gpu_count // total_nodes
210+
211+
if decode_nodes not in (None, 0) and gpus_per_node not in (None, 0):
212+
if decode_workers:
213+
gpus_per_decode = (decode_nodes * gpus_per_node) // decode_workers
214+
return decode_workers * gpus_per_decode
215+
return decode_nodes * gpus_per_node
216+
217+
if decode_nodes == 0 and decode_workers:
218+
gpus_per_prefill = _compute_prefill_gpus_per_worker(resources, gpus_per_node)
219+
if gpus_per_prefill not in (None, 0):
220+
return decode_workers * gpus_per_prefill
221+
222+
return None
223+
224+
225+
def _extract_gpu_counts(context: RollupContext) -> tuple[int | None, int | None]:
226+
"""Extract total/decode GPU counts from effective runtime resources."""
227+
resources = context.resources
228+
if resources:
229+
total_gpu_count = _compute_total_gpu_count(resources)
230+
decode_gpu_count = _compute_decode_gpu_count(resources, total_gpu_count)
231+
if total_gpu_count is not None or decode_gpu_count is not None:
232+
return total_gpu_count, decode_gpu_count
233+
234+
return None, None
235+
236+
237+
def _is_sglang_disagg(context: RollupContext) -> bool:
238+
"""Return whether the current run is an SGLang disaggregated deployment."""
239+
if context.backend_type != "sglang":
240+
return False
241+
242+
if not isinstance(context.resources, dict):
243+
return False
244+
245+
prefill_nodes = int(context.resources.get("prefill_nodes", 0) or 0)
246+
decode_nodes = int(context.resources.get("decode_nodes", 0) or 0)
247+
agg_workers = int(context.resources.get("agg_workers", 0) or 0)
248+
return prefill_nodes > 0 and decode_nodes > 0 and agg_workers == 0
249+
250+
251+
def _extract_p90_decode_running_requests(log_dir: Path, context: RollupContext) -> int | None:
252+
"""Stream decode logs and compute the nearest-rank P90 of #running-req values."""
253+
if not _is_sglang_disagg(context):
254+
return None
255+
256+
counts: Counter[int] = Counter()
257+
total = 0
258+
259+
for decode_log in sorted(log_dir.glob("*decode*.out")):
260+
try:
261+
with decode_log.open("r", errors="replace") as f:
262+
for line in f:
263+
match = RUNNING_REQ_PATTERN.search(line)
264+
if not match:
265+
continue
266+
value = int(match.group(1))
267+
counts[value] += 1
268+
total += 1
269+
except OSError as exc:
270+
print(f"Failed to read {decode_log}: {exc}", file=sys.stderr)
271+
272+
if total == 0:
273+
return None
274+
275+
rank = math.ceil(total * 0.9)
276+
cumulative = 0
277+
for value in sorted(counts):
278+
cumulative += counts[value]
279+
if cumulative >= rank:
280+
return value
281+
282+
return None
283+
284+
285+
def _safe_ratio(numerator: float | int | None, denominator: float | int | None) -> float | None:
286+
"""Return numerator / denominator when both values are valid and denominator != 0."""
287+
if numerator is None or denominator in (None, 0):
288+
return None
289+
return float(numerator) / float(denominator)
290+
291+
292+
def _format_csv_value(value: object) -> str:
293+
"""Format CSV values with at most three decimal places for numeric fields."""
294+
if value is None:
295+
return ""
296+
if isinstance(value, int):
297+
return str(value)
298+
if isinstance(value, float):
299+
return f"{value:.3f}".rstrip("0").rstrip(".")
300+
return str(value)
301+
302+
303+
def _build_csv_row(
304+
data: dict[str, object],
305+
config_name: str,
306+
gpu_num: int | None,
307+
decode_gpu_count: int | None,
308+
p90_decode_running_requests: int | None,
309+
) -> dict[str, object]:
310+
"""Build one CSV row from a parsed sa-bench result."""
311+
total_token_throughput = data.get("total_token_throughput")
312+
median_tpot = data.get("median_tpot_ms")
313+
row = {
314+
"Config": config_name,
315+
"Total GPU Count": gpu_num,
316+
"Decode GPU Count": decode_gpu_count,
317+
"Concurrency": data.get("max_concurrency"),
318+
"Total Token Throughput": total_token_throughput,
319+
"Output Token Throughput": data.get("output_throughput"),
320+
"Median TTFT": data.get("median_ttft_ms"),
321+
"Median TPOT": median_tpot,
322+
"Median ITL": data.get("median_itl_ms"),
323+
"P90 Decode Running Requests": p90_decode_running_requests,
324+
"Output Token Throughput per User": _safe_ratio(1000.0, median_tpot),
325+
"Total Token Throughput per GPU": _safe_ratio(total_token_throughput, gpu_num),
326+
}
327+
return {key: _format_csv_value(value) for key, value in row.items()}
328+
329+
23330
def main(log_dir: Path) -> None:
24-
"""Generate benchmark-rollup.json from sa-bench result files."""
331+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench result files."""
25332
result_files = sorted(log_dir.glob("sa-bench_*/results_*.json"))
26333
if not result_files:
27334
print("No sa-bench results found", file=sys.stderr)
28335
return
29336

30337
runs = []
338+
csv_rows = []
31339
config = {}
340+
context = _load_rollup_context(log_dir)
341+
total_gpu_count, decode_gpu_count = _extract_gpu_counts(context)
342+
p90_decode_running_requests = _extract_p90_decode_running_requests(log_dir, context)
32343

33-
for f in result_files:
344+
for result_file in result_files:
34345
try:
35-
data = json.loads(f.read_text())
36-
except json.JSONDecodeError as e:
37-
print(f"Failed to parse {f}: {e}", file=sys.stderr)
346+
data = json.loads(result_file.read_text())
347+
except json.JSONDecodeError as exc:
348+
print(f"Failed to parse {result_file}: {exc}", file=sys.stderr)
38349
continue
39350

40-
# Extract config from first file
41351
if not config:
42352
config = {
43353
"model": data.get("model_id"),
@@ -61,16 +371,38 @@ def main(log_dir: Path) -> None:
61371
"total_output_tokens": data.get("total_output"),
62372
})
63373

374+
csv_rows.append(
375+
_build_csv_row(
376+
data=data,
377+
config_name=context.config_name or str(data.get("model_id") or "unknown"),
378+
gpu_num=total_gpu_count,
379+
decode_gpu_count=decode_gpu_count,
380+
p90_decode_running_requests=p90_decode_running_requests,
381+
)
382+
)
383+
384+
if not runs:
385+
print("No valid sa-bench results found", file=sys.stderr)
386+
return
387+
64388
rollup = {
65389
"benchmark_type": "sa-bench",
66390
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
67391
"config": config,
68392
"runs": runs,
69393
}
70394

71-
output_path = log_dir / "benchmark-rollup.json"
72-
output_path.write_text(json.dumps(rollup, indent=2))
73-
print(f"Wrote {output_path}")
395+
json_path = log_dir / "benchmark-rollup.json"
396+
json_path.write_text(json.dumps(rollup, indent=2))
397+
print(f"Wrote {json_path}")
398+
399+
csv_rows.sort(key=lambda row: int(row["Concurrency"]) if row["Concurrency"] else -1)
400+
csv_path = log_dir / "benchmark-rollup.csv"
401+
with csv_path.open("w", newline="") as csv_file:
402+
writer = csv.DictWriter(csv_file, fieldnames=OUTPUT_FIELDS)
403+
writer.writeheader()
404+
writer.writerows(csv_rows)
405+
print(f"Wrote {csv_path}")
74406

75407

76408
if __name__ == "__main__":

0 commit comments

Comments
 (0)