forked from ml-explore/mlx-swift-lm
-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathgemma4_mtp_integration_test.py
More file actions
137 lines (113 loc) · 4.31 KB
/
gemma4_mtp_integration_test.py
File metadata and controls
137 lines (113 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
"""
Gemma4 MTP Speculative Decoding — Real Model Integration Test
=============================================================
Models:
Main: mlx-community/gemma-4-e2b-it-4bit (~3.4 GB)
Assistant: mlx-community/gemma-4-E2B-it-assistant-bf16 (~0.18 GB)
Total RAM: ~3.6 GB — safe on 64 GB M5 Pro
Safety limits:
max_tokens: 50
num_draft_tokens: 2 (MTP depth)
Usage:
python3 gemma4_mtp_integration_test.py
"""
import os
import sys
import time
import subprocess
from pathlib import Path
HF_CACHE = Path.home() / ".cache/huggingface/hub"
MAIN_ID = "mlx-community/gemma-4-e2b-it-4bit"
ASST_ID = "mlx-community/gemma-4-E2B-it-assistant-bf16"
PROMPT = "What is the capital of France? Answer in one word."
MAX_TOKENS = 50
NUM_DRAFT = 2
def find_snapshot(model_id: str) -> Path:
slug = "models--" + model_id.replace("/", "--")
snaps = list((HF_CACHE / slug / "snapshots").glob("*"))
if not snaps:
raise FileNotFoundError(f"Model not cached: {model_id}")
return snaps[0]
def check_mlx_lm() -> bool:
try:
import mlx_lm # noqa: F401
return True
except ImportError:
return False
def run_mlx_lm(model_dir: str, prompt: str, max_tokens: int,
draft_model: str | None = None) -> tuple[str, float]:
"""Run mlx_lm.generate and return (output_text, tps)."""
import mlx_lm
print(f" Loading model from: {model_dir}")
load_kwargs = {}
if draft_model:
load_kwargs["draft_model"] = draft_model
model, tokenizer = mlx_lm.load(model_dir, **load_kwargs)
prompt_tokens = tokenizer.encode(prompt, return_tensors="mlx")
t0 = time.perf_counter()
response = mlx_lm.generate(
model,
tokenizer,
prompt=prompt,
max_tokens=max_tokens,
verbose=False,
)
elapsed = time.perf_counter() - t0
# count output tokens (approximate)
output_tokens = len(tokenizer.encode(response))
tps = output_tokens / elapsed if elapsed > 0 else 0.0
return response, tps
def main():
print("=" * 55)
print(" Gemma 4 E2B — MTP Speculative Decoding Test")
print("=" * 55)
# Check model presence
try:
main_snap = find_snapshot(MAIN_ID)
asst_snap = find_snapshot(ASST_ID)
print(f"✅ Main model: {main_snap}")
print(f"✅ Assistant model: {asst_snap}")
except FileNotFoundError as e:
print(f"⚠️ {e}")
print(" Run: mlx_lm.convert --hf-path ... to download.")
sys.exit(1)
if not check_mlx_lm():
print("\n❌ mlx_lm not installed. Run:")
print(" pip install mlx-lm")
sys.exit(1)
print(f"\n📝 Prompt: \"{PROMPT}\"")
print(f" max_tokens={MAX_TOKENS}, num_draft={NUM_DRAFT}\n")
# --- Baseline (no MTP) ---
print("--- Baseline (no speculative decoding) ---")
base_text, base_tps = run_mlx_lm(str(main_snap), PROMPT, MAX_TOKENS)
print(f" Output: \"{base_text.strip()[:80]}\"")
print(f" Speed: {base_tps:.1f} tok/s\n")
# --- MTP speculative ---
print("--- MTP Speculative (draft_model=assistant) ---")
# mlx_lm draft model support: pass draft model path
try:
mtp_text, mtp_tps = run_mlx_lm(
str(main_snap), PROMPT, MAX_TOKENS,
draft_model=str(asst_snap)
)
speedup = mtp_tps / base_tps if base_tps > 0 else 0
print(f" Output: \"{mtp_text.strip()[:80]}\"")
print(f" Speed: {mtp_tps:.1f} tok/s")
print(f"\n{'='*55}")
print(f" Speedup: {speedup:.2f}x")
print(f" Baseline: {base_tps:.1f} tok/s")
print(f" MTP: {mtp_tps:.1f} tok/s")
print(f"{'='*55}")
# Validate output
assert "paris" in base_text.lower(), f"Baseline didn't say Paris: {base_text}"
assert "paris" in mtp_text.lower(), f"MTP didn't say Paris: {mtp_text}"
assert speedup >= 0.8, f"MTP regressed: {speedup:.2f}x"
print("\n✅ All assertions passed!")
except Exception as e:
print(f"\n⚠️ MTP generation with draft_model failed: {e}")
print(" This may mean mlx_lm doesn't yet support draft_model= for MTP.")
print(" The Swift MTPTokenIterator integration test validates the pipeline directly.")
print(" Run: bash run_tests.sh Gemma4Tests (unit tests already pass)")
if __name__ == "__main__":
main()