Open
Description
Background
We need two sets of kernels for MLA:
- self-attention on ragged tensor, w/o matrix absorption:
head_dim_qk=192, head_dim_vo=128
- cross-attention on paged-kv cache, w/ matrix absorption:
head_dim_qk=576, head_dim_vo=512
(K=V)
and serving engines are expected to use different kernels according to use cases:
- For decoding, use 2
- For prefilling (w/o prefix-caching), use 1
- For incremental prefilling/chunked-prefill, use the 1+2:
o_1, lse_1 = cross_attention(c_q, q_pe, c_kv)
(c_q: (n, 128, 512), q_pe: (n, 128, 64), c_kv: (n_kv, 576), o_1: (n, 128, 512), lse_1: (n, 128)
)o_2, lse_2 = self_attention(q, k, v_new)
(q: (n, 128, 192), k: (n, 128, 192), v: (n, 128, 128), o_2: (n, 128, 128), lse_2: (n, 128)
)o, lse = merge(W_UV(o_1), lse_1, o_2, lse_2)
Milestone
- Prefill attention kernel for self-attention (w/o matrix absorption) feat: support deepseek prefill attention shape #765 @yzh119
- Decode attention kernel (w/ matrix absorption, page table) on CUDA Cores feat: support MLA decode #551 @tsu-bin
- General (prefill+append+decode) MLA attention kernel on FA2 template (w/ matrix absorption, page table) perf: memory efficient deepseek mla fused page-attention kernel #804 @yzh119
- General (prefill+append+decode) MLA attention kernel on FA3 template (w/ matrix absorption, page table) @yzh119
Metadata
Assignees
Labels
No labels
Activity