|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +"""GPT-OSS 20B smoke test: mlx_lm ground truth for sink attention work (#148). |
| 4 | +
|
| 5 | +Loads openai/gpt-oss-20b, generates with greedy decoding, and compares |
| 6 | +output against golden token IDs. Not in CI since it requires ~21.5 GB model. |
| 7 | +
|
| 8 | +Run: |
| 9 | + VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/test_gpt_oss_smoke.py |
| 10 | +""" |
| 11 | + |
| 12 | +import os |
| 13 | +import sys |
| 14 | + |
| 15 | +os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0") |
| 16 | + |
| 17 | +from transformers import AutoTokenizer # noqa: E402 |
| 18 | +from vllm import LLM, SamplingParams # noqa: E402 |
| 19 | + |
| 20 | +MODEL_NAME = "openai/gpt-oss-20b" |
| 21 | +MAX_TOKENS = 100 |
| 22 | + |
| 23 | +PROMPTS = [ |
| 24 | + "The capital of France is", |
| 25 | + "The weather today is not", |
| 26 | + "One plus one equals", |
| 27 | + "The largest planet in our solar system is", |
| 28 | + "Water boils at a temperature of", |
| 29 | +] |
| 30 | + |
| 31 | +# fmt: off |
| 32 | +# Golden token IDs from MLX inline cache, greedy decoding (openai/gpt-oss-20b). |
| 33 | +# Generated via: |
| 34 | +# VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/gen_golden_token_ids_for_deterministics.py \ |
| 35 | +# --model openai/gpt-oss-20b --max-tokens 100 --chat-template |
| 36 | +# |
| 37 | +# Note: FP non-determinism at longer sequences may cause 2-3 prompts to diverge |
| 38 | +# after ~25 tokens across runs. Regenerate with the command above if needed. |
| 39 | +GOLDEN_MLX = { |
| 40 | + "The capital of France is": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 976, 9029, 328, 10128, 382, 4050, 3164, 6960, 1682, 290, 6052, 25, 392, 72782, 4050, 2632, 9570, 483, 392, 72782, 4050, 63659, 1327, 6052, 13, 200007, 200006, 173781, 200005, 17196, 200008, 72782, 200002], |
| 41 | + "The weather today is not": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 976, 11122, 4044, 382, 625, 4050, 4569, 7890, 60592, 13, 3164, 3572, 413, 8601, 261, 21872, 25, 392, 976, 11122, 4044, 382, 625, 723, 64493, 49706, 889, 1023, 9289, 9115, 13, 3164, 3572, 413, 16054, 395, 3543, 30, 2604, 10112, 1023, 1682, 316, 1761, 290, 11122, 30, 623, 1825, 5003, 392, 976, 11122, 4044, 382, 625, 4050, 4569, 382, 60592, 13, 1416, 1309, 316, 9570, 54286, 13, 1416, 2023, 3810, 395, 108041, 25, 392, 4827, 1481, 481, 1299, 316, 1761, 1078, 290, 11122, 16842, 2604, 581, 2023, 18135, 484, 1023, 1682, 316], |
| 42 | + "One plus one equals": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 5045, 2932, 1001, 29702, 4050, 3164, 6960, 1682, 290, 6052, 25, 220, 17, 13, 3072, 10112, 1023, 1682, 261, 945, 65742, 6052, 30, 623, 1825, 3572, 413, 11493, 13, 623, 63122, 6052, 25, 220, 17, 13, 3072, 10112, 1023, 1682, 261, 15681, 30, 623, 21179, 25, 392, 3575, 553, 17554, 162016, 11, 261, 4410, 6439, 2359, 22203, 656, 7788, 17527, 3692, 32711, 860, 3582, 21179, 13, 2632, 6052, 25, 220, 17, 13, 200007, 200006, 173781, 200005, 17196, 200008, 5045, 2932, 1001, 29702, 6240, 17, 410, 13, 200002], |
| 43 | + "The largest planet in our solar system is": [200005, 35644, 200008, 976, 1825, 31064, 25, 392, 976, 10574, 17921, 306, 1039, 17624, 2420, 382, 4050, 3164, 6960, 1682, 290, 6052, 25, 79575, 13, 3164, 3572, 1682, 261, 18128, 13, 2632, 6052, 25, 79575, 13, 138743, 8633, 4275, 290, 10574, 13, 2632, 9570, 25, 79575, 13, 200007, 200006, 173781, 200005, 17196, 200008, 976, 10574, 17921, 306, 1039, 17624, 2420, 382, 6240, 41, 26451, 410, 13, 200002], |
| 44 | + "Water boils at a temperature of": [200005, 35644, 200008, 976, 1825, 5003, 25, 392, 27874, 165683, 540, 261, 12088, 328, 4050, 3164, 6960, 1682, 290, 79667, 2438, 328, 3411, 13, 3072, 290, 4928, 382, 60592, 25, 392, 27874, 165683, 540, 261, 12088, 328, 4050, 3164, 3572, 1682, 290, 6052, 25, 220, 1353, 26557, 540, 220, 16, 83327, 11, 503, 220, 19584, 68854, 13, 3072, 10112, 1023, 1682, 290, 12088, 306, 181775, 25, 220, 33797, 13, 1055, 658, 13, 623, 1825, 3572, 413, 35885, 261, 52077, 6052, 13, 623, 4928, 382, 60592, 889, 6960, 1023, 1682, 290, 79667, 2438, 13, 2632, 6052, 25, 220, 1353, 26557, 350], |
| 45 | +} |
| 46 | +# fmt: on |
| 47 | + |
| 48 | + |
| 49 | +def _apply_chat_template(model_name, prompts): |
| 50 | + """Apply chat template and return (formatted_prompts, reverse_map).""" |
| 51 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 52 | + formatted = [] |
| 53 | + reverse_map = {} |
| 54 | + for prompt in prompts: |
| 55 | + messages = [{"role": "user", "content": prompt}] |
| 56 | + fmt = tokenizer.apply_chat_template( |
| 57 | + messages, add_generation_prompt=True, tokenize=False |
| 58 | + ) |
| 59 | + formatted.append(fmt) |
| 60 | + reverse_map[fmt] = prompt |
| 61 | + return formatted, reverse_map |
| 62 | + |
| 63 | + |
| 64 | +if __name__ == "__main__": |
| 65 | + formatted_prompts, reverse_map = _apply_chat_template(MODEL_NAME, PROMPTS) |
| 66 | + |
| 67 | + llm = LLM(model=MODEL_NAME, max_model_len=512, max_num_seqs=1) |
| 68 | + sp = SamplingParams(temperature=0, max_tokens=MAX_TOKENS) |
| 69 | + outputs = llm.generate(formatted_prompts, sp) |
| 70 | + |
| 71 | + passed = 0 |
| 72 | + failed = 0 |
| 73 | + for o in outputs: |
| 74 | + prompt = reverse_map[o.prompt] |
| 75 | + token_ids = list(o.outputs[0].token_ids) |
| 76 | + text = o.outputs[0].text |
| 77 | + expected = GOLDEN_MLX[prompt] |
| 78 | + matched = token_ids == expected |
| 79 | + |
| 80 | + status = "PASS" if matched else "FAIL" |
| 81 | + print(f" [{status}] {prompt!r}") |
| 82 | + print(f" output: {text!r}") |
| 83 | + if not matched: |
| 84 | + print(f" got: {token_ids}") |
| 85 | + print(f" expected: {expected}") |
| 86 | + failed += 1 |
| 87 | + else: |
| 88 | + passed += 1 |
| 89 | + |
| 90 | + print(f"\n{passed} passed, {failed} failed") |
| 91 | + sys.exit(1 if failed else 0) |
0 commit comments