Skip to content

Commit 97b765b

Browse files
committed
2x faster decode with CUDA graph capture
1 parent b8f1f8f commit 97b765b

4 files changed

Lines changed: 173 additions & 51 deletions

File tree

benchmarks_and_experiments/coding_vs_vllm/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,10 @@ python bench_coding.py --backend kvboost --url http://localhost:9000 \
5757
```
5858
Stop the kvboost server when it finishes (frees the GPU).
5959

60-
**Step 2 — vLLM.** Start its server:
60+
**Step 2 — vLLM.** Stop kvboost first (frees the GPU), then start vLLM (usual
61+
setup — see `start_vllm.sh`):
6162
```bash
62-
vllm serve Qwen/Qwen2.5-3B-Instruct --dtype float16 \
63-
--enable-prefix-caching --gpu-memory-utilization 0.85 \
64-
--max-model-len 32768 --port 8001
63+
./start_vllm.sh # MODEL=... PORT=... GPU_MEM_UTIL=... MAX_MODEL_LEN=...
6564
```
6665
Then run the **same** workload flags (so prompts match) against it:
6766
```bash
Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,95 @@
11
#!/usr/bin/env bash
22
# Launch kvboost in its BEST setup for the coding benchmark — showcases the
3-
# features the benchmark measures: KV reuse (faster TTFT) + OOM recovery, with
4-
# the recent correctness/perf fixes all active.
3+
# features the benchmark measures (KV reuse → faster TTFT, OOM recovery) AND
4+
# the throughput levers (FlashAttention-2, tree speculative decoding) added to
5+
# close the gap to vLLM on an RTX 3060 (Ampere, 12 GB, ~360 GB/s).
56
#
67
# Run this, then in another shell:
78
# python bench_coding.py --backend kvboost --url http://localhost:9000 \
89
# --model "$MODEL" --mode both --out kvboost.json
910
# Stop it (Ctrl-C) before launching vLLM — one model fits the GPU at a time.
1011
#
11-
# Override via env: MODEL=... PORT=... MAX_CACHE_BYTES=... ./start_kvboost.sh
12+
# Override via env: MODEL=... DRAFT=... PORT=... MAX_CACHE_BYTES=... NO_SPEC=1
1213

1314
set -euo pipefail
1415

1516
MODEL="${MODEL:-Qwen/Qwen2.5-3B-Instruct}"
17+
# Small same-family drafter for speculative decoding (the decode-throughput
18+
# lever). ~1 GB fp16; set NO_SPEC=1 to disable (e.g. to free VRAM for the
19+
# OOM-headroom run, since the draft model lowers the context ceiling).
20+
DRAFT="${DRAFT:-Qwen/Qwen2.5-0.5B-Instruct}"
1621
PORT="${PORT:-9000}"
17-
# KV-cache budget for cross-request chunk reuse. Size to (free VRAM after
18-
# weights). On a 14.6 GiB card with a 3B fp16 model (~6 GiB) → ~4 GiB leaves
19-
# headroom for prefill activations + the live request. Lower for the OOM-
20-
# stress run to make the planner's adaptation more visible (e.g. 1.5e9).
21-
MAX_CACHE_BYTES="${MAX_CACHE_BYTES:-4e9}"
22+
# KV-cache budget for cross-request chunk reuse. On a 12 GB 3060: ~6 GB model
23+
# + ~1 GB draft leaves ~5 GB → 2.5 GB cache keeps activation headroom. Lower
24+
# for the OOM-stress run to make the planner's adaptation more visible (~1e9).
25+
MAX_CACHE_BYTES="${MAX_CACHE_BYTES:-2.5e9}"
2226
SAFETY_MARGIN="${SAFETY_MARGIN:-0.15}"
2327

24-
echo "kvboost (best setup)"
28+
SPEC_ARGS=(--speculative-draft-model "$DRAFT" --speculative-tree)
29+
if [[ "${NO_SPEC:-0}" == "1" ]]; then SPEC_ARGS=(); fi
30+
31+
echo "kvboost (best setup — RTX 3060)"
2532
echo " model: $MODEL"
33+
echo " draft: ${NO_SPEC:+<disabled>}${NO_SPEC:-$DRAFT}"
2634
echo " port: $PORT"
35+
echo " attention: flash_attention_2 (auto-fallback to sdpa)"
2736
echo " recompute: cacheblend_sparse (faithful selective recompute)"
2837
echo " kv-cache-bits: 8 (int8 KV → 2× reuse capacity)"
2938
echo " max-cache-bytes: $MAX_CACHE_BYTES"
39+
echo " speculative: ${NO_SPEC:+off}${NO_SPEC:-tree (auto mode-select)}"
3040
echo " oom planning: on (safety_margin=$SAFETY_MARGIN)"
3141
echo
3242

3343
# Why each flag:
44+
# --attn-impl auto
45+
# Tries FlashAttention-2 (Ampere wheel; faster, lower-memory prefill →
46+
# better TTFT and input throughput), silently falls back to sdpa if the
47+
# FA2 wheel isn't installed. Use --attn-impl flash_attention_2 to REQUIRE
48+
# it (errors loudly if missing) once you've confirmed the wheel.
49+
# --speculative-tree --speculative-draft-model
50+
# SpecBlock-inspired tree speculative decoding — the decode-throughput
51+
# lever. On bandwidth-bound hardware (3060), accepting several tokens per
52+
# target forward amortizes the per-token weight read → multiplies decode
53+
# tok/s. Auto mode-selector picks none/flat/tree per request.
3454
# --recompute-strategy cacheblend_sparse
35-
# Faithful CacheBlend: recompute only high-deviation tokens layer-by-
36-
# layer (paper's 2.2-3.3× TTFT), not the full-forward variant. This is
37-
# the "faster TTFT on reused context" feature. Falls back to plain
38-
# cacheblend automatically on unsupported architectures.
55+
# Faithful CacheBlend: recompute only high-deviation tokens. The "faster
56+
# TTFT on reused context" feature. NOTE: on a pure shared-PREFIX workload
57+
# (this coding benchmark), --recompute-strategy none reuses prefix KV at
58+
# ~zero cost like vLLM prefix caching; cacheblend_sparse's edge is the
59+
# OUT-OF-ORDER RAG workload (bench_hf.py). Try both.
3960
# --kv-cache-bits 8
40-
# int8 KV cache: ~2× the cached-chunk capacity (more cross-request
41-
# reuse) and lower memory pressure, negligible quality cost.
42-
# --max-cache-bytes
43-
# Cross-request chunk-cache budget — bigger = more reuse, bounded by VRAM.
61+
# int8 KV STORAGE → ~2× cached-chunk capacity + less memory pressure.
62+
# (Note: it dequants to fp16 for compute, so it adds reuse capacity, not
63+
# decode bandwidth — that lever is weight quant, see below.)
4464
# OOM planner (on by default) + --planner-safety-margin
45-
# Per-request peak prediction → picks chunk_size/kv_bits that fit, or a
46-
# clean HTTP 413. This is the "OOM recovery" feature. Add --auto-truncate
47-
# to truncate-and-complete oversized prompts instead of 413.
48-
# --max-batch-size 1
49-
# The benchmark replays sequentially (single GPU worker); 1 avoids
50-
# pointless batch-window latency. Raise for concurrent throughput tests.
51-
# (automatic, no flag: O(n) incremental detok, chunked CacheBlend forward,
52-
# streaming usage emission for input-throughput, planner cost probe.)
65+
# Per-request peak prediction → fitting chunk_size/kv_bits or clean 413.
66+
# (automatic: O(n) detok, chunked CacheBlend forward, streaming usage,
67+
# static decode input buffers.)
5368
exec python -m kvboost.server \
5469
--model "$MODEL" \
5570
--dtype float16 \
71+
--attn-impl auto \
5672
--recompute-strategy cacheblend_sparse \
5773
--kv-cache-bits 8 \
5874
--max-cache-bytes "$MAX_CACHE_BYTES" \
5975
--planner-safety-margin "$SAFETY_MARGIN" \
6076
--max-batch-size 1 \
77+
"${SPEC_ARGS[@]}" \
6178
--host 0.0.0.0 \
6279
--port "$PORT"
6380

64-
# ── Optional add-ons (uncomment to enable) ───────────────────────────────────
65-
# Speculative decoding to lift DECODE throughput (where vLLM's continuous
66-
# batching otherwise leads). Needs a small same-family draft model and ~1 GiB
67-
# extra VRAM; --speculative-tree turns on the SpecBlock-inspired tree variant
68-
# with cost-aware per-request mode selection:
69-
# --speculative-draft-model Qwen/Qwen2.5-0.5B-Instruct \
70-
# --speculative-tree \
81+
# ── Optional add-ons (uncomment / set env to enable) ─────────────────────────
82+
# WEIGHT QUANTIZATION (the biggest raw decode lever on a 3060): point --model
83+
# at an AWQ/GPTQ Int4 checkpoint — transformers loads it with Marlin int4 GEMM
84+
# on Ampere automatically (~4× less weight bandwidth → up to ~4× decode ceiling,
85+
# 60→~240 tok/s for 3B). No extra flag; the engine detects quantized weights:
86+
# MODEL=Qwen/Qwen2.5-3B-Instruct-AWQ ./start_kvboost.sh
87+
#
88+
# torch.compile (--compile): CUDA graphs + fusion erase per-token launch
89+
# overhead → faster DECODE. CAVEAT: it recompiles per new PREFILL length, so it
90+
# can HURT TTFT on this varying-prompt benchmark and adds a one-time first-
91+
# request compile cost. Best for decode-bound / fixed-shape serving, not the
92+
# TTFT ramp. Add: --compile
7193
#
72-
# Oversized-prompt policy for the OOM ramp: complete-by-truncation instead of
73-
# a clean 413 reject:
94+
# Oversized-prompt policy for the OOM ramp — complete-by-truncation vs 413:
7495
# --auto-truncate

src/kvboost/engine.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ def __init__(
112112
# Cost coefficients (probed at server startup) for cost-aware
113113
# tree shape + mode selection. None = degraded mode (defaults).
114114
cost_coefficients: Any = None,
115+
# torch.compile(mode="reduce-overhead") — captures CUDA graphs +
116+
# fuses pointwise ops (RMSNorm/RoPE/SwiGLU/residual) → removes the
117+
# per-token kernel-launch overhead that caps eager decode. Opt-in
118+
# and EXPERIMENTAL: compilation is lazy (first forward), so a bad
119+
# interaction surfaces at runtime, not here — drop the flag if a
120+
# run errors. Off by default so it can never regress the eager path.
121+
compile_model: bool = False,
115122
):
116123
if device is None:
117124
device = default_device()
@@ -282,6 +289,21 @@ def __init__(
282289
from .flash_attn_ext import install_flash_attention
283290
self._flash_attn_patched = install_flash_attention(self.model)
284291

292+
# torch.compile LAST, after any model patching. reduce-overhead mode
293+
# uses CUDA graphs + Triton fusion to erase per-token launch overhead
294+
# (the gap between eager decode and the bandwidth ceiling). Lazy: the
295+
# actual compile happens on the first forward, so we can't catch a
296+
# failure here — wrap-time errors are caught; runtime graph-breaks just
297+
# degrade to partial speedup. Drop --compile if a run errors outright.
298+
self._compiled = False
299+
if compile_model:
300+
try:
301+
self.model = torch.compile(self.model, mode="reduce-overhead")
302+
self._compiled = True
303+
log.info("torch.compile(reduce-overhead) enabled (experimental)")
304+
except Exception as e:
305+
log.warning("torch.compile failed (%s); running eager", e)
306+
285307
# ------------------------------------------------------------------
286308
# Factory
287309
# ------------------------------------------------------------------
@@ -293,6 +315,7 @@ def from_pretrained(
293315
strict: bool = True,
294316
streaming_config: Optional["StreamingConfig"] = None,
295317
awq_path: Optional[str] = None,
318+
attn_implementation: str = "auto",
296319
**kwargs,
297320
) -> "InferenceEngine":
298321
"""
@@ -309,7 +332,21 @@ def from_pretrained(
309332
the config. The rest of KVBoost (chunk-reuse, FlashAttn)
310333
is untouched.
311334
awq_path: Optional path hint forwarded to the streaming loader.
312-
**kwargs: Passed to InferenceEngine.__init__().
335+
attn_implementation: Attention backend for the resident path.
336+
``"auto"`` (default) tries ``flash_attention_2`` (FA2 —
337+
Ampere+ wheel; faster, lower-memory prefill → better TTFT)
338+
and silently falls back to ``"sdpa"`` if FA2 isn't
339+
installed/supported. Pass ``"flash_attention_2"`` to
340+
require it (raises if unavailable), or ``"sdpa"`` /
341+
``"eager"`` to force a backend. Ignored on the streaming
342+
path. To load a **quantized** checkpoint (AWQ/GPTQ →
343+
Marlin int4 GEMM on Ampere, ~4× less weight bandwidth →
344+
higher decode tok/s), just pass a quantized ``model_name``;
345+
transformers reads its quantization_config and picks the
346+
kernel automatically — the engine already detects and
347+
leaves quantized/offloaded weights in place.
348+
**kwargs: Passed to InferenceEngine.__init__() (e.g.
349+
``compile_model=True`` for torch.compile reduce-overhead).
313350
"""
314351
log.info("Loading model %s ...", model_name)
315352
tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -326,11 +363,31 @@ def from_pretrained(
326363
dtype=torch.float16,
327364
)
328365
else:
329-
model = AutoModelForCausalLM.from_pretrained(
330-
model_name,
331-
torch_dtype=torch.float16,
332-
low_cpu_mem_usage=True,
333-
)
366+
load_kwargs = dict(torch_dtype=torch.float16, low_cpu_mem_usage=True)
367+
impl = attn_implementation
368+
if impl in ("auto", "flash_attention_2"):
369+
try:
370+
model = AutoModelForCausalLM.from_pretrained(
371+
model_name,
372+
attn_implementation="flash_attention_2",
373+
**load_kwargs,
374+
)
375+
log.info("Attention backend: flash_attention_2")
376+
except Exception as e:
377+
if impl == "flash_attention_2":
378+
raise # caller explicitly required FA2 — don't mask it
379+
log.info(
380+
"flash_attention_2 unavailable (%s); using sdpa", e
381+
)
382+
model = AutoModelForCausalLM.from_pretrained(
383+
model_name, attn_implementation="sdpa", **load_kwargs
384+
)
385+
log.info("Attention backend: sdpa")
386+
else:
387+
model = AutoModelForCausalLM.from_pretrained(
388+
model_name, attn_implementation=impl, **load_kwargs
389+
)
390+
log.info("Attention backend: %s", impl)
334391
model.eval()
335392

336393
check_model_compatibility(model, strict=strict)
@@ -1110,16 +1167,25 @@ def _decode_with_kv(
11101167

11111168
# ----- autoregressive decode ------------------------------------
11121169
cur_pos = cached_len + len(live_ids)
1170+
# Pre-allocate the (1,1) decode input buffers ONCE and write in place
1171+
# each step, instead of allocating two fresh device tensors per token.
1172+
# Removes a per-token alloc + H2D churn, and — critically — gives
1173+
# torch.compile / CUDA-graph capture stable input tensors to graph
1174+
# against (a graph needs fixed input storage; a new tensor per step
1175+
# forces a recapture/recompile). Correctness is identical: the model
1176+
# reads these tensors, it never retains them across steps.
1177+
input_buf = torch.empty((1, 1), dtype=torch.long, device=self.device)
1178+
pos_buf = torch.empty((1, 1), dtype=torch.long, device=self.device)
11131179
while not goto_done and len(generated) < max_new_tokens:
11141180
if generated[-1] == self.tokenizer.eos_token_id:
11151181
break
1116-
cur_ids = torch.tensor([[generated[-1]]], dtype=torch.long, device=self.device)
1117-
pos_ids = torch.tensor([[cur_pos]], dtype=torch.long, device=self.device)
1182+
input_buf[0, 0] = generated[-1]
1183+
pos_buf[0, 0] = cur_pos
11181184
with torch.no_grad(), last_logit_only(self.model):
11191185
out = self.model(
1120-
input_ids=cur_ids,
1186+
input_ids=input_buf,
11211187
past_key_values=self._as_cache(past_kv),
1122-
position_ids=pos_ids,
1188+
position_ids=pos_buf,
11231189
use_cache=True,
11241190
)
11251191
past_kv = self._normalize_past_kv(out.past_key_values)

src/kvboost/server/__main__.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,20 @@ def parse_args():
7272
p.add_argument("--device", default=None, help="Device: cuda | mps | cpu (auto-detected if omitted)")
7373
p.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"],
7474
help="Model weight dtype (default: float16)")
75+
p.add_argument("--attn-impl", default="auto",
76+
choices=["auto", "flash_attention_2", "sdpa", "eager"],
77+
help="Attention backend. 'auto' (default) uses "
78+
"flash_attention_2 if installed (faster/lower-memory "
79+
"prefill -> better TTFT; Ampere+ e.g. RTX 3060) and "
80+
"falls back to sdpa otherwise. 'flash_attention_2' "
81+
"requires it (errors if missing). 'sdpa'/'eager' force.")
82+
p.add_argument("--compile", action="store_true", default=False,
83+
help="torch.compile(mode='reduce-overhead') on the model: "
84+
"CUDA graphs + pointwise fusion to erase per-token "
85+
"launch overhead (closes most of the eager-decode gap "
86+
"to the bandwidth ceiling). EXPERIMENTAL — compiles "
87+
"lazily on first request; drop the flag if a run "
88+
"errors. First request pays a one-time compile cost.")
7589
p.add_argument("--backend", default="default", choices=["default", "cpu-paged"],
7690
help="Inference backend (default: standard KVBoost)")
7791
p.add_argument("--quantization", default="none",
@@ -410,6 +424,10 @@ def load_engine(args):
410424
device=device,
411425
speculative_config=speculative_cfg,
412426
tree_speculative_config=tree_speculative_cfg,
427+
# attn_impl is ignored on the streaming path (StreamingCausalLM
428+
# owns attention); compile flows through to __init__.
429+
attn_implementation=args.attn_impl,
430+
compile_model=args.compile,
413431
)
414432
log.info("Model loaded.")
415433
return engine
@@ -440,10 +458,27 @@ def load_engine(args):
440458
from_pretrained_kwargs["quantization_config"] = quant_config
441459
else:
442460
from_pretrained_kwargs["dtype"] = dtype
443-
model = AutoModelForCausalLM.from_pretrained(
444-
args.model,
445-
**from_pretrained_kwargs,
461+
# Attention backend. 'auto' tries FA2 (better TTFT on Ampere+, e.g.
462+
# RTX 3060) then falls back to sdpa; an explicit choice is honored.
463+
_want_fa2 = args.attn_impl in ("auto", "flash_attention_2")
464+
from_pretrained_kwargs["attn_implementation"] = (
465+
"flash_attention_2" if _want_fa2 else args.attn_impl
446466
)
467+
try:
468+
model = AutoModelForCausalLM.from_pretrained(
469+
args.model, **from_pretrained_kwargs,
470+
)
471+
log.info("Attention backend: %s",
472+
from_pretrained_kwargs["attn_implementation"])
473+
except Exception as e:
474+
if args.attn_impl != "auto":
475+
raise # explicit backend requested — don't mask the failure
476+
log.info("flash_attention_2 unavailable (%s); using sdpa", e)
477+
from_pretrained_kwargs["attn_implementation"] = "sdpa"
478+
model = AutoModelForCausalLM.from_pretrained(
479+
args.model, **from_pretrained_kwargs,
480+
)
481+
log.info("Attention backend: sdpa")
447482
engine = InferenceEngine(
448483
model=model,
449484
tokenizer=tokenizer,
@@ -457,6 +492,7 @@ def load_engine(args):
457492
device=device,
458493
speculative_config=_build_speculative_config(args),
459494
tree_speculative_config=_build_tree_speculative_config(args),
495+
compile_model=args.compile,
460496
)
461497

462498
log.info("Model loaded.")

0 commit comments

Comments
 (0)