Commit 4b3166d
authored
[varlen Kernel] Add test spec for unified varlen paged attention (#147)
## Goal
This is the spec to smoothly migrate v1 kernel (only paged) to v2 kernel
(paged + varlen), without break anything on main, but enable us to
support continuous batching.
## Summary
Adds `test_metal_unified_attention.py`, adapted from vLLM upstream's
`test_triton_unified_attention.py`.
Three test functions form a triangle validation:
```
v1 kernel (production)
/ \
PASS / \ xfail
(runs / \ (v2 not built)
now) / \
v v
ref_paged_attn ----> v2 kernel (unified varlen)
(naive MLX) xfail
```
- `test_v1_kernel_vs_reference` - v1 == ref (12 cases, runs now,
validates the reference)
- `test_metal_unified_attn_decode_only` - v2 == v1 (24 cases, xfail,
proves v2 is a drop-in replacement for decode)
- `test_metal_unified_attn` - v2 == ref (192 cases, xfail, full varlen
with mixed prefill+decode, GQA, sliding window, soft cap)
## Migration plan
```
Step 1: Build v2 kernel, test decode-only (q_len=1)
Scaffolding test goes green: v2 == v1
|
Step 2: Extend v2 to handle varlen (q_len > 1)
Full test goes green: v2 == ref
|
Step 3: Delete scaffolding test, replace v1 with v2 in production
```
---------
Signed-off-by: ran <hzz5361@psu.edu>1 parent 1603a16 commit 4b3166d
1 file changed
Lines changed: 470 additions & 0 deletions
0 commit comments