Skip to content

Add ATOM_CK_FREE=1 switch for CK-free e2e inference#20

Merged
sunway513 merged 3 commits intomainfrom
feat/ck-free-mode
Feb 23, 2026
Merged

Add ATOM_CK_FREE=1 switch for CK-free e2e inference#20
sunway513 merged 3 commits intomainfrom
feat/ck-free-mode

Conversation

@sunway513
Copy link
Owner

@sunway513 sunway513 commented Feb 23, 2026

Summary

  • Add ATOM_CK_FREE=1 environment variable — a single switch that auto-enables all non-CK paths for e2e inference
  • MOE: force Triton/FlyDSL backends when CK-free
  • MHA attention: force Triton attention path (bypasses fused_qk_norm_rope_cache_quant_shuffle)
  • MLA decode: force Triton MLA decode path
  • MLA prefill cache ops handled by AITER-side fallbacks (see Add CK-free fallbacks for cache ops and RoPE auto-detection aiter#27)

Dependencies

Test plan

  • Build AITER with ENABLE_CK=0 pip install -e .
  • Set ATOM_CK_FREE=1 and run DeepSeek inference (MLA path)
  • Set ATOM_CK_FREE=1 and run Llama inference (MHA path)
  • Verify all ops route to non-CK backends via log messages

Single env var that auto-enables all non-CK paths:
- envs.py: Add ATOM_CK_FREE environment variable
- moe.py: Force Triton/FlyDSL MOE when ATOM_CK_FREE=1
- attention_mha.py: Force Triton attention when ATOM_CK_FREE=1
- attention_mla.py: Force Triton MLA decode when ATOM_CK_FREE=1

MLA prefill cache ops (concat_and_cache_mla, fused_qk_rope_concat_and_cache_mla)
are handled by AITER-side fallbacks (PyTorch/Triton) that activate automatically
when the CK module_cache JIT build fails.
Tests env var detection, MOE routing, MHA routing, and MLA routing
conditions without requiring GPU or model weights. 11 tests total.
Comment on lines +10 to +11
import importlib

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
importlib imported but unused

Suggested change
import importlib

Previously use_triton_attn controlled both the cache update strategy and
the paged attention backend. In CK-free builds this forced Triton PA
even though ASM PA is CK-free and faster for decode.

Now: always use Triton fused rope+cache (fast, no module_cache JIT) and
independently select ASM PA for decode when head_dim=128 and no sliding
window. For fp8 KV cache, fill per-token scale buffers with the uniform
per-tensor scale so ASM PA can dequant correctly. Move kv_scale to CUDA
at init for graph capture compatibility.

Benchmark (Llama-3.1-8B, 1k/1k, con64):
  v7 Triton PA bf16 KV: 6,255 tok/s
  v9 ASM PA bf16 KV:    6,712 tok/s (+7.3%)
  v9 ASM PA fp8 KV:     6,830 tok/s (+9.2%)
@sunway513 sunway513 merged commit cf596f1 into main Feb 23, 2026
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant