Skip to content

Commit 6ecf38f

Browse files
[Paged KV] Add paged attention deterministic smoke test (#138)
## Summary - Add `test_paged_deterministic.py`: 5-prompt smoke test using vLLM offline inference (temp=0, greedy) against hardcoded golden token IDs from Qwen3-0.6B - Golden values generated on `main` from both MLX inline cache and HF paged KV cache paths - Add `tools/gen_golden.py` helper to regenerate golden values ## Motivation Prerequisite for the native Metal kernel PR (#136). After inlining the vendored Metal shaders, paged attention output must remain identical to the current HF kernel baseline. This test anchors that. ## Test - `python -m pytest tests/test_paged_deterministic.py -v -s` (paged path by default) - Passes on `main` with HF kernel: 5/5 ## Relevant Issue & PR * Issue #119 * PR #136 : This inline metal kernel need to either pass this test, or explain the possible non-deterministics from the kernel. upstream batch invariant feature * blog: https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ * main feature: vllm-project/vllm#27433 * vllm upstream batch invariant feature is only compatible with H / B series NVIDIA GPU. A100 not working. See my exp results https://github.com/WindChimeRan/spec_deterministic * community work: vllm-project/vllm#30018 Batch invariant is hardware & kernel dependent. Supporting this feature is non-trivial on metal. output example: <img width="1061" height="721" alt="image" src="https://github.com/user-attachments/assets/bf423b90-c567-408b-8682-e2c36050fb8f" /> --------- Signed-off-by: ran <hzz5361@psu.edu> Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com> Co-authored-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent e0f97be commit 6ecf38f

2 files changed

Lines changed: 212 additions & 0 deletions

File tree

tests/test_paged_deterministic.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Deterministic smoke test: vLLM offline inference with golden token comparison.
3+
4+
Golden token IDs were generated on the main branch using vLLM offline inference
5+
with temperature=0 (greedy decoding) on Qwen/Qwen3-0.6B, running one sequence
6+
at a time (max_num_seqs=1) to avoid batch-invariance issues on Metal.
7+
8+
Findings from golden generation (main branch, HF paged-attention kernel):
9+
- The HF kernel paged KV path produces correct, coherent output.
10+
- 4/5 prompts are identical to the MLX inline cache path.
11+
- 1/5 ("The capital of France is") diverges at token 5 — both continuations
12+
are valid English ("France is also the capital" vs "Italy is Rome. The").
13+
Likely caused by floating-point non-determinism in the attention kernel
14+
where top-2 logits are very close.
15+
16+
The assert accepts EITHER golden set (mlx-cache or paged-cache) and prints
17+
which path matched.
18+
19+
Run (paged KV path, the default):
20+
python -m pytest tests/test_paged_deterministic.py -v -s
21+
22+
To test the MLX inline cache path instead, pass env vars explicitly:
23+
VLLM_METAL_USE_PAGED_ATTENTION=0 VLLM_METAL_MEMORY_FRACTION=auto \
24+
python -m pytest tests/test_paged_deterministic.py -v -s
25+
26+
Note: MLX requires VLLM_METAL_MEMORY_FRACTION=auto (numeric fractions are
27+
only valid for the paged attention path).
28+
"""
29+
30+
from __future__ import annotations
31+
32+
import os
33+
34+
import pytest
35+
from vllm import LLM, SamplingParams
36+
37+
MODEL_NAME = "Qwen/Qwen3-0.6B"
38+
MAX_TOKENS = 10
39+
DEFAULT_USE_PAGED_ATTENTION = "1"
40+
DEFAULT_PAGED_MEMORY_FRACTION = "0.2"
41+
DEFAULT_MLX_MEMORY_FRACTION = "auto"
42+
43+
PROMPTS = [
44+
"The capital of France is",
45+
"The weather today is not",
46+
"One plus one equals",
47+
"The largest planet in our solar system is",
48+
"Water boils at a temperature of",
49+
]
50+
51+
# fmt: off
52+
# Golden token IDs from MLX inline cache (default path), greedy decoding.
53+
# Generated on main branch via: VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py
54+
GOLDEN_MLX = {
55+
"The capital of France is": [12095, 13, 576, 6722, 315, 9625, 374, 1083, 279, 6722],
56+
"The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13],
57+
"One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11],
58+
"The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13],
59+
"Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315],
60+
}
61+
62+
# Golden token IDs from paged KV cache (HF kernel on main branch), greedy decoding.
63+
# Generated on main branch via: VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \
64+
# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py
65+
GOLDEN_PAGED = {
66+
"The capital of France is": [12095, 13, 576, 6722, 315, 15344, 374, 21718, 13, 576],
67+
"The weather today is not": [1661, 13, 576, 9315, 374, 220, 17, 15, 12348, 13],
68+
"One plus one equals": [825, 11, 825, 5519, 825, 16819, 1378, 13, 2055, 11],
69+
"The largest planet in our solar system is": [1112, 30, 362, 13, 43562, 425, 13, 48976, 356, 13],
70+
"Water boils at a temperature of": [220, 16, 15, 15, 30937, 13, 3555, 374, 279, 9315],
71+
}
72+
# fmt: on
73+
74+
75+
def _setenv_default(mp: pytest.MonkeyPatch, key: str, default: str) -> str:
76+
"""Set an env var only when absent and return the effective value."""
77+
value = os.environ.get(key)
78+
if value is None:
79+
mp.setenv(key, default)
80+
return default
81+
return value
82+
83+
84+
@pytest.fixture(autouse=True, scope="module")
85+
def _set_env():
86+
"""Set default env vars for this test.
87+
88+
Uses MonkeyPatch.context() so env changes are automatically reverted
89+
after the module, avoiding side effects on other tests.
90+
91+
Defaults to the paged KV cache path to ensure the test actually exercises
92+
the paged attention kernel, but respects any env vars already set by the
93+
user (e.g. to run the MLX path).
94+
"""
95+
with pytest.MonkeyPatch.context() as mp:
96+
mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
97+
98+
# Default to paged attention, but allow explicit caller override.
99+
use_paged = _setenv_default(
100+
mp,
101+
"VLLM_METAL_USE_PAGED_ATTENTION",
102+
DEFAULT_USE_PAGED_ATTENTION,
103+
)
104+
105+
# Choose a path-specific memory default, while preserving caller override.
106+
memory_default = (
107+
DEFAULT_PAGED_MEMORY_FRACTION
108+
if use_paged == "1"
109+
else DEFAULT_MLX_MEMORY_FRACTION
110+
)
111+
_setenv_default(mp, "VLLM_METAL_MEMORY_FRACTION", memory_default)
112+
yield
113+
114+
115+
@pytest.fixture(scope="module")
116+
def vllm_outputs():
117+
"""Run vLLM offline inference once for all prompts.
118+
119+
Uses max_num_seqs=1 to avoid batch-invariance non-determinism on Metal.
120+
"""
121+
llm = LLM(model=MODEL_NAME, max_model_len=512, max_num_seqs=1)
122+
sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
123+
outputs = llm.generate(PROMPTS, sp)
124+
return {o.prompt: o for o in outputs}
125+
126+
127+
class TestPagedDeterministic:
128+
@pytest.mark.slow
129+
@pytest.mark.parametrize("prompt", PROMPTS)
130+
def test_generate_matches_golden(self, vllm_outputs, prompt):
131+
output = vllm_outputs[prompt]
132+
token_ids = list(output.outputs[0].token_ids)
133+
text = output.outputs[0].text
134+
135+
mlx_expected = GOLDEN_MLX[prompt]
136+
paged_expected = GOLDEN_PAGED[prompt]
137+
138+
mlx_match = token_ids == mlx_expected
139+
paged_match = token_ids == paged_expected
140+
141+
print(f"\n prompt: {prompt!r}")
142+
print(f" output: {text!r}")
143+
print(f" ids: {token_ids}")
144+
if mlx_match:
145+
print(" result: MATCHED mlx-cache golden")
146+
elif paged_match:
147+
print(" result: MATCHED paged-cache golden")
148+
else:
149+
print(" result: NO MATCH")
150+
print(f" expected (mlx): {mlx_expected}")
151+
print(f" expected (paged): {paged_expected}")
152+
153+
assert mlx_match or paged_match, (
154+
f"Output for {prompt!r} matched neither golden set.\n"
155+
f"Got: {token_ids}\n"
156+
f"Expected (mlx): {mlx_expected}\n"
157+
f"Expected (pgd): {paged_expected}"
158+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Generate golden token IDs for the deterministic smoke test.
4+
5+
Runs vLLM offline inference (greedy, max_num_seqs=1) and prints golden
6+
token-ID dicts to paste into test_paged_deterministic.py.
7+
8+
Usage:
9+
# MLX inline cache (default):
10+
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py
11+
12+
# Paged KV cache:
13+
VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.3 \
14+
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py
15+
16+
Note: MLX path requires VLLM_METAL_MEMORY_FRACTION=auto (the default).
17+
Numeric fractions are only valid for the paged attention path.
18+
"""
19+
20+
import os
21+
22+
os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
23+
24+
from vllm import LLM, SamplingParams
25+
26+
MODEL = "Qwen/Qwen3-0.6B"
27+
MAX_TOKENS = 10
28+
29+
PROMPTS = [
30+
"The capital of France is",
31+
"The weather today is not",
32+
"One plus one equals",
33+
"The largest planet in our solar system is",
34+
"Water boils at a temperature of",
35+
]
36+
37+
if __name__ == "__main__":
38+
paged = os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION", "0") == "1"
39+
label = "PAGED" if paged else "MLX"
40+
print(f"\n--- Generating golden values for {label} path ---\n")
41+
42+
llm = LLM(model=MODEL, max_model_len=512, max_num_seqs=1)
43+
sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
44+
outputs = llm.generate(PROMPTS, sp)
45+
46+
print(f"\nGOLDEN_{label} = {{")
47+
for o in outputs:
48+
prompt = o.prompt
49+
ids = list(o.outputs[0].token_ids)
50+
text = o.outputs[0].text
51+
pad = 45 - len(prompt)
52+
print(f" {prompt!r}:{' ' * pad}{ids},")
53+
print(f" # → {text!r}")
54+
print("}")

0 commit comments

Comments
 (0)