|
| 1 | +""" |
| 2 | +Generate benchmark figures from experiment JSON. |
| 3 | +Usage: python plot_benchmarks.py [path/to/experiment.json] |
| 4 | +Saves figures to docs/figures/. |
| 5 | +""" |
| 6 | + |
| 7 | +import json |
| 8 | +import sys |
| 9 | +import os |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import matplotlib |
| 14 | +matplotlib.use("Agg") |
| 15 | +import matplotlib.pyplot as plt |
| 16 | +import matplotlib.patches as mpatches |
| 17 | +from matplotlib.ticker import MultipleLocator |
| 18 | + |
| 19 | +# ── paths ────────────────────────────────────────────────────────────────────── |
| 20 | +REPO_ROOT = Path(__file__).resolve().parents[2] |
| 21 | +RESULTS_DIR = Path(__file__).parent / "results" |
| 22 | +FIGURES_DIR = REPO_ROOT / "docs" / "figures" |
| 23 | +FIGURES_DIR.mkdir(parents=True, exist_ok=True) |
| 24 | + |
| 25 | +COLORS = { |
| 26 | + "kvboost": "#4C72B0", |
| 27 | + "vllm_prefixcache": "#DD8452", |
| 28 | + "baseline": "#55A868", |
| 29 | + "cold": "#5A9ECC", |
| 30 | + "warm": "#E87041", |
| 31 | + "neutral": "#888888", |
| 32 | +} |
| 33 | +LABELS = { |
| 34 | + "kvboost": "KVBoost", |
| 35 | + "vllm_prefixcache": "vLLM (prefix cache)", |
| 36 | + "baseline": "Baseline (HF)", |
| 37 | +} |
| 38 | + |
| 39 | +STYLE = dict(dpi=150, facecolor="white") |
| 40 | + |
| 41 | + |
| 42 | +# ── helpers ──────────────────────────────────────────────────────────────────── |
| 43 | + |
| 44 | +def _load_experiment(path=None): |
| 45 | + if path: |
| 46 | + return json.load(open(path)) |
| 47 | + files = sorted(RESULTS_DIR.glob("experiment_*.json")) |
| 48 | + if not files: |
| 49 | + raise FileNotFoundError("No experiment JSON found in results/") |
| 50 | + return json.load(open(files[-1])) |
| 51 | + |
| 52 | + |
| 53 | +def _save(fig, name): |
| 54 | + p = FIGURES_DIR / name |
| 55 | + fig.savefig(p, **STYLE, bbox_inches="tight") |
| 56 | + plt.close(fig) |
| 57 | + print(f" saved {p.relative_to(REPO_ROOT)}") |
| 58 | + return p |
| 59 | + |
| 60 | + |
| 61 | +# ── Figure 1: COLD vs WARM TTFT bar chart ───────────────────────────────────── |
| 62 | + |
| 63 | +def fig_cold_warm_ttft(data): |
| 64 | + backends = ["baseline", "vllm_prefixcache", "kvboost"] |
| 65 | + cold_ms = [data["results"][b]["latency_stats"]["ttft_ms_cold"]["mean"] for b in backends] |
| 66 | + warm_ms = [data["results"][b]["latency_stats"]["ttft_ms_warm"]["mean"] for b in backends] |
| 67 | + |
| 68 | + x = np.arange(len(backends)) |
| 69 | + w = 0.35 |
| 70 | + |
| 71 | + fig, ax = plt.subplots(figsize=(8, 4.5)) |
| 72 | + bars_cold = ax.bar(x - w/2, cold_ms, w, label="COLD (no cache)", color=COLORS["cold"], edgecolor="white", linewidth=0.5) |
| 73 | + bars_warm = ax.bar(x + w/2, warm_ms, w, label="WARM (cached)", color=COLORS["warm"], edgecolor="white", linewidth=0.5) |
| 74 | + |
| 75 | + # value labels |
| 76 | + for bar in list(bars_cold) + list(bars_warm): |
| 77 | + h = bar.get_height() |
| 78 | + ax.text(bar.get_x() + bar.get_width()/2, h + 8, f"{h:.0f}", ha="center", va="bottom", fontsize=9) |
| 79 | + |
| 80 | + ax.set_xticks(x) |
| 81 | + ax.set_xticklabels([LABELS[b] for b in backends], fontsize=11) |
| 82 | + ax.set_ylabel("Time to First Token (ms)", fontsize=11) |
| 83 | + ax.set_title("COLD vs WARM Time to First Token\n(lower is better)", fontsize=12, fontweight="bold") |
| 84 | + ax.legend(fontsize=10) |
| 85 | + ax.set_ylim(0, max(cold_ms) * 1.20) |
| 86 | + ax.yaxis.set_minor_locator(MultipleLocator(50)) |
| 87 | + ax.grid(axis="y", linestyle="--", alpha=0.4) |
| 88 | + ax.spines[["top", "right"]].set_visible(False) |
| 89 | + |
| 90 | + # speedup annotations for warm bars |
| 91 | + baseline_warm = warm_ms[0] |
| 92 | + for i, (b, wm) in enumerate(zip(backends[1:], warm_ms[1:]), 1): |
| 93 | + speedup = baseline_warm / wm |
| 94 | + ax.annotate(f"{speedup:.1f}×\nvs baseline", |
| 95 | + xy=(x[i] + w/2, wm), |
| 96 | + xytext=(x[i] + w/2 + 0.05, wm + 80), |
| 97 | + fontsize=8, color="#333333", |
| 98 | + arrowprops=dict(arrowstyle="->", color="#333333", lw=0.8)) |
| 99 | + |
| 100 | + fig.tight_layout() |
| 101 | + return _save(fig, "cold_warm_ttft.png") |
| 102 | + |
| 103 | + |
| 104 | +# ── Figure 2: TTFT distribution (CDF) ───────────────────────────────────────── |
| 105 | + |
| 106 | +def fig_ttft_cdf(data): |
| 107 | + fig, axes = plt.subplots(1, 2, figsize=(11, 4.5), sharey=False) |
| 108 | + |
| 109 | + for ax, qtype in zip(axes, ["COLD", "WARM"]): |
| 110 | + key = "ttft_ms_cold" if qtype == "COLD" else "ttft_ms_warm" |
| 111 | + for b in ["baseline", "vllm_prefixcache", "kvboost"]: |
| 112 | + samples = data["results"][b]["latency_samples"] |
| 113 | + vals = sorted(s["ttft_ms"] for s in samples if s["query_type"] == qtype) |
| 114 | + n = len(vals) |
| 115 | + cdf = np.arange(1, n + 1) / n |
| 116 | + ax.plot(vals, cdf * 100, label=LABELS[b], color=COLORS[b], linewidth=1.8) |
| 117 | + |
| 118 | + ax.set_xlabel("TTFT (ms)", fontsize=10) |
| 119 | + ax.set_ylabel("Percentile", fontsize=10) |
| 120 | + ax.set_title(f"{qtype} queries — TTFT CDF", fontsize=11, fontweight="bold") |
| 121 | + ax.legend(fontsize=9) |
| 122 | + ax.set_ylim(0, 101) |
| 123 | + ax.grid(linestyle="--", alpha=0.4) |
| 124 | + ax.spines[["top", "right"]].set_visible(False) |
| 125 | + |
| 126 | + fig.suptitle("Cumulative Distribution of TTFT by Query Type", fontsize=12, fontweight="bold", y=1.02) |
| 127 | + fig.tight_layout() |
| 128 | + return _save(fig, "ttft_cdf.png") |
| 129 | + |
| 130 | + |
| 131 | +# ── Figure 3: TTFT by context-length bucket ──────────────────────────────────── |
| 132 | + |
| 133 | +def fig_ttft_by_bucket(data): |
| 134 | + BUCKETS = [(0, 512), (512, 1024), (1024, 2048), (2048, 4096), (4096, 99999)] |
| 135 | + BUCKET_LABELS = ["0–512", "512–1K", "1K–2K", "2K–4K", "4K+"] |
| 136 | + backends = ["baseline", "vllm_prefixcache", "kvboost"] |
| 137 | + |
| 138 | + def mean_ttft(samples, lo, hi, qtype): |
| 139 | + vals = [s["ttft_ms"] for s in samples |
| 140 | + if lo <= s["context_length"] < hi and s["query_type"] == qtype] |
| 141 | + return np.mean(vals) if vals else None |
| 142 | + |
| 143 | + fig, axes = plt.subplots(1, 2, figsize=(13, 5), sharey=False) |
| 144 | + |
| 145 | + for ax, qtype in zip(axes, ["COLD", "WARM"]): |
| 146 | + x = np.arange(len(BUCKETS)) |
| 147 | + width = 0.25 |
| 148 | + offsets = np.linspace(-width, width, len(backends)) |
| 149 | + |
| 150 | + for i, b in enumerate(backends): |
| 151 | + samples = data["results"][b]["latency_samples"] |
| 152 | + means = [mean_ttft(samples, lo, hi, qtype) for lo, hi in BUCKETS] |
| 153 | + valid_x = [xi + offsets[i] for xi, m in zip(x, means) if m is not None] |
| 154 | + valid_m = [m for m in means if m is not None] |
| 155 | + valid_labels = [BUCKET_LABELS[j] for j, m in enumerate(means) if m is not None] |
| 156 | + bars = ax.bar(valid_x, valid_m, width * 0.85, label=LABELS[b], |
| 157 | + color=COLORS[b], edgecolor="white", linewidth=0.5) |
| 158 | + |
| 159 | + ax.set_xticks(x) |
| 160 | + ax.set_xticklabels(BUCKET_LABELS, fontsize=10) |
| 161 | + ax.set_xlabel("Context length (tokens)", fontsize=10) |
| 162 | + ax.set_ylabel("Mean TTFT (ms)", fontsize=10) |
| 163 | + ax.set_title(f"{qtype} — TTFT by context length", fontsize=11, fontweight="bold") |
| 164 | + ax.legend(fontsize=9) |
| 165 | + ax.grid(axis="y", linestyle="--", alpha=0.4) |
| 166 | + ax.spines[["top", "right"]].set_visible(False) |
| 167 | + |
| 168 | + fig.suptitle("TTFT by Context-Length Bucket", fontsize=12, fontweight="bold", y=1.02) |
| 169 | + fig.tight_layout() |
| 170 | + return _save(fig, "ttft_by_bucket.png") |
| 171 | + |
| 172 | + |
| 173 | +# ── Figure 4: KV reuse distribution ─────────────────────────────────────────── |
| 174 | + |
| 175 | +def fig_kv_reuse(data): |
| 176 | + samples = data["results"]["kvboost"]["latency_samples"] |
| 177 | + warm_reuse = [s["cache_reuse_ratio"] * 100 |
| 178 | + for s in samples if s["query_type"] == "WARM"] |
| 179 | + |
| 180 | + BUCKET_EDGES = [0, 20, 40, 60, 80, 100] |
| 181 | + counts, _ = np.histogram(warm_reuse, bins=BUCKET_EDGES) |
| 182 | + pcts = counts / len(warm_reuse) * 100 |
| 183 | + labels = ["0–20%", "20–40%", "40–60%", "60–80%", "80–100%"] |
| 184 | + |
| 185 | + fig, ax = plt.subplots(figsize=(7, 4.5)) |
| 186 | + bars = ax.bar(labels, pcts, color=COLORS["kvboost"], edgecolor="white", linewidth=0.5, zorder=3) |
| 187 | + |
| 188 | + for bar, pct in zip(bars, pcts): |
| 189 | + ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, |
| 190 | + f"{pct:.0f}%", ha="center", va="bottom", fontsize=10, fontweight="bold") |
| 191 | + |
| 192 | + avg = np.mean(warm_reuse) |
| 193 | + ax.annotate(f"Mean reuse: {avg:.1f}%", |
| 194 | + xy=(3.5, max(pcts) * 0.85), fontsize=11, color="#333333", |
| 195 | + ha="center", |
| 196 | + bbox=dict(boxstyle="round,pad=0.3", facecolor="#EEF3FA", edgecolor="#4C72B0", linewidth=1)) |
| 197 | + |
| 198 | + ax.set_xlabel("KV Cache Reuse (%)", fontsize=11) |
| 199 | + ax.set_ylabel("Share of warm queries (%)", fontsize=11) |
| 200 | + ax.set_title("KV Cache Reuse Distribution\n(KVBoost, warm queries)", fontsize=12, fontweight="bold") |
| 201 | + ax.set_ylim(0, max(pcts) * 1.25) |
| 202 | + ax.grid(axis="y", linestyle="--", alpha=0.4, zorder=0) |
| 203 | + ax.spines[["top", "right"]].set_visible(False) |
| 204 | + |
| 205 | + fig.tight_layout() |
| 206 | + return _save(fig, "kv_reuse_distribution.png") |
| 207 | + |
| 208 | + |
| 209 | +# ── Figure 5: Speedup summary ────────────────────────────────────────────────── |
| 210 | + |
| 211 | +def fig_speedup_summary(data): |
| 212 | + baseline_mean = data["results"]["baseline"]["latency_stats"]["ttft_ms_overall"]["mean"] |
| 213 | + baseline_cold = data["results"]["baseline"]["latency_stats"]["ttft_ms_cold"]["mean"] |
| 214 | + baseline_warm = data["results"]["baseline"]["latency_stats"]["ttft_ms_warm"]["mean"] |
| 215 | + |
| 216 | + rows = [] |
| 217 | + for b in ["kvboost", "vllm_prefixcache"]: |
| 218 | + ls = data["results"][b]["latency_stats"] |
| 219 | + rows.append(( |
| 220 | + LABELS[b], |
| 221 | + baseline_mean / ls["ttft_ms_overall"]["mean"], |
| 222 | + baseline_cold / ls["ttft_ms_cold"]["mean"], |
| 223 | + baseline_warm / ls["ttft_ms_warm"]["mean"], |
| 224 | + )) |
| 225 | + |
| 226 | + categories = ["Overall", "COLD", "WARM"] |
| 227 | + x = np.arange(len(categories)) |
| 228 | + w = 0.35 |
| 229 | + offsets = [-w/2, w/2] |
| 230 | + |
| 231 | + fig, ax = plt.subplots(figsize=(8, 4.5)) |
| 232 | + for (label, ov, co, wa), off, b in zip(rows, offsets, ["kvboost", "vllm_prefixcache"]): |
| 233 | + vals = [ov, co, wa] |
| 234 | + bars = ax.bar(x + off, vals, w, label=label, color=COLORS[b], edgecolor="white", linewidth=0.5) |
| 235 | + for bar, v in zip(bars, vals): |
| 236 | + ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05, |
| 237 | + f"{v:.1f}×", ha="center", va="bottom", fontsize=10, fontweight="bold") |
| 238 | + |
| 239 | + ax.axhline(1.0, color="gray", linestyle="--", linewidth=1, label="Baseline (1×)") |
| 240 | + ax.set_xticks(x) |
| 241 | + ax.set_xticklabels(categories, fontsize=12) |
| 242 | + ax.set_ylabel("Speedup vs Baseline", fontsize=11) |
| 243 | + ax.set_title("TTFT Speedup vs HuggingFace Baseline\n(higher is better)", fontsize=12, fontweight="bold") |
| 244 | + ax.legend(fontsize=10) |
| 245 | + ax.set_ylim(0, max(baseline_warm / data["results"]["kvboost"]["latency_stats"]["ttft_ms_warm"]["mean"], |
| 246 | + baseline_warm / data["results"]["vllm_prefixcache"]["latency_stats"]["ttft_ms_warm"]["mean"]) * 1.25) |
| 247 | + ax.grid(axis="y", linestyle="--", alpha=0.4) |
| 248 | + ax.spines[["top", "right"]].set_visible(False) |
| 249 | + |
| 250 | + fig.tight_layout() |
| 251 | + return _save(fig, "speedup_summary.png") |
| 252 | + |
| 253 | + |
| 254 | +# ── Figure 6: Accuracy vs reuse rate (scatter) ──────────────────────────────── |
| 255 | + |
| 256 | +def fig_accuracy_vs_reuse(data): |
| 257 | + """Show per-pair accuracy grouped by reuse bucket for KVBoost warm queries.""" |
| 258 | + acc_samples = data["results"]["kvboost"]["accuracy_samples"] |
| 259 | + lat_samples = data["results"]["kvboost"]["latency_samples"] |
| 260 | + |
| 261 | + # build reuse map: pair_group -> reuse_ratio |
| 262 | + reuse_map = {} |
| 263 | + for s in lat_samples: |
| 264 | + if s["query_type"] == "WARM": |
| 265 | + reuse_map[s["pair_group"]] = s["cache_reuse_ratio"] * 100 |
| 266 | + |
| 267 | + # bucket and measure accuracy |
| 268 | + BUCKET_EDGES = [0, 20, 40, 60, 80, 100.001] |
| 269 | + BUCKET_LABELS = ["0–20%", "20–40%", "40–60%", "60–80%", "80–100%"] |
| 270 | + |
| 271 | + correct_by_bucket = [[] for _ in BUCKET_LABELS] |
| 272 | + for s in acc_samples: |
| 273 | + if s["query_type"] == "WARM" and s["pair_group"] in reuse_map: |
| 274 | + r = reuse_map[s["pair_group"]] |
| 275 | + for i, (lo, hi) in enumerate(zip(BUCKET_EDGES, BUCKET_EDGES[1:])): |
| 276 | + if lo <= r < hi: |
| 277 | + correct_by_bucket[i].append(s["correct"]) |
| 278 | + break |
| 279 | + |
| 280 | + accs = [100 * np.mean(c) if c else None for c in correct_by_bucket] |
| 281 | + counts = [len(c) for c in correct_by_bucket] |
| 282 | + |
| 283 | + fig, ax = plt.subplots(figsize=(8, 4.5)) |
| 284 | + valid = [(l, a, c) for l, a, c in zip(BUCKET_LABELS, accs, counts) if a is not None] |
| 285 | + labels_v, accs_v, counts_v = zip(*valid) if valid else ([], [], []) |
| 286 | + |
| 287 | + bars = ax.bar(labels_v, accs_v, color=COLORS["kvboost"], edgecolor="white", linewidth=0.5, zorder=3) |
| 288 | + for bar, acc, cnt in zip(bars, accs_v, counts_v): |
| 289 | + ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05, |
| 290 | + f"{acc:.1f}%\n(n={cnt})", ha="center", va="bottom", fontsize=9) |
| 291 | + |
| 292 | + ax.axhline(data["results"]["baseline"]["accuracy_stats"]["accuracy_warm"] * 100, |
| 293 | + color=COLORS["baseline"], linestyle="--", linewidth=1.5, label="Baseline warm acc.") |
| 294 | + ax.axhline(data["results"]["kvboost"]["accuracy_stats"]["accuracy_cold"] * 100, |
| 295 | + color=COLORS["kvboost"], linestyle=":", linewidth=1.5, label="KVBoost cold acc.") |
| 296 | + |
| 297 | + ax.set_xlabel("KV Cache Reuse (%)", fontsize=11) |
| 298 | + ax.set_ylabel("Accuracy (%)", fontsize=11) |
| 299 | + ax.set_title("KVBoost Accuracy by Reuse Level\n(warm queries)", fontsize=12, fontweight="bold") |
| 300 | + ax.legend(fontsize=9) |
| 301 | + ax.set_ylim(90, 102) |
| 302 | + ax.grid(axis="y", linestyle="--", alpha=0.4, zorder=0) |
| 303 | + ax.spines[["top", "right"]].set_visible(False) |
| 304 | + |
| 305 | + fig.tight_layout() |
| 306 | + return _save(fig, "accuracy_vs_reuse.png") |
| 307 | + |
| 308 | + |
| 309 | +# ── main ─────────────────────────────────────────────────────────────────────── |
| 310 | + |
| 311 | +def main(): |
| 312 | + path = sys.argv[1] if len(sys.argv) > 1 else None |
| 313 | + data = _load_experiment(path) |
| 314 | + model = data.get("model", "") |
| 315 | + n = data.get("n_samples", "?") |
| 316 | + print(f"Loaded experiment: model={model}, n_samples={n}") |
| 317 | + print(f"Saving figures to {FIGURES_DIR.relative_to(REPO_ROOT)}/") |
| 318 | + |
| 319 | + fig_cold_warm_ttft(data) |
| 320 | + fig_ttft_cdf(data) |
| 321 | + fig_ttft_by_bucket(data) |
| 322 | + fig_kv_reuse(data) |
| 323 | + fig_speedup_summary(data) |
| 324 | + fig_accuracy_vs_reuse(data) |
| 325 | + |
| 326 | + print("Done.") |
| 327 | + |
| 328 | + |
| 329 | +if __name__ == "__main__": |
| 330 | + main() |
0 commit comments