-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathregistry_io.py
More file actions
321 lines (274 loc) · 9.69 KB
/
registry_io.py
File metadata and controls
321 lines (274 loc) · 9.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""Shared I/O functions for reading and writing model registry data files.
Consolidates the load-modify-save pattern used by verify_models.py and
main_benchmark.py into a single module that properly uses the
VerificationRecord/VerificationHistory dataclasses.
"""
import json
import logging
from datetime import date
from pathlib import Path
from typing import Callable, Optional
from .verification import VerificationHistory, VerificationRecord
logger = logging.getLogger(__name__)
_DATA_DIR = Path(__file__).parent / "data"
_SUPPORTED_MODELS_PATH = _DATA_DIR / "supported_models.json"
_VERIFICATION_HISTORY_PATH = _DATA_DIR / "verification_history.json"
# Status codes
STATUS_UNVERIFIED = 0
STATUS_VERIFIED = 1
STATUS_SKIPPED = 2
STATUS_FAILED = 3
# HF-loadable quantization formats. Admitted to the registry; verification gates
# on `required_quant_library_for_model()` at run time.
_HF_LOADABLE_QUANT_PATTERNS = [
"-awq",
"_awq",
"-AWQ",
"_AWQ",
"-gptq",
"_gptq",
"-GPTQ",
"_GPTQ",
"GPTQ",
"-bnb-",
"_bnb_",
"bnb-4bit",
"bnb-8bit",
"-4bit",
"_4bit",
"-8bit",
"_8bit",
"-int4",
"_int4",
"-int8",
"_int8",
"-w4a16",
"-w8a8",
"-W4A16",
"-W8A8",
".w4a16",
".W4A16",
"-hqq",
"_hqq",
"-HQQ",
"_HQQ",
"-3bit",
"_3bit",
"-2bit",
"_2bit",
"-5bit",
"-6bit",
"-oQ",
"_oQ",
"-quantized.",
"_Quantized",
"-Quantized",
]
# Formats that need a non-HF loader (GGUF→llama.cpp, MLX→Apple, FP4/FP8→NVIDIA).
_INCOMPATIBLE_QUANT_PATTERNS = [
"-gguf",
"_gguf",
"-GGUF",
"_GGUF",
"mlx-community/",
"-mlx",
"-MLX",
"_mlx",
"_MLX",
".mlx",
".MLX",
"-fp8",
"_fp8",
"-FP8",
"_FP8",
"-nvfp4",
"_nvfp4",
"-NVFP4",
"_NVFP4",
"-mxfp4",
"_mxfp4",
"-MXFP4",
"_MXFP4",
]
# Values are Python import names, not PyPI package names. Order matters: explicit
# format markers must precede generic bit-width markers (HQQ-4bit IDs match both).
_QUANT_LIBRARY_BY_PATTERN: list[tuple[tuple[str, ...], str]] = [
(("-hqq", "_hqq", "-HQQ", "_HQQ"), "hqq"),
(("-gptq", "_gptq", "-GPTQ", "_GPTQ", "GPTQ"), "auto_gptq"),
(("-awq", "_awq", "-AWQ", "_AWQ"), "awq"),
(("-w4a16", "-w8a8", "-W4A16", "-W8A8", ".w4a16", ".W4A16"), "auto_gptq"),
(("-bnb-", "_bnb_", "bnb-4bit", "bnb-8bit"), "bitsandbytes"),
(("-4bit", "_4bit", "-8bit", "_8bit", "-int4", "_int4", "-int8", "_int8"), "bitsandbytes"),
]
QUANTIZED_NOTE = "Quantized format not loadable by HF transformers"
def is_incompatible_quantized(model_id: str) -> bool:
"""True for quantization formats the bridge can't ingest (GGUF, MLX, FP4/FP8)."""
return any(pat in model_id for pat in _INCOMPATIBLE_QUANT_PATTERNS)
def is_hf_loadable_quantized(model_id: str) -> bool:
"""True for quantizations loadable by HF transformers + a quant library."""
return any(pat in model_id for pat in _HF_LOADABLE_QUANT_PATTERNS)
def required_quant_library_for_model(model_id: str) -> Optional[str]:
"""Return the Python import name needed to load this model, or None if unquantized."""
for patterns, library in _QUANT_LIBRARY_BY_PATTERN:
if any(pat in model_id for pat in patterns):
return library
return None
def is_quantized_model(model_id: str) -> bool:
"""Alias for ``is_incompatible_quantized`` — kept for back-compat with existing call sites."""
return is_incompatible_quantized(model_id)
def load_supported_models_raw() -> dict:
"""Load supported_models.json as a raw dict."""
with open(_SUPPORTED_MODELS_PATH) as f:
return json.load(f)
def save_supported_models_raw(data: dict) -> None:
"""Save raw dict back to supported_models.json."""
with open(_SUPPORTED_MODELS_PATH, "w") as f:
json.dump(data, f, indent=2)
f.write("\n")
def load_verification_history() -> VerificationHistory:
"""Load verification_history.json into a VerificationHistory dataclass."""
if _VERIFICATION_HISTORY_PATH.exists():
with open(_VERIFICATION_HISTORY_PATH) as f:
data = json.load(f)
return VerificationHistory.from_dict(data)
return VerificationHistory()
def save_verification_history(history: VerificationHistory) -> None:
"""Save VerificationHistory dataclass to verification_history.json."""
with open(_VERIFICATION_HISTORY_PATH, "w") as f:
json.dump(history.to_dict(), f, indent=2)
f.write("\n")
def _get_tl_version() -> Optional[str]:
"""Get the current TransformerLens version, or None."""
try:
import transformer_lens
return getattr(transformer_lens, "__version__", None)
except Exception:
return None
def update_model_status(
model_id: str,
arch_id: str,
status: Optional[int] = None,
note: Optional[str] = None,
phase_scores: Optional[dict[int, Optional[float]]] = None,
sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None,
) -> bool:
"""Update a single model entry in supported_models.json.
If the model is not found in the registry and status == STATUS_VERIFIED,
a new entry is appended.
When status is None (partial-phase update), only the provided phase_scores
are updated — status, note, and other scores are preserved.
Args:
model_id: The model to update
arch_id: Architecture of the model
status: New status code (0-3), or None for score-only updates
note: Optional note for skip/fail reason
phase_scores: Phase score dict {1: float, 2: float, 3: float, 4: float}
sanitize_fn: Optional callable to sanitize note strings
Returns:
True if entry was found/created and updated
"""
if phase_scores is None:
phase_scores = {}
if sanitize_fn and note:
note = sanitize_fn(note)
data = load_supported_models_raw()
updated = False
for entry in data.get("models", []):
if entry["model_id"] == model_id and entry["architecture_id"] == arch_id:
if status is not None:
entry["status"] = status
entry["verified_date"] = (
date.today().isoformat() if status != STATUS_UNVERIFIED else None
)
entry["note"] = note
elif note is not None:
# Score-only update with an explicit note — overwrite stale notes
entry["note"] = note
elif phase_scores and "exceeds" in (entry.get("note") or "").lower():
# Writing real scores clears a stale memory-skip note
entry["note"] = None
for phase_num in (1, 2, 3, 4, 7, 8):
key = f"phase{phase_num}_score"
if phase_num in phase_scores:
entry[key] = phase_scores[phase_num]
elif key not in entry:
entry[key] = None
# Reorder keys so phase scores are always in numerical order
_KEY_ORDER = [
"architecture_id",
"model_id",
"status",
"verified_date",
"metadata",
"note",
"phase1_score",
"phase2_score",
"phase3_score",
"phase4_score",
"phase7_score",
"phase8_score",
]
reordered = {k: entry[k] for k in _KEY_ORDER if k in entry}
for k in entry:
if k not in reordered:
reordered[k] = entry[k]
entry.clear()
entry.update(reordered)
updated = True
break
if not updated and status == STATUS_VERIFIED:
# Model not in registry -- add it
data.get("models", []).append(
{
"model_id": model_id,
"architecture_id": arch_id,
"status": status,
"verified_date": date.today().isoformat(),
"metadata": None,
"note": note,
"phase1_score": phase_scores.get(1),
"phase2_score": phase_scores.get(2),
"phase3_score": phase_scores.get(3),
"phase4_score": phase_scores.get(4),
"phase7_score": phase_scores.get(7),
"phase8_score": phase_scores.get(8),
}
)
updated = True
if updated:
models = data.get("models", [])
data["total_verified"] = sum(1 for m in models if m.get("status", 0) == STATUS_VERIFIED)
data["total_models"] = len(models)
data["total_architectures"] = len(set(m["architecture_id"] for m in models))
save_supported_models_raw(data)
return updated
def add_verification_record(
model_id: str,
arch_id: str,
notes: Optional[str] = None,
verified_by: str = "verify_models",
sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None,
) -> None:
"""Append a VerificationRecord to verification_history.json.
Uses the VerificationRecord dataclass properly instead of raw dict
manipulation.
Args:
model_id: The verified model
arch_id: Architecture type
notes: Optional verification notes
verified_by: Who/what performed the verification
sanitize_fn: Optional callable to sanitize note strings
"""
if sanitize_fn and notes:
notes = sanitize_fn(notes)
record = VerificationRecord(
model_id=model_id,
architecture_id=arch_id,
verified_date=date.today(),
verified_by=verified_by,
transformerlens_version=_get_tl_version(),
notes=notes,
)
history = load_verification_history()
history.add_record(record)
save_verification_history(history)