Skip to content

Commit b6587d7

Browse files
author
Weiliangl User
committed
Add CSV export for sa-bench rollup
1 parent 31c3e59 commit b6587d7

2 files changed

Lines changed: 359 additions & 10 deletions

File tree

src/srtctl/benchmarks/scripts/sa-bench/rollup.py

Lines changed: 190 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,37 @@
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""Generate benchmark-rollup.json from sa-bench results."""
5+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench results."""
66

7+
from __future__ import annotations
8+
9+
import csv
710
import json
11+
from collections import Counter
12+
import math
13+
import re
814
import sys
915
from datetime import datetime, timezone
1016
from pathlib import Path
17+
from typing import Any
18+
19+
20+
OUTPUT_FIELDS = [
21+
"Config",
22+
"Total GPU Count",
23+
"Decode GPU Count",
24+
"Concurrency",
25+
"Total Token Throughput",
26+
"Output Token Throughput",
27+
"Median TTFT",
28+
"Median TPOT",
29+
"Median ITL",
30+
"P90 Decode Running Requests",
31+
"Output Token Throughput per User",
32+
"Total Token Throughput per GPU",
33+
]
34+
35+
RUNNING_REQ_PATTERN = re.compile(r"#running-req:\s*(\d+)")
1136

1237

1338
def _get_percentile(percentiles: list, target: float) -> float | None:
@@ -20,24 +45,157 @@ def _get_percentile(percentiles: list, target: float) -> float | None:
2045
return None
2146

2247

48+
def _read_job_metadata(log_dir: Path) -> dict[str, Any] | None:
49+
"""Read submit metadata JSON from the output directory when available."""
50+
output_dir = log_dir.parent
51+
for metadata_path in sorted(output_dir.glob("*.json")):
52+
try:
53+
data = json.loads(metadata_path.read_text())
54+
except Exception as exc:
55+
print(f"Failed to parse {metadata_path}: {exc}", file=sys.stderr)
56+
continue
57+
if data:
58+
return data
59+
return None
60+
61+
62+
def _compute_gpu_counts(resources: dict[str, Any]) -> tuple[int | None, int | None]:
63+
"""Compute total and decode GPU counts from resource settings."""
64+
gpus_per_node = int(resources.get("gpus_per_node", 0))
65+
prefill_nodes = int(resources.get("prefill_nodes", 0))
66+
decode_nodes = int(resources.get("decode_nodes", 0))
67+
agg_nodes = int(resources.get("agg_nodes", 0))
68+
if gpus_per_node <= 0:
69+
return None, None
70+
71+
if prefill_nodes > 0 or decode_nodes > 0:
72+
total_gpu_count = (prefill_nodes + decode_nodes) * gpus_per_node
73+
elif agg_nodes > 0:
74+
total_gpu_count = agg_nodes * gpus_per_node
75+
else:
76+
total_gpu_count = gpus_per_node
77+
78+
if decode_nodes > 0:
79+
return total_gpu_count, decode_nodes * gpus_per_node
80+
81+
decode_workers = int(resources.get("decode_workers", 0))
82+
gpus_per_decode = int(resources.get("gpus_per_decode", 0))
83+
if decode_workers > 0 and gpus_per_decode > 0:
84+
return total_gpu_count, decode_workers * gpus_per_decode
85+
86+
return total_gpu_count, None
87+
88+
89+
def _extract_p90_decode_running_requests(log_dir: Path, metadata: dict[str, Any] | None) -> int | None:
90+
"""Stream decode logs and compute the nearest-rank P90 of #running-req values."""
91+
if not metadata or metadata.get("backend_type") != "sglang":
92+
return None
93+
94+
resources = metadata.get("resources")
95+
if resources is None:
96+
return None
97+
if not (int(resources.get("prefill_nodes", 0)) > 0 and int(resources.get("decode_nodes", 0)) > 0):
98+
return None
99+
if int(resources.get("agg_workers", 0)) > 0:
100+
return None
101+
102+
counts: Counter[int] = Counter()
103+
total = 0
104+
105+
for decode_log in sorted(log_dir.glob("*decode*.out")):
106+
try:
107+
with decode_log.open("r", errors="replace") as f:
108+
for line in f:
109+
match = RUNNING_REQ_PATTERN.search(line)
110+
if not match:
111+
continue
112+
value = int(match.group(1))
113+
counts[value] += 1
114+
total += 1
115+
except OSError as exc:
116+
print(f"Failed to read {decode_log}: {exc}", file=sys.stderr)
117+
118+
if total == 0:
119+
return None
120+
121+
rank = math.ceil(total * 0.9)
122+
cumulative = 0
123+
for value in sorted(counts):
124+
cumulative += counts[value]
125+
if cumulative >= rank:
126+
return value
127+
128+
return None
129+
130+
131+
def _safe_ratio(numerator: float | int | None, denominator: float | int | None) -> float | None:
132+
"""Return numerator / denominator when both values are valid and denominator != 0."""
133+
if numerator is None or denominator in (None, 0):
134+
return None
135+
return float(numerator) / float(denominator)
136+
137+
138+
def _format_csv_value(value: object) -> str:
139+
"""Format CSV values with at most three decimal places for numeric fields."""
140+
if value is None:
141+
return ""
142+
if isinstance(value, int):
143+
return str(value)
144+
if isinstance(value, float):
145+
return f"{value:.3f}".rstrip("0").rstrip(".")
146+
return str(value)
147+
148+
149+
def _build_csv_row(
150+
data: dict[str, object],
151+
config_name: str,
152+
gpu_num: int | None,
153+
decode_gpu_count: int | None,
154+
p90_decode_running_requests: int | None,
155+
) -> dict[str, object]:
156+
"""Build one CSV row from a parsed sa-bench result."""
157+
total_token_throughput = data.get("total_token_throughput")
158+
median_tpot = data.get("median_tpot_ms")
159+
row = {
160+
"Config": config_name,
161+
"Total GPU Count": gpu_num,
162+
"Decode GPU Count": decode_gpu_count,
163+
"Concurrency": data.get("max_concurrency"),
164+
"Total Token Throughput": total_token_throughput,
165+
"Output Token Throughput": data.get("output_throughput"),
166+
"Median TTFT": data.get("median_ttft_ms"),
167+
"Median TPOT": median_tpot,
168+
"Median ITL": data.get("median_itl_ms"),
169+
"P90 Decode Running Requests": p90_decode_running_requests,
170+
"Output Token Throughput per User": _safe_ratio(1000.0, median_tpot),
171+
"Total Token Throughput per GPU": _safe_ratio(total_token_throughput, gpu_num),
172+
}
173+
return {key: _format_csv_value(value) for key, value in row.items()}
174+
175+
23176
def main(log_dir: Path) -> None:
24-
"""Generate benchmark-rollup.json from sa-bench result files."""
177+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench result files."""
25178
result_files = sorted(log_dir.glob("sa-bench_*/results_*.json"))
26179
if not result_files:
27180
print("No sa-bench results found", file=sys.stderr)
28181
return
29182

30183
runs = []
184+
csv_rows = []
31185
config = {}
186+
metadata = _read_job_metadata(log_dir)
187+
config_name = metadata.get("job_name") if metadata else None
188+
resources = metadata.get("resources") if metadata else None
189+
total_gpu_count, decode_gpu_count = _compute_gpu_counts(resources) if resources else (None, None)
190+
p90_decode_running_requests = _extract_p90_decode_running_requests(log_dir, metadata)
32191

33-
for f in result_files:
192+
for result_file in result_files:
34193
try:
35-
data = json.loads(f.read_text())
36-
except json.JSONDecodeError as e:
37-
print(f"Failed to parse {f}: {e}", file=sys.stderr)
194+
data = json.loads(result_file.read_text())
195+
except json.JSONDecodeError as exc:
196+
print(f"Failed to parse {result_file}: {exc}", file=sys.stderr)
38197
continue
39198

40-
# Extract config from first file
41199
if not config:
42200
config = {
43201
"model": data.get("model_id"),
@@ -61,16 +219,38 @@ def main(log_dir: Path) -> None:
61219
"total_output_tokens": data.get("total_output"),
62220
})
63221

222+
csv_rows.append(
223+
_build_csv_row(
224+
data=data,
225+
config_name=config_name or str(data.get("model_id") or "unknown"),
226+
gpu_num=total_gpu_count,
227+
decode_gpu_count=decode_gpu_count,
228+
p90_decode_running_requests=p90_decode_running_requests,
229+
)
230+
)
231+
232+
if not runs:
233+
print("No valid sa-bench results found", file=sys.stderr)
234+
return
235+
64236
rollup = {
65237
"benchmark_type": "sa-bench",
66238
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
67239
"config": config,
68240
"runs": runs,
69241
}
70242

71-
output_path = log_dir / "benchmark-rollup.json"
72-
output_path.write_text(json.dumps(rollup, indent=2))
73-
print(f"Wrote {output_path}")
243+
json_path = log_dir / "benchmark-rollup.json"
244+
json_path.write_text(json.dumps(rollup, indent=2))
245+
print(f"Wrote {json_path}")
246+
247+
csv_rows.sort(key=lambda row: int(row["Concurrency"]) if row["Concurrency"] else -1)
248+
csv_path = log_dir / "benchmark-rollup.csv"
249+
with csv_path.open("w", newline="") as csv_file:
250+
writer = csv.DictWriter(csv_file, fieldnames=OUTPUT_FIELDS)
251+
writer.writeheader()
252+
writer.writerows(csv_rows)
253+
print(f"Wrote {csv_path}")
74254

75255

76256
if __name__ == "__main__":

0 commit comments

Comments
 (0)