Skip to content

Commit 4b3166d

Browse files
authored
[varlen Kernel] Add test spec for unified varlen paged attention (#147)
## Goal This is the spec to smoothly migrate v1 kernel (only paged) to v2 kernel (paged + varlen), without break anything on main, but enable us to support continuous batching. ## Summary Adds `test_metal_unified_attention.py`, adapted from vLLM upstream's `test_triton_unified_attention.py`. Three test functions form a triangle validation: ``` v1 kernel (production) / \ PASS / \ xfail (runs / \ (v2 not built) now) / \ v v ref_paged_attn ----> v2 kernel (unified varlen) (naive MLX) xfail ``` - `test_v1_kernel_vs_reference` - v1 == ref (12 cases, runs now, validates the reference) - `test_metal_unified_attn_decode_only` - v2 == v1 (24 cases, xfail, proves v2 is a drop-in replacement for decode) - `test_metal_unified_attn` - v2 == ref (192 cases, xfail, full varlen with mixed prefill+decode, GQA, sliding window, soft cap) ## Migration plan ``` Step 1: Build v2 kernel, test decode-only (q_len=1) Scaffolding test goes green: v2 == v1 | Step 2: Extend v2 to handle varlen (q_len > 1) Full test goes green: v2 == ref | Step 3: Delete scaffolding test, replace v1 with v2 in production ``` --------- Signed-off-by: ran <hzz5361@psu.edu>
1 parent 1603a16 commit 4b3166d

1 file changed

Lines changed: 470 additions & 0 deletions

File tree

0 commit comments

Comments
 (0)