Skip to content

Commit 900d272

Browse files
author
Weiliangl User
committed
Add CSV export for sa-bench rollup
1 parent 31c3e59 commit 900d272

2 files changed

Lines changed: 493 additions & 10 deletions

File tree

src/srtctl/benchmarks/scripts/sa-bench/rollup.py

Lines changed: 327 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,39 @@
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""Generate benchmark-rollup.json from sa-bench results."""
5+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench results."""
66

7+
from __future__ import annotations
8+
9+
import csv
710
import json
11+
from collections import Counter
12+
import math
13+
import re
814
import sys
915
from datetime import datetime, timezone
1016
from pathlib import Path
17+
from typing import Any
18+
19+
import yaml
20+
21+
22+
OUTPUT_FIELDS = [
23+
"Config",
24+
"Total GPU Count",
25+
"Decode GPU Count",
26+
"Concurrency",
27+
"Total Token Throughput",
28+
"Output Token Throughput",
29+
"Median TTFT",
30+
"Median TPOT",
31+
"Median ITL",
32+
"P90 Decode Running Requests",
33+
"Output Token Throughput per User",
34+
"Total Token Throughput per GPU",
35+
]
36+
37+
RUNNING_REQ_PATTERN = re.compile(r"#running-req:\s*(\d+)")
1138

1239

1340
def _get_percentile(percentiles: list, target: float) -> float | None:
@@ -20,24 +47,292 @@ def _get_percentile(percentiles: list, target: float) -> float | None:
2047
return None
2148

2249

50+
def _extract_gpu_num(path: Path) -> int | None:
51+
"""Extract total GPU count from the result filename."""
52+
match = re.search(r"_gpus_(\d+)", path.name)
53+
if not match:
54+
return None
55+
return int(match.group(1))
56+
57+
58+
def _read_resources_from_yaml(path: Path) -> dict[str, Any] | None:
59+
"""Read the top-level resources field from a YAML file."""
60+
try:
61+
data = yaml.safe_load(path.read_text()) or {}
62+
except Exception as exc:
63+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
64+
return None
65+
66+
resources = data.get("resources")
67+
return resources if isinstance(resources, dict) else None
68+
69+
70+
def _read_runtime_resources(log_dir: Path) -> dict[str, Any] | None:
71+
"""Read resolved runtime resources, preferring override-expanded configs."""
72+
output_dir = log_dir.parent
73+
74+
runtime_configs = sorted(output_dir.glob("config_*.yaml"))
75+
for config_path in runtime_configs:
76+
resources = _read_resources_from_yaml(config_path)
77+
if resources:
78+
return resources
79+
80+
config_path = log_dir / "config.yaml"
81+
if config_path.exists():
82+
return _read_resources_from_yaml(config_path)
83+
84+
return None
85+
86+
87+
def _compute_prefill_gpus_per_worker(resources: dict[str, Any], gpus_per_node: int | None) -> int | None:
88+
"""Compute prefill GPUs per worker using the same fallback order as ResourceConfig."""
89+
explicit = resources.get("gpus_per_prefill")
90+
if explicit not in (None, 0):
91+
return int(explicit)
92+
93+
prefill_nodes = resources.get("prefill_nodes")
94+
prefill_workers = resources.get("prefill_workers")
95+
if prefill_nodes not in (None, 0) and prefill_workers not in (None, 0) and gpus_per_node not in (None, 0):
96+
return (int(prefill_nodes) * int(gpus_per_node)) // int(prefill_workers)
97+
98+
return gpus_per_node
99+
100+
101+
def _compute_decode_gpu_count(resources: dict[str, Any], total_gpu_count: int | None) -> int | None:
102+
"""Compute total decode GPUs using ResourceConfig-compatible rules when possible."""
103+
decode_workers = int(resources.get("decode_workers", 0) or 0)
104+
decode_nodes_raw = resources.get("decode_nodes")
105+
decode_nodes = int(decode_nodes_raw) if decode_nodes_raw not in (None, "") else None
106+
107+
explicit = resources.get("gpus_per_decode")
108+
if explicit not in (None, 0):
109+
gpus_per_decode = int(explicit)
110+
return decode_workers * gpus_per_decode if decode_workers else gpus_per_decode
111+
112+
gpus_per_node_raw = resources.get("gpus_per_node")
113+
gpus_per_node = int(gpus_per_node_raw) if gpus_per_node_raw not in (None, 0) else None
114+
115+
if gpus_per_node is None and total_gpu_count not in (None, 0):
116+
prefill_nodes = int(resources.get("prefill_nodes", 0) or 0)
117+
total_nodes = prefill_nodes + (decode_nodes or 0)
118+
if total_nodes > 0 and total_gpu_count % total_nodes == 0:
119+
gpus_per_node = total_gpu_count // total_nodes
120+
121+
if decode_nodes not in (None, 0) and gpus_per_node not in (None, 0):
122+
if decode_workers:
123+
gpus_per_decode = (decode_nodes * gpus_per_node) // decode_workers
124+
return decode_workers * gpus_per_decode
125+
return decode_nodes * gpus_per_node
126+
127+
if decode_nodes == 0 and decode_workers:
128+
gpus_per_prefill = _compute_prefill_gpus_per_worker(resources, gpus_per_node)
129+
if gpus_per_prefill not in (None, 0):
130+
return decode_workers * gpus_per_prefill
131+
132+
return None
133+
134+
135+
def _extract_decode_gpu_count(log_dir: Path, total_gpu_count: int | None) -> int | None:
136+
"""Extract decode GPU count from metadata or resolved runtime config."""
137+
metadata = _read_job_metadata(log_dir)
138+
if metadata:
139+
resources = metadata.get("resources")
140+
if isinstance(resources, dict):
141+
decode_gpu_count = _compute_decode_gpu_count(resources, total_gpu_count)
142+
if decode_gpu_count is not None:
143+
return decode_gpu_count
144+
145+
runtime_resources = _read_runtime_resources(log_dir)
146+
if runtime_resources:
147+
return _compute_decode_gpu_count(runtime_resources, total_gpu_count)
148+
149+
return None
150+
151+
152+
def _read_yaml_name(path: Path) -> str | None:
153+
"""Read the top-level name field from a YAML file."""
154+
try:
155+
data = yaml.safe_load(path.read_text()) or {}
156+
except Exception as exc:
157+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
158+
return None
159+
160+
name = data.get("name")
161+
return str(name) if name else None
162+
163+
164+
def _read_job_name(path: Path) -> str | None:
165+
"""Read the job_name field from submit metadata JSON."""
166+
try:
167+
data: dict[str, Any] = json.loads(path.read_text())
168+
except Exception as exc:
169+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
170+
return None
171+
172+
job_name = data.get("job_name")
173+
return str(job_name) if job_name else None
174+
175+
176+
def _read_job_metadata(log_dir: Path) -> dict[str, Any] | None:
177+
"""Read submit metadata JSON from the output directory when available."""
178+
output_dir = log_dir.parent
179+
for metadata_path in sorted(output_dir.glob("*.json")):
180+
try:
181+
data = json.loads(metadata_path.read_text())
182+
except Exception as exc:
183+
print(f"Failed to parse {metadata_path}: {exc}", file=sys.stderr)
184+
continue
185+
if isinstance(data, dict):
186+
return data
187+
return None
188+
189+
190+
def _read_config_name(log_dir: Path) -> str | None:
191+
"""Read the effective config name, preferring resolved override configs."""
192+
output_dir = log_dir.parent
193+
194+
# For override jobs, submit.py saves the resolved runtime config as config_<suffix>.yaml.
195+
runtime_configs = sorted(output_dir.glob("config_*.yaml"))
196+
for config_path in runtime_configs:
197+
name = _read_yaml_name(config_path)
198+
if name:
199+
return name
200+
201+
# Job metadata also stores the final job/config name.
202+
metadata_files = sorted(output_dir.glob("*.json"))
203+
for metadata_path in metadata_files:
204+
job_name = _read_job_name(metadata_path)
205+
if job_name:
206+
return job_name
207+
208+
# Fall back to the copied source config in logs/.
209+
config_path = log_dir / "config.yaml"
210+
if config_path.exists():
211+
return _read_yaml_name(config_path)
212+
213+
return None
214+
215+
216+
def _is_sglang_disagg(log_dir: Path) -> bool:
217+
"""Return whether the current run is an SGLang disaggregated deployment."""
218+
metadata = _read_job_metadata(log_dir)
219+
if not metadata:
220+
return False
221+
222+
if metadata.get("backend_type") != "sglang":
223+
return False
224+
225+
resources = metadata.get("resources")
226+
if not isinstance(resources, dict):
227+
return False
228+
229+
prefill_nodes = int(resources.get("prefill_nodes", 0) or 0)
230+
decode_nodes = int(resources.get("decode_nodes", 0) or 0)
231+
agg_workers = int(resources.get("agg_workers", 0) or 0)
232+
return prefill_nodes > 0 and decode_nodes > 0 and agg_workers == 0
233+
234+
235+
def _extract_p90_decode_running_requests(log_dir: Path) -> int | None:
236+
"""Stream decode logs and compute the nearest-rank P90 of #running-req values."""
237+
if not _is_sglang_disagg(log_dir):
238+
return None
239+
240+
counts: Counter[int] = Counter()
241+
total = 0
242+
243+
for decode_log in sorted(log_dir.glob("*decode*.out")):
244+
try:
245+
with decode_log.open("r", errors="replace") as f:
246+
for line in f:
247+
match = RUNNING_REQ_PATTERN.search(line)
248+
if not match:
249+
continue
250+
value = int(match.group(1))
251+
counts[value] += 1
252+
total += 1
253+
except OSError as exc:
254+
print(f"Failed to read {decode_log}: {exc}", file=sys.stderr)
255+
256+
if total == 0:
257+
return None
258+
259+
rank = math.ceil(total * 0.9)
260+
cumulative = 0
261+
for value in sorted(counts):
262+
cumulative += counts[value]
263+
if cumulative >= rank:
264+
return value
265+
266+
return None
267+
268+
269+
def _safe_ratio(numerator: float | int | None, denominator: float | int | None) -> float | None:
270+
"""Return numerator / denominator when both values are valid and denominator != 0."""
271+
if numerator is None or denominator in (None, 0):
272+
return None
273+
return float(numerator) / float(denominator)
274+
275+
276+
def _format_csv_value(value: object) -> str:
277+
"""Format CSV values with at most three decimal places for numeric fields."""
278+
if value is None:
279+
return ""
280+
if isinstance(value, int):
281+
return str(value)
282+
if isinstance(value, float):
283+
return f"{value:.3f}".rstrip("0").rstrip(".")
284+
return str(value)
285+
286+
287+
def _build_csv_row(
288+
data: dict[str, object],
289+
config_name: str,
290+
gpu_num: int | None,
291+
decode_gpu_count: int | None,
292+
p90_decode_running_requests: int | None,
293+
) -> dict[str, object]:
294+
"""Build one CSV row from a parsed sa-bench result."""
295+
total_token_throughput = data.get("total_token_throughput")
296+
median_tpot = data.get("median_tpot_ms")
297+
row = {
298+
"Config": config_name,
299+
"Total GPU Count": gpu_num,
300+
"Decode GPU Count": decode_gpu_count,
301+
"Concurrency": data.get("max_concurrency"),
302+
"Total Token Throughput": total_token_throughput,
303+
"Output Token Throughput": data.get("output_throughput"),
304+
"Median TTFT": data.get("median_ttft_ms"),
305+
"Median TPOT": median_tpot,
306+
"Median ITL": data.get("median_itl_ms"),
307+
"P90 Decode Running Requests": p90_decode_running_requests,
308+
"Output Token Throughput per User": _safe_ratio(1000.0, median_tpot),
309+
"Total Token Throughput per GPU": _safe_ratio(total_token_throughput, gpu_num),
310+
}
311+
return {key: _format_csv_value(value) for key, value in row.items()}
312+
313+
23314
def main(log_dir: Path) -> None:
24-
"""Generate benchmark-rollup.json from sa-bench result files."""
315+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench result files."""
25316
result_files = sorted(log_dir.glob("sa-bench_*/results_*.json"))
26317
if not result_files:
27318
print("No sa-bench results found", file=sys.stderr)
28319
return
29320

30321
runs = []
322+
csv_rows = []
31323
config = {}
324+
config_name = _read_config_name(log_dir)
325+
first_gpu_num = _extract_gpu_num(result_files[0]) if result_files else None
326+
decode_gpu_count = _extract_decode_gpu_count(log_dir, first_gpu_num)
327+
p90_decode_running_requests = _extract_p90_decode_running_requests(log_dir)
32328

33-
for f in result_files:
329+
for result_file in result_files:
34330
try:
35-
data = json.loads(f.read_text())
36-
except json.JSONDecodeError as e:
37-
print(f"Failed to parse {f}: {e}", file=sys.stderr)
331+
data = json.loads(result_file.read_text())
332+
except json.JSONDecodeError as exc:
333+
print(f"Failed to parse {result_file}: {exc}", file=sys.stderr)
38334
continue
39335

40-
# Extract config from first file
41336
if not config:
42337
config = {
43338
"model": data.get("model_id"),
@@ -61,16 +356,38 @@ def main(log_dir: Path) -> None:
61356
"total_output_tokens": data.get("total_output"),
62357
})
63358

359+
csv_rows.append(
360+
_build_csv_row(
361+
data=data,
362+
config_name=config_name or str(data.get("model_id") or "unknown"),
363+
gpu_num=_extract_gpu_num(result_file),
364+
decode_gpu_count=decode_gpu_count,
365+
p90_decode_running_requests=p90_decode_running_requests,
366+
)
367+
)
368+
369+
if not runs:
370+
print("No valid sa-bench results found", file=sys.stderr)
371+
return
372+
64373
rollup = {
65374
"benchmark_type": "sa-bench",
66375
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
67376
"config": config,
68377
"runs": runs,
69378
}
70379

71-
output_path = log_dir / "benchmark-rollup.json"
72-
output_path.write_text(json.dumps(rollup, indent=2))
73-
print(f"Wrote {output_path}")
380+
json_path = log_dir / "benchmark-rollup.json"
381+
json_path.write_text(json.dumps(rollup, indent=2))
382+
print(f"Wrote {json_path}")
383+
384+
csv_rows.sort(key=lambda row: int(row["Concurrency"]) if row["Concurrency"] else -1)
385+
csv_path = log_dir / "benchmark-rollup.csv"
386+
with csv_path.open("w", newline="") as csv_file:
387+
writer = csv.DictWriter(csv_file, fieldnames=OUTPUT_FIELDS)
388+
writer.writeheader()
389+
writer.writerows(csv_rows)
390+
print(f"Wrote {csv_path}")
74391

75392

76393
if __name__ == "__main__":

0 commit comments

Comments
 (0)