Skip to content

Commit dc2328e

Browse files
author
Weiliangl User
committed
Add CSV export for sa-bench rollup
1 parent 31c3e59 commit dc2328e

2 files changed

Lines changed: 482 additions & 10 deletions

File tree

src/srtctl/benchmarks/scripts/sa-bench/rollup.py

Lines changed: 316 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,47 @@
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""Generate benchmark-rollup.json from sa-bench results."""
5+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench results."""
66

7+
from __future__ import annotations
8+
9+
import csv
710
import json
11+
from collections import Counter
12+
import math
13+
import re
814
import sys
915
from datetime import datetime, timezone
1016
from pathlib import Path
17+
from typing import Any, NamedTuple
18+
19+
import yaml
20+
21+
22+
OUTPUT_FIELDS = [
23+
"Config",
24+
"Total GPU Count",
25+
"Decode GPU Count",
26+
"Concurrency",
27+
"Total Token Throughput",
28+
"Output Token Throughput",
29+
"Median TTFT",
30+
"Median TPOT",
31+
"Median ITL",
32+
"P90 Decode Running Requests",
33+
"Output Token Throughput per User",
34+
"Total Token Throughput per GPU",
35+
]
36+
37+
RUNNING_REQ_PATTERN = re.compile(r"#running-req:\s*(\d+)")
38+
39+
40+
class RollupContext(NamedTuple):
41+
"""Resolved runtime context shared by CSV enrichment helpers."""
42+
43+
config_name: str | None
44+
resources: dict[str, Any] | None
45+
backend_type: str | None
1146

1247

1348
def _get_percentile(percentiles: list, target: float) -> float | None:
@@ -20,24 +55,273 @@ def _get_percentile(percentiles: list, target: float) -> float | None:
2055
return None
2156

2257

58+
def _read_yaml_dict(path: Path) -> dict[str, Any] | None:
59+
"""Read a YAML file into a dictionary."""
60+
try:
61+
data = yaml.safe_load(path.read_text()) or {}
62+
except Exception as exc:
63+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
64+
return None
65+
66+
return data if isinstance(data, dict) else None
67+
68+
69+
def _read_json_dict(path: Path) -> dict[str, Any] | None:
70+
"""Read a JSON file into a dictionary."""
71+
try:
72+
data = json.loads(path.read_text())
73+
except Exception as exc:
74+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
75+
return None
76+
77+
return data if isinstance(data, dict) else None
78+
79+
80+
def _read_runtime_config(log_dir: Path) -> dict[str, Any] | None:
81+
"""Read resolved runtime config, preferring override-expanded configs."""
82+
output_dir = log_dir.parent
83+
84+
runtime_configs = sorted(output_dir.glob("config_*.yaml"))
85+
for config_path in runtime_configs:
86+
config = _read_yaml_dict(config_path)
87+
if config:
88+
return config
89+
90+
config_path = log_dir / "config.yaml"
91+
if config_path.exists():
92+
return _read_yaml_dict(config_path)
93+
94+
return None
95+
96+
97+
def _read_job_metadata(log_dir: Path) -> dict[str, Any] | None:
98+
"""Read submit metadata JSON from the output directory when available."""
99+
output_dir = log_dir.parent
100+
for metadata_path in sorted(output_dir.glob("*.json")):
101+
data = _read_json_dict(metadata_path)
102+
if data:
103+
return data
104+
return None
105+
106+
107+
def _load_rollup_context(log_dir: Path) -> RollupContext:
108+
"""Load config name, resources, and backend type once for downstream helpers."""
109+
runtime_config = _read_runtime_config(log_dir)
110+
metadata = _read_job_metadata(log_dir)
111+
112+
config_name = None
113+
if runtime_config:
114+
name = runtime_config.get("name")
115+
if isinstance(name, str) and name:
116+
config_name = name
117+
if config_name is None and metadata:
118+
job_name = metadata.get("job_name")
119+
if isinstance(job_name, str) and job_name:
120+
config_name = job_name
121+
122+
resources = None
123+
if metadata:
124+
metadata_resources = metadata.get("resources")
125+
if isinstance(metadata_resources, dict):
126+
resources = metadata_resources
127+
if resources is None and runtime_config:
128+
runtime_resources = runtime_config.get("resources")
129+
if isinstance(runtime_resources, dict):
130+
resources = runtime_resources
131+
132+
backend_type = None
133+
if metadata:
134+
value = metadata.get("backend_type")
135+
if isinstance(value, str) and value:
136+
backend_type = value
137+
if backend_type is None and runtime_config:
138+
backend = runtime_config.get("backend")
139+
if isinstance(backend, dict):
140+
value = backend.get("type")
141+
if isinstance(value, str) and value:
142+
backend_type = value
143+
elif "sglang_config" in backend:
144+
backend_type = "sglang"
145+
146+
return RollupContext(
147+
config_name=config_name,
148+
resources=resources,
149+
backend_type=backend_type,
150+
)
151+
152+
153+
def _compute_total_gpu_count(resources: dict[str, Any]) -> int | None:
154+
"""Compute total GPU count from resources using the same topology semantics as the config."""
155+
gpus_per_node_raw = resources.get("gpus_per_node")
156+
if gpus_per_node_raw in (None, 0):
157+
return None
158+
gpus_per_node = int(gpus_per_node_raw)
159+
160+
prefill_nodes = int(resources.get("prefill_nodes", 0) or 0)
161+
decode_nodes = int(resources.get("decode_nodes", 0) or 0)
162+
if prefill_nodes or decode_nodes:
163+
return (prefill_nodes + decode_nodes) * gpus_per_node
164+
165+
agg_nodes = int(resources.get("agg_nodes", 0) or 0)
166+
if agg_nodes:
167+
return agg_nodes * gpus_per_node
168+
169+
return gpus_per_node
170+
171+
172+
def _compute_gpu_counts(resources: dict[str, Any]) -> tuple[int | None, int | None]:
173+
"""Compute total and decode GPU counts from resource settings."""
174+
total_gpu_count = _compute_total_gpu_count(resources)
175+
176+
decode_workers = int(resources.get("decode_workers", 0) or 0)
177+
decode_nodes_raw = resources.get("decode_nodes")
178+
decode_nodes = int(decode_nodes_raw) if decode_nodes_raw not in (None, "") else None
179+
180+
explicit = resources.get("gpus_per_decode")
181+
if explicit not in (None, 0):
182+
gpus_per_decode = int(explicit)
183+
decode_gpu_count = decode_workers * gpus_per_decode if decode_workers else gpus_per_decode
184+
return total_gpu_count, decode_gpu_count
185+
186+
gpus_per_node_raw = resources.get("gpus_per_node")
187+
gpus_per_node = int(gpus_per_node_raw) if gpus_per_node_raw not in (None, 0) else None
188+
189+
if gpus_per_node is None and total_gpu_count not in (None, 0):
190+
prefill_nodes = int(resources.get("prefill_nodes", 0) or 0)
191+
total_nodes = prefill_nodes + (decode_nodes or 0)
192+
if total_nodes > 0 and total_gpu_count % total_nodes == 0:
193+
gpus_per_node = total_gpu_count // total_nodes
194+
195+
if decode_nodes not in (None, 0) and gpus_per_node not in (None, 0):
196+
if decode_workers:
197+
gpus_per_decode = (decode_nodes * gpus_per_node) // decode_workers
198+
return total_gpu_count, decode_workers * gpus_per_decode
199+
return total_gpu_count, decode_nodes * gpus_per_node
200+
201+
if decode_nodes == 0 and decode_workers:
202+
explicit_prefill = resources.get("gpus_per_prefill")
203+
if explicit_prefill not in (None, 0):
204+
gpus_per_prefill = int(explicit_prefill)
205+
else:
206+
prefill_nodes = resources.get("prefill_nodes")
207+
prefill_workers = resources.get("prefill_workers")
208+
if prefill_nodes not in (None, 0) and prefill_workers not in (None, 0) and gpus_per_node not in (None, 0):
209+
gpus_per_prefill = (int(prefill_nodes) * int(gpus_per_node)) // int(prefill_workers)
210+
else:
211+
gpus_per_prefill = gpus_per_node
212+
if gpus_per_prefill not in (None, 0):
213+
return total_gpu_count, decode_workers * gpus_per_prefill
214+
215+
return total_gpu_count, None
216+
217+
218+
def _extract_p90_decode_running_requests(log_dir: Path, context: RollupContext) -> int | None:
219+
"""Stream decode logs and compute the nearest-rank P90 of #running-req values."""
220+
resources = context.resources
221+
if context.backend_type != "sglang" or not isinstance(resources, dict):
222+
return None
223+
if int(resources.get("prefill_nodes", 0) or 0) <= 0:
224+
return None
225+
if int(resources.get("decode_nodes", 0) or 0) <= 0:
226+
return None
227+
if int(resources.get("agg_workers", 0) or 0) != 0:
228+
return None
229+
230+
counts: Counter[int] = Counter()
231+
total = 0
232+
233+
for decode_log in sorted(log_dir.glob("*decode*.out")):
234+
try:
235+
with decode_log.open("r", errors="replace") as f:
236+
for line in f:
237+
match = RUNNING_REQ_PATTERN.search(line)
238+
if not match:
239+
continue
240+
value = int(match.group(1))
241+
counts[value] += 1
242+
total += 1
243+
except OSError as exc:
244+
print(f"Failed to read {decode_log}: {exc}", file=sys.stderr)
245+
246+
if total == 0:
247+
return None
248+
249+
rank = math.ceil(total * 0.9)
250+
cumulative = 0
251+
for value in sorted(counts):
252+
cumulative += counts[value]
253+
if cumulative >= rank:
254+
return value
255+
256+
return None
257+
258+
259+
def _safe_ratio(numerator: float | int | None, denominator: float | int | None) -> float | None:
260+
"""Return numerator / denominator when both values are valid and denominator != 0."""
261+
if numerator is None or denominator in (None, 0):
262+
return None
263+
return float(numerator) / float(denominator)
264+
265+
266+
def _format_csv_value(value: object) -> str:
267+
"""Format CSV values with at most three decimal places for numeric fields."""
268+
if value is None:
269+
return ""
270+
if isinstance(value, int):
271+
return str(value)
272+
if isinstance(value, float):
273+
return f"{value:.3f}".rstrip("0").rstrip(".")
274+
return str(value)
275+
276+
277+
def _build_csv_row(
278+
data: dict[str, object],
279+
config_name: str,
280+
gpu_num: int | None,
281+
decode_gpu_count: int | None,
282+
p90_decode_running_requests: int | None,
283+
) -> dict[str, object]:
284+
"""Build one CSV row from a parsed sa-bench result."""
285+
total_token_throughput = data.get("total_token_throughput")
286+
median_tpot = data.get("median_tpot_ms")
287+
row = {
288+
"Config": config_name,
289+
"Total GPU Count": gpu_num,
290+
"Decode GPU Count": decode_gpu_count,
291+
"Concurrency": data.get("max_concurrency"),
292+
"Total Token Throughput": total_token_throughput,
293+
"Output Token Throughput": data.get("output_throughput"),
294+
"Median TTFT": data.get("median_ttft_ms"),
295+
"Median TPOT": median_tpot,
296+
"Median ITL": data.get("median_itl_ms"),
297+
"P90 Decode Running Requests": p90_decode_running_requests,
298+
"Output Token Throughput per User": _safe_ratio(1000.0, median_tpot),
299+
"Total Token Throughput per GPU": _safe_ratio(total_token_throughput, gpu_num),
300+
}
301+
return {key: _format_csv_value(value) for key, value in row.items()}
302+
303+
23304
def main(log_dir: Path) -> None:
24-
"""Generate benchmark-rollup.json from sa-bench result files."""
305+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench result files."""
25306
result_files = sorted(log_dir.glob("sa-bench_*/results_*.json"))
26307
if not result_files:
27308
print("No sa-bench results found", file=sys.stderr)
28309
return
29310

30311
runs = []
312+
csv_rows = []
31313
config = {}
314+
context = _load_rollup_context(log_dir)
315+
total_gpu_count, decode_gpu_count = _compute_gpu_counts(context.resources) if context.resources else (None, None)
316+
p90_decode_running_requests = _extract_p90_decode_running_requests(log_dir, context)
32317

33-
for f in result_files:
318+
for result_file in result_files:
34319
try:
35-
data = json.loads(f.read_text())
36-
except json.JSONDecodeError as e:
37-
print(f"Failed to parse {f}: {e}", file=sys.stderr)
320+
data = json.loads(result_file.read_text())
321+
except json.JSONDecodeError as exc:
322+
print(f"Failed to parse {result_file}: {exc}", file=sys.stderr)
38323
continue
39324

40-
# Extract config from first file
41325
if not config:
42326
config = {
43327
"model": data.get("model_id"),
@@ -61,16 +345,38 @@ def main(log_dir: Path) -> None:
61345
"total_output_tokens": data.get("total_output"),
62346
})
63347

348+
csv_rows.append(
349+
_build_csv_row(
350+
data=data,
351+
config_name=context.config_name or str(data.get("model_id") or "unknown"),
352+
gpu_num=total_gpu_count,
353+
decode_gpu_count=decode_gpu_count,
354+
p90_decode_running_requests=p90_decode_running_requests,
355+
)
356+
)
357+
358+
if not runs:
359+
print("No valid sa-bench results found", file=sys.stderr)
360+
return
361+
64362
rollup = {
65363
"benchmark_type": "sa-bench",
66364
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
67365
"config": config,
68366
"runs": runs,
69367
}
70368

71-
output_path = log_dir / "benchmark-rollup.json"
72-
output_path.write_text(json.dumps(rollup, indent=2))
73-
print(f"Wrote {output_path}")
369+
json_path = log_dir / "benchmark-rollup.json"
370+
json_path.write_text(json.dumps(rollup, indent=2))
371+
print(f"Wrote {json_path}")
372+
373+
csv_rows.sort(key=lambda row: int(row["Concurrency"]) if row["Concurrency"] else -1)
374+
csv_path = log_dir / "benchmark-rollup.csv"
375+
with csv_path.open("w", newline="") as csv_file:
376+
writer = csv.DictWriter(csv_file, fieldnames=OUTPUT_FIELDS)
377+
writer.writeheader()
378+
writer.writerows(csv_rows)
379+
print(f"Wrote {csv_path}")
74380

75381

76382
if __name__ == "__main__":

0 commit comments

Comments
 (0)