Skip to content

Commit 7b4b95e

Browse files
author
Weiliangl User
committed
Add CSV export for sa-bench rollup
1 parent 31c3e59 commit 7b4b95e

2 files changed

Lines changed: 471 additions & 10 deletions

File tree

src/srtctl/benchmarks/scripts/sa-bench/rollup.py

Lines changed: 305 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,47 @@
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""Generate benchmark-rollup.json from sa-bench results."""
5+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench results."""
66

7+
from __future__ import annotations
8+
9+
import csv
710
import json
11+
from collections import Counter
12+
import math
13+
import re
814
import sys
915
from datetime import datetime, timezone
1016
from pathlib import Path
17+
from typing import Any, NamedTuple
18+
19+
import yaml
20+
21+
22+
OUTPUT_FIELDS = [
23+
"Config",
24+
"Total GPU Count",
25+
"Decode GPU Count",
26+
"Concurrency",
27+
"Total Token Throughput",
28+
"Output Token Throughput",
29+
"Median TTFT",
30+
"Median TPOT",
31+
"Median ITL",
32+
"P90 Decode Running Requests",
33+
"Output Token Throughput per User",
34+
"Total Token Throughput per GPU",
35+
]
36+
37+
RUNNING_REQ_PATTERN = re.compile(r"#running-req:\s*(\d+)")
38+
39+
40+
class RollupContext(NamedTuple):
41+
"""Resolved runtime context shared by CSV enrichment helpers."""
42+
43+
config_name: str | None
44+
resources: dict[str, Any] | None
45+
backend_type: str | None
1146

1247

1348
def _get_percentile(percentiles: list, target: float) -> float | None:
@@ -20,24 +55,262 @@ def _get_percentile(percentiles: list, target: float) -> float | None:
2055
return None
2156

2257

58+
def _read_yaml_dict(path: Path) -> dict[str, Any] | None:
59+
"""Read a YAML file into a dictionary."""
60+
try:
61+
data = yaml.safe_load(path.read_text()) or {}
62+
except Exception as exc:
63+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
64+
return None
65+
66+
return data if isinstance(data, dict) else None
67+
68+
69+
def _read_json_dict(path: Path) -> dict[str, Any] | None:
70+
"""Read a JSON file into a dictionary."""
71+
try:
72+
data = json.loads(path.read_text())
73+
except Exception as exc:
74+
print(f"Failed to parse {path}: {exc}", file=sys.stderr)
75+
return None
76+
77+
return data if isinstance(data, dict) else None
78+
79+
80+
def _read_runtime_config(log_dir: Path) -> dict[str, Any] | None:
81+
"""Read resolved runtime config, preferring override-expanded configs."""
82+
output_dir = log_dir.parent
83+
84+
runtime_configs = sorted(output_dir.glob("config_*.yaml"))
85+
for config_path in runtime_configs:
86+
config = _read_yaml_dict(config_path)
87+
if config:
88+
return config
89+
90+
config_path = log_dir / "config.yaml"
91+
if config_path.exists():
92+
return _read_yaml_dict(config_path)
93+
94+
return None
95+
96+
97+
def _read_job_metadata(log_dir: Path) -> dict[str, Any] | None:
98+
"""Read submit metadata JSON from the output directory when available."""
99+
output_dir = log_dir.parent
100+
for metadata_path in sorted(output_dir.glob("*.json")):
101+
data = _read_json_dict(metadata_path)
102+
if data:
103+
return data
104+
return None
105+
106+
107+
def _load_rollup_context(log_dir: Path) -> RollupContext:
108+
"""Load config name, resources, and backend type once for downstream helpers."""
109+
runtime_config = _read_runtime_config(log_dir)
110+
metadata = _read_job_metadata(log_dir)
111+
112+
config_name = None
113+
if runtime_config:
114+
name = runtime_config.get("name")
115+
if isinstance(name, str) and name:
116+
config_name = name
117+
if config_name is None and metadata:
118+
job_name = metadata.get("job_name")
119+
if isinstance(job_name, str) and job_name:
120+
config_name = job_name
121+
122+
resources = None
123+
if metadata:
124+
metadata_resources = metadata.get("resources")
125+
if isinstance(metadata_resources, dict):
126+
resources = metadata_resources
127+
if resources is None and runtime_config:
128+
runtime_resources = runtime_config.get("resources")
129+
if isinstance(runtime_resources, dict):
130+
resources = runtime_resources
131+
132+
backend_type = None
133+
if metadata:
134+
value = metadata.get("backend_type")
135+
if isinstance(value, str) and value:
136+
backend_type = value
137+
if backend_type is None and runtime_config:
138+
backend = runtime_config.get("backend")
139+
if isinstance(backend, dict):
140+
value = backend.get("type")
141+
if isinstance(value, str) and value:
142+
backend_type = value
143+
elif "sglang_config" in backend:
144+
backend_type = "sglang"
145+
146+
return RollupContext(
147+
config_name=config_name,
148+
resources=resources,
149+
backend_type=backend_type,
150+
)
151+
152+
153+
def _compute_gpu_counts(resources: dict[str, Any]) -> tuple[int | None, int | None]:
154+
"""Compute total and decode GPU counts from resource settings."""
155+
decode_workers = int(resources.get("decode_workers", 0) or 0)
156+
explicit = resources.get("gpus_per_decode")
157+
158+
gpus_per_node_raw = resources.get("gpus_per_node")
159+
gpus_per_node = int(gpus_per_node_raw) if gpus_per_node_raw not in (None, 0) else None
160+
prefill_nodes = int(resources.get("prefill_nodes", 0) or 0)
161+
decode_nodes_raw = resources.get("decode_nodes")
162+
decode_nodes = int(decode_nodes_raw) if decode_nodes_raw not in (None, "") else None
163+
agg_nodes = int(resources.get("agg_nodes", 0) or 0)
164+
165+
total_gpu_count = None
166+
if gpus_per_node is not None:
167+
if prefill_nodes or decode_nodes:
168+
total_gpu_count = (prefill_nodes + (decode_nodes or 0)) * gpus_per_node
169+
elif agg_nodes:
170+
total_gpu_count = agg_nodes * gpus_per_node
171+
else:
172+
total_gpu_count = gpus_per_node
173+
174+
if explicit not in (None, 0):
175+
gpus_per_decode = int(explicit)
176+
decode_gpu_count = decode_workers * gpus_per_decode if decode_workers else gpus_per_decode
177+
return total_gpu_count, decode_gpu_count
178+
179+
if gpus_per_node is None and total_gpu_count not in (None, 0):
180+
total_nodes = prefill_nodes + (decode_nodes or 0)
181+
if total_nodes > 0 and total_gpu_count % total_nodes == 0:
182+
gpus_per_node = total_gpu_count // total_nodes
183+
184+
if decode_nodes not in (None, 0) and gpus_per_node not in (None, 0):
185+
if decode_workers:
186+
gpus_per_decode = (decode_nodes * gpus_per_node) // decode_workers
187+
return total_gpu_count, decode_workers * gpus_per_decode
188+
return total_gpu_count, decode_nodes * gpus_per_node
189+
190+
if decode_nodes == 0 and decode_workers:
191+
explicit_prefill = resources.get("gpus_per_prefill")
192+
if explicit_prefill not in (None, 0):
193+
gpus_per_prefill = int(explicit_prefill)
194+
else:
195+
prefill_nodes = resources.get("prefill_nodes")
196+
prefill_workers = resources.get("prefill_workers")
197+
if prefill_nodes not in (None, 0) and prefill_workers not in (None, 0) and gpus_per_node not in (None, 0):
198+
gpus_per_prefill = (int(prefill_nodes) * int(gpus_per_node)) // int(prefill_workers)
199+
else:
200+
gpus_per_prefill = gpus_per_node
201+
if gpus_per_prefill not in (None, 0):
202+
return total_gpu_count, decode_workers * gpus_per_prefill
203+
204+
return total_gpu_count, None
205+
206+
207+
def _extract_p90_decode_running_requests(log_dir: Path, context: RollupContext) -> int | None:
208+
"""Stream decode logs and compute the nearest-rank P90 of #running-req values."""
209+
resources = context.resources
210+
if context.backend_type != "sglang" or not isinstance(resources, dict):
211+
return None
212+
if int(resources.get("prefill_nodes", 0) or 0) <= 0:
213+
return None
214+
if int(resources.get("decode_nodes", 0) or 0) <= 0:
215+
return None
216+
if int(resources.get("agg_workers", 0) or 0) != 0:
217+
return None
218+
219+
counts: Counter[int] = Counter()
220+
total = 0
221+
222+
for decode_log in sorted(log_dir.glob("*decode*.out")):
223+
try:
224+
with decode_log.open("r", errors="replace") as f:
225+
for line in f:
226+
match = RUNNING_REQ_PATTERN.search(line)
227+
if not match:
228+
continue
229+
value = int(match.group(1))
230+
counts[value] += 1
231+
total += 1
232+
except OSError as exc:
233+
print(f"Failed to read {decode_log}: {exc}", file=sys.stderr)
234+
235+
if total == 0:
236+
return None
237+
238+
rank = math.ceil(total * 0.9)
239+
cumulative = 0
240+
for value in sorted(counts):
241+
cumulative += counts[value]
242+
if cumulative >= rank:
243+
return value
244+
245+
return None
246+
247+
248+
def _safe_ratio(numerator: float | int | None, denominator: float | int | None) -> float | None:
249+
"""Return numerator / denominator when both values are valid and denominator != 0."""
250+
if numerator is None or denominator in (None, 0):
251+
return None
252+
return float(numerator) / float(denominator)
253+
254+
255+
def _format_csv_value(value: object) -> str:
256+
"""Format CSV values with at most three decimal places for numeric fields."""
257+
if value is None:
258+
return ""
259+
if isinstance(value, int):
260+
return str(value)
261+
if isinstance(value, float):
262+
return f"{value:.3f}".rstrip("0").rstrip(".")
263+
return str(value)
264+
265+
266+
def _build_csv_row(
267+
data: dict[str, object],
268+
config_name: str,
269+
gpu_num: int | None,
270+
decode_gpu_count: int | None,
271+
p90_decode_running_requests: int | None,
272+
) -> dict[str, object]:
273+
"""Build one CSV row from a parsed sa-bench result."""
274+
total_token_throughput = data.get("total_token_throughput")
275+
median_tpot = data.get("median_tpot_ms")
276+
row = {
277+
"Config": config_name,
278+
"Total GPU Count": gpu_num,
279+
"Decode GPU Count": decode_gpu_count,
280+
"Concurrency": data.get("max_concurrency"),
281+
"Total Token Throughput": total_token_throughput,
282+
"Output Token Throughput": data.get("output_throughput"),
283+
"Median TTFT": data.get("median_ttft_ms"),
284+
"Median TPOT": median_tpot,
285+
"Median ITL": data.get("median_itl_ms"),
286+
"P90 Decode Running Requests": p90_decode_running_requests,
287+
"Output Token Throughput per User": _safe_ratio(1000.0, median_tpot),
288+
"Total Token Throughput per GPU": _safe_ratio(total_token_throughput, gpu_num),
289+
}
290+
return {key: _format_csv_value(value) for key, value in row.items()}
291+
292+
23293
def main(log_dir: Path) -> None:
24-
"""Generate benchmark-rollup.json from sa-bench result files."""
294+
"""Generate benchmark-rollup.json and benchmark-rollup.csv from sa-bench result files."""
25295
result_files = sorted(log_dir.glob("sa-bench_*/results_*.json"))
26296
if not result_files:
27297
print("No sa-bench results found", file=sys.stderr)
28298
return
29299

30300
runs = []
301+
csv_rows = []
31302
config = {}
303+
context = _load_rollup_context(log_dir)
304+
total_gpu_count, decode_gpu_count = _compute_gpu_counts(context.resources) if context.resources else (None, None)
305+
p90_decode_running_requests = _extract_p90_decode_running_requests(log_dir, context)
32306

33-
for f in result_files:
307+
for result_file in result_files:
34308
try:
35-
data = json.loads(f.read_text())
36-
except json.JSONDecodeError as e:
37-
print(f"Failed to parse {f}: {e}", file=sys.stderr)
309+
data = json.loads(result_file.read_text())
310+
except json.JSONDecodeError as exc:
311+
print(f"Failed to parse {result_file}: {exc}", file=sys.stderr)
38312
continue
39313

40-
# Extract config from first file
41314
if not config:
42315
config = {
43316
"model": data.get("model_id"),
@@ -61,16 +334,38 @@ def main(log_dir: Path) -> None:
61334
"total_output_tokens": data.get("total_output"),
62335
})
63336

337+
csv_rows.append(
338+
_build_csv_row(
339+
data=data,
340+
config_name=context.config_name or str(data.get("model_id") or "unknown"),
341+
gpu_num=total_gpu_count,
342+
decode_gpu_count=decode_gpu_count,
343+
p90_decode_running_requests=p90_decode_running_requests,
344+
)
345+
)
346+
347+
if not runs:
348+
print("No valid sa-bench results found", file=sys.stderr)
349+
return
350+
64351
rollup = {
65352
"benchmark_type": "sa-bench",
66353
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
67354
"config": config,
68355
"runs": runs,
69356
}
70357

71-
output_path = log_dir / "benchmark-rollup.json"
72-
output_path.write_text(json.dumps(rollup, indent=2))
73-
print(f"Wrote {output_path}")
358+
json_path = log_dir / "benchmark-rollup.json"
359+
json_path.write_text(json.dumps(rollup, indent=2))
360+
print(f"Wrote {json_path}")
361+
362+
csv_rows.sort(key=lambda row: int(row["Concurrency"]) if row["Concurrency"] else -1)
363+
csv_path = log_dir / "benchmark-rollup.csv"
364+
with csv_path.open("w", newline="") as csv_file:
365+
writer = csv.DictWriter(csv_file, fieldnames=OUTPUT_FIELDS)
366+
writer.writeheader()
367+
writer.writerows(csv_rows)
368+
print(f"Wrote {csv_path}")
74369

75370

76371
if __name__ == "__main__":

0 commit comments

Comments
 (0)