|
6 | 6 | # * GPU, no nvcc -> CUDA torch + flash-attn (prebuilt wheel) + FlashInfer |
7 | 7 | # * GPU + nvcc (CUDA 12.x / 13.x) -> full path incl. bundled kernel + flash-attn |
8 | 8 | # |
9 | | -# flash-attn is REQUIRED on any GPU box (it's the prefill backend you want): a |
10 | | -# matching prebuilt wheel is installed when available (no nvcc needed), with a |
11 | | -# source build as fallback. The install fails loudly if it can't be installed, |
12 | | -# so you never silently end up on SDPA. The bundled kernel and FlashInfer stay |
13 | | -# best-effort (the repo falls back to SDPA for those). Use --skip-flash-attn to |
14 | | -# opt out. Every build is time-boxed and logged to install_deps.log. |
| 9 | +# The primary accelerated prefill path on a GPU box is now Triton: the 'sage' |
| 10 | +# (INT8 SageAttention) and 'triton_flash' (FP16 flash) backends JIT-compile |
| 11 | +# through the CUDA driver — no nvcc, no prebuilt-wheel matching, no multi-arch |
| 12 | +# source build. Triton ships with the CUDA torch wheel on Linux; we just verify |
| 13 | +# it imports. flash-attn is now OPTIONAL and best-effort (NEVER fatal): a |
| 14 | +# matching prebuilt wheel is installed when available, with a source build as |
| 15 | +# fallback, but if neither works the runtime simply uses SDPA / the Triton |
| 16 | +# kernels. The bundled CUDA kernel and FlashInfer are also best-effort. Use |
| 17 | +# --skip-flash-attn to opt out. Every build is time-boxed and logged to |
| 18 | +# install_deps.log. |
15 | 19 | # |
16 | 20 | # Usage |
17 | 21 | # ----- |
@@ -288,19 +292,34 @@ if (( CAN_BUILD_EXT == 1 )); then |
288 | 292 | fi |
289 | 293 |
|
290 | 294 | if [[ "${MODE}" == "cuda" ]]; then |
291 | | - # FlashAttention-2 — REQUIRED (the prefill backend you want). Prebuilt wheel |
292 | | - # first (no nvcc needed), source build as fallback. Fatal if it can't go in. |
| 295 | + # Triton — the PRIMARY accelerated kernel path (SageAttention INT8 prefill + |
| 296 | + # FP16 'triton_flash'). JIT-compiles via the CUDA driver: no nvcc, no wheel |
| 297 | + # matching, no multi-arch source build. Ships with the CUDA torch wheel on |
| 298 | + # Linux; verify it imports and install best-effort if somehow missing. |
| 299 | + if python -c 'import triton' 2>/dev/null; then |
| 300 | + log "Triton present ($(python -c 'import triton; print(triton.__version__)')) — 'sage' / 'triton_flash' backends enabled." |
| 301 | + else |
| 302 | + warn "Triton not importable (unexpected with a CUDA torch wheel); installing best-effort." |
| 303 | + python -m pip install -q triton \ |
| 304 | + || warn "triton install failed; 'sage'/'triton_flash' will fall back to SDPA." |
| 305 | + fi |
| 306 | + |
| 307 | + # FlashAttention-2 — OPTIONAL / best-effort now (Triton 'sage'/'triton_flash' |
| 308 | + # is the recommended prefill path on Ampere). Prebuilt wheel first (no nvcc), |
| 309 | + # source build as fallback. NEVER fatal: on failure the runtime uses the |
| 310 | + # Triton kernels or SDPA. |
293 | 311 | if (( SKIP_FLASH_ATTN == 1 )); then |
294 | | - warn "Skipping flash-attn at your request (--skip-flash-attn); prefill uses torch SDPA." |
| 312 | + warn "Skipping flash-attn at your request (--skip-flash-attn); use --attn-impl sage / triton_flash, or SDPA." |
295 | 313 | elif install_flash_attn; then |
296 | 314 | log "FlashAttention-2 ready ($(python -c 'import flash_attn; print(flash_attn.__version__)'))" |
297 | 315 | else |
298 | | - fail "flash-attn could not be installed (you asked for it explicitly). |
299 | | - See ${BUILD_LOG} for the exact build error. Most common cause: the torch |
300 | | - pulled from ${TORCH_CUDA_TAG:-the CUDA index} is newer than any published |
301 | | - flash-attn wheel. Fixes: |
302 | | - * pin torch to a release that has wheels: TORCH_SPEC=torch==2.7.1 ./install_deps.sh |
303 | | - * or pin a flash-attn version: FLASH_ATTN_SPEC=flash-attn==2.7.4.post1 ./install_deps.sh" |
| 316 | + warn "flash-attn could not be installed (optional). See ${BUILD_LOG}. |
| 317 | + This is fine — the recommended path no longer needs it: run the server with |
| 318 | + --attn-impl sage (INT8 SageAttention prefill via Triton) or --attn-impl |
| 319 | + triton_flash (FP16). To install flash-attn anyway, the usual fixes are: |
| 320 | + * pin torch to a release with wheels: TORCH_SPEC=torch==2.7.1 ./install_deps.sh |
| 321 | + * or pin a flash-attn version: FLASH_ATTN_SPEC=flash-attn==2.7.4.post1 ./install_deps.sh |
| 322 | + * or limit the source build to Ampere: FLASH_ATTN_CUDA_ARCHS=80 ./install_deps.sh" |
304 | 323 | fi |
305 | 324 |
|
306 | 325 | # FlashInfer (decode attention) — JIT, best-effort, works without nvcc. |
@@ -355,8 +374,16 @@ def have(mod): |
355 | 374 |
|
356 | 375 | fa2 = have("flash_attn") |
357 | 376 | fi = have("flashinfer") |
| 377 | +tri = have("triton") |
358 | 378 | kern = have("kvboost._flash_attn_cuda") |
| 379 | +try: |
| 380 | + from kvboost.kernels import sage_available |
| 381 | + sage = sage_available() |
| 382 | +except Exception: |
| 383 | + sage = False |
359 | 384 | print(f" info prefill backend : {'flash_attention_2' if fa2 else 'torch SDPA (flash-attn not installed)'}") |
| 385 | +print(f" info sage/triton flash: {'available (--attn-impl sage | triton_flash)' if sage else 'unavailable (triton missing → SDPA)'}") |
| 386 | +print(f" info triton : {'present' if tri else 'absent'}") |
360 | 387 | print(f" info decode backend : {'flashinfer' if fi else 'torch SDPA (flashinfer not installed)'}") |
361 | 388 | print(f" info bundled kernel : {'kvboost._flash_attn_cuda' if kern else 'not built (SDPA patch path)'}") |
362 | 389 |
|
|
0 commit comments