Skip to content

Commit dbe4aaf

Browse files
committed
Remove cuda graph decode
1 parent 5600a8a commit dbe4aaf

8 files changed

Lines changed: 529 additions & 26 deletions

File tree

benchmarks_and_experiments/coding_vs_vllm/start_kvboost.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ echo
6565
exec python -m kvboost.server \
6666
--model "$MODEL" \
6767
--dtype float16 \
68-
--cuda-graph-decode \
6968
--attn-impl auto \
7069
--recompute-strategy cacheblend_sparse \
7170
--chunk-boundary-window 32\

install_deps.sh

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
# * GPU, no nvcc -> CUDA torch + flash-attn (prebuilt wheel) + FlashInfer
77
# * GPU + nvcc (CUDA 12.x / 13.x) -> full path incl. bundled kernel + flash-attn
88
#
9-
# flash-attn is REQUIRED on any GPU box (it's the prefill backend you want): a
10-
# matching prebuilt wheel is installed when available (no nvcc needed), with a
11-
# source build as fallback. The install fails loudly if it can't be installed,
12-
# so you never silently end up on SDPA. The bundled kernel and FlashInfer stay
13-
# best-effort (the repo falls back to SDPA for those). Use --skip-flash-attn to
14-
# opt out. Every build is time-boxed and logged to install_deps.log.
9+
# The primary accelerated prefill path on a GPU box is now Triton: the 'sage'
10+
# (INT8 SageAttention) and 'triton_flash' (FP16 flash) backends JIT-compile
11+
# through the CUDA driver — no nvcc, no prebuilt-wheel matching, no multi-arch
12+
# source build. Triton ships with the CUDA torch wheel on Linux; we just verify
13+
# it imports. flash-attn is now OPTIONAL and best-effort (NEVER fatal): a
14+
# matching prebuilt wheel is installed when available, with a source build as
15+
# fallback, but if neither works the runtime simply uses SDPA / the Triton
16+
# kernels. The bundled CUDA kernel and FlashInfer are also best-effort. Use
17+
# --skip-flash-attn to opt out. Every build is time-boxed and logged to
18+
# install_deps.log.
1519
#
1620
# Usage
1721
# -----
@@ -288,19 +292,34 @@ if (( CAN_BUILD_EXT == 1 )); then
288292
fi
289293

290294
if [[ "${MODE}" == "cuda" ]]; then
291-
# FlashAttention-2 — REQUIRED (the prefill backend you want). Prebuilt wheel
292-
# first (no nvcc needed), source build as fallback. Fatal if it can't go in.
295+
# Triton — the PRIMARY accelerated kernel path (SageAttention INT8 prefill +
296+
# FP16 'triton_flash'). JIT-compiles via the CUDA driver: no nvcc, no wheel
297+
# matching, no multi-arch source build. Ships with the CUDA torch wheel on
298+
# Linux; verify it imports and install best-effort if somehow missing.
299+
if python -c 'import triton' 2>/dev/null; then
300+
log "Triton present ($(python -c 'import triton; print(triton.__version__)')) — 'sage' / 'triton_flash' backends enabled."
301+
else
302+
warn "Triton not importable (unexpected with a CUDA torch wheel); installing best-effort."
303+
python -m pip install -q triton \
304+
|| warn "triton install failed; 'sage'/'triton_flash' will fall back to SDPA."
305+
fi
306+
307+
# FlashAttention-2 — OPTIONAL / best-effort now (Triton 'sage'/'triton_flash'
308+
# is the recommended prefill path on Ampere). Prebuilt wheel first (no nvcc),
309+
# source build as fallback. NEVER fatal: on failure the runtime uses the
310+
# Triton kernels or SDPA.
293311
if (( SKIP_FLASH_ATTN == 1 )); then
294-
warn "Skipping flash-attn at your request (--skip-flash-attn); prefill uses torch SDPA."
312+
warn "Skipping flash-attn at your request (--skip-flash-attn); use --attn-impl sage / triton_flash, or SDPA."
295313
elif install_flash_attn; then
296314
log "FlashAttention-2 ready ($(python -c 'import flash_attn; print(flash_attn.__version__)'))"
297315
else
298-
fail "flash-attn could not be installed (you asked for it explicitly).
299-
See ${BUILD_LOG} for the exact build error. Most common cause: the torch
300-
pulled from ${TORCH_CUDA_TAG:-the CUDA index} is newer than any published
301-
flash-attn wheel. Fixes:
302-
* pin torch to a release that has wheels: TORCH_SPEC=torch==2.7.1 ./install_deps.sh
303-
* or pin a flash-attn version: FLASH_ATTN_SPEC=flash-attn==2.7.4.post1 ./install_deps.sh"
316+
warn "flash-attn could not be installed (optional). See ${BUILD_LOG}.
317+
This is fine — the recommended path no longer needs it: run the server with
318+
--attn-impl sage (INT8 SageAttention prefill via Triton) or --attn-impl
319+
triton_flash (FP16). To install flash-attn anyway, the usual fixes are:
320+
* pin torch to a release with wheels: TORCH_SPEC=torch==2.7.1 ./install_deps.sh
321+
* or pin a flash-attn version: FLASH_ATTN_SPEC=flash-attn==2.7.4.post1 ./install_deps.sh
322+
* or limit the source build to Ampere: FLASH_ATTN_CUDA_ARCHS=80 ./install_deps.sh"
304323
fi
305324

306325
# FlashInfer (decode attention) — JIT, best-effort, works without nvcc.
@@ -355,8 +374,16 @@ def have(mod):
355374
356375
fa2 = have("flash_attn")
357376
fi = have("flashinfer")
377+
tri = have("triton")
358378
kern = have("kvboost._flash_attn_cuda")
379+
try:
380+
from kvboost.kernels import sage_available
381+
sage = sage_available()
382+
except Exception:
383+
sage = False
359384
print(f" info prefill backend : {'flash_attention_2' if fa2 else 'torch SDPA (flash-attn not installed)'}")
385+
print(f" info sage/triton flash: {'available (--attn-impl sage | triton_flash)' if sage else 'unavailable (triton missing → SDPA)'}")
386+
print(f" info triton : {'present' if tri else 'absent'}")
360387
print(f" info decode backend : {'flashinfer' if fi else 'torch SDPA (flashinfer not installed)'}")
361388
print(f" info bundled kernel : {'kvboost._flash_attn_cuda' if kern else 'not built (SDPA patch path)'}")
362389

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ dev = [
4040
]
4141
cuda = [
4242
"ninja>=1.11",
43+
# Triton backs the 'sage' (INT8 SageAttention) and 'triton_flash' kernels.
44+
# JIT-compiled via the CUDA driver — no nvcc, no flash-attn-style wheel
45+
# build. Ships with the CUDA torch wheel on Linux; pinned here so it's
46+
# explicit. (Linux-only: Triton has no macOS/Windows wheels.)
47+
"triton>=2.1 ; platform_system=='Linux'",
4348
]
4449
streaming = [
4550
"safetensors>=0.4",

src/kvboost/engine.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,14 @@ def from_pretrained(
371371
and silently falls back to ``"sdpa"`` if FA2 isn't
372372
installed/supported. Pass ``"flash_attention_2"`` to
373373
require it (raises if unavailable), or ``"sdpa"`` /
374-
``"eager"`` to force a backend. Ignored on the streaming
374+
``"eager"`` to force a backend. ``"sage"`` runs INT8
375+
SageAttention on prefill via a Triton kernel (no nvcc /
376+
no flash-attn build needed; INT8 tensor-core QKᵀ on
377+
Ampere, SDPA fallback for decode); ``"triton_flash"`` is
378+
the FP16 Triton flash baseline; ``"flashinfer"`` routes
379+
decode attention through FlashInfer. Each JIT/optional
380+
backend self-checks against SDPA on first use and
381+
disables itself on mismatch. Ignored on the streaming
375382
path. To load a **quantized** checkpoint (AWQ/GPTQ →
376383
Marlin int4 GEMM on Ampere, ~4× less weight bandwidth →
377384
higher decode tok/s), just pass a quantized ``model_name``;

src/kvboost/kernels/__init__.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,60 @@
1-
"""Proven external inference kernels kvboost routes to at runtime.
1+
"""Proven external/JIT inference kernels kvboost routes to at runtime.
22
3-
Currently: FlashInfer decode-attention (see ``flashinfer_attn``). Each kernel
4-
is gated on availability + a numerical self-check, and falls back to PyTorch
5-
SDPA so a missing or misbehaving kernel never corrupts output.
3+
* FlashInfer decode-attention (``flashinfer_attn``) — ``--attn-impl flashinfer``.
4+
* SageAttention INT8 prefill + FP16 Triton flash (``sage_attn``) —
5+
``--attn-impl sage`` / ``--attn-impl triton_flash``.
6+
7+
Each kernel is gated on availability + a one-time numerical self-check, and
8+
falls back to PyTorch SDPA so a missing or misbehaving kernel never corrupts
9+
output. ``resolve_attn_impl`` registers the requested backend with HuggingFace
10+
(if its dependency is present) before model load, else downgrades to ``sdpa``.
611
"""
7-
from .flashinfer_attn import (
8-
flashinfer_available,
9-
install_flashinfer_attention,
10-
resolve_attn_impl,
12+
import logging
13+
14+
from .flashinfer_attn import flashinfer_available, install_flashinfer_attention
15+
from .sage_attn import (
16+
install_sage_attention,
17+
sage_attention_forward,
18+
sage_available,
19+
triton_available,
20+
triton_flash_attention_forward,
1121
)
1222

23+
_log = logging.getLogger("kvboost.kernels")
24+
25+
26+
def resolve_attn_impl(requested: str) -> str:
27+
"""Map a requested attn-impl to one HF can actually load.
28+
29+
Registers the backend with HuggingFace if its dependency is importable,
30+
otherwise falls back to ``"sdpa"`` with a warning. ``"auto"`` and stock
31+
impls (``"sdpa"``, ``"eager"``, ``"flash_attention_2"``) pass through.
32+
"""
33+
if requested == "flashinfer":
34+
if install_flashinfer_attention():
35+
return "flashinfer"
36+
_log.warning("attn-impl 'flashinfer' requested but unavailable; using sdpa.")
37+
return "sdpa"
38+
39+
if requested in ("sage", "triton_flash"):
40+
if install_sage_attention():
41+
return requested
42+
_log.warning(
43+
"attn-impl '%s' requested but Triton is unavailable; using sdpa.",
44+
requested,
45+
)
46+
return "sdpa"
47+
48+
return requested
49+
50+
1351
__all__ = [
1452
"flashinfer_available",
1553
"install_flashinfer_attention",
54+
"install_sage_attention",
55+
"sage_attention_forward",
56+
"sage_available",
57+
"triton_available",
58+
"triton_flash_attention_forward",
1659
"resolve_attn_impl",
1760
]

0 commit comments

Comments
 (0)