-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprobe_cli.py
More file actions
455 lines (370 loc) · 16.6 KB
/
probe_cli.py
File metadata and controls
455 lines (370 loc) · 16.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
#!/usr/bin/env python3
"""
AFM7 Probing CLI Tool
Command-line interface for probing AFM7 model internals during inference.
Outputs probe results to JSON for visualization in probe_viewer.html.
Usage:
python probe_cli.py --model-path /path/to/mlx_afm7 --prompt "Your prompt here"
python probe_cli.py --model-path /path/to/mlx_afm7 --interactive
Examples:
# Single prompt with all probes
python probe_cli.py -m mlx_afm7 -p "What is machine learning?"
# Interactive mode with specific layers
python probe_cli.py -m mlx_afm7 --interactive --layers 0,17,34,55
# Export detailed probe data
python probe_cli.py -m mlx_afm7 -p "Hello" --output probe_results.json --verbose
"""
import argparse
import json
import sys
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, List, Any, Tuple
from collections import defaultdict
import numpy as np
# Ensure parent directory is in path for imports
sys.path.insert(0, str(Path(__file__).parent))
@dataclass
class ProbeConfig:
"""Configuration for probing."""
capture_embeddings: bool = True
capture_layer_outputs: bool = True
capture_ffn_activations: bool = True
capture_logits: bool = True
capture_token_probs: bool = True
layer_indices: Optional[List[int]] = None
max_sequence_positions: int = 512
@dataclass
class ProbeResults:
"""Container for probe results."""
input_tokens: List[int] = field(default_factory=list)
input_text: str = ""
generated_tokens: List[int] = field(default_factory=list)
generated_text: str = ""
top_k_tokens: List[Tuple[int, float, str]] = field(default_factory=list)
# Shapes only (for JSON export)
layer_outputs_shapes: Dict[str, List[int]] = field(default_factory=dict)
kv_reuse_layer_outputs_shapes: Dict[str, List[int]] = field(default_factory=dict)
# Statistics
embedding_stats: Dict[str, float] = field(default_factory=dict)
logits_stats: Dict[str, float] = field(default_factory=dict)
ffn_stats: Dict[str, Dict[str, float]] = field(default_factory=dict)
def to_dict(self):
"""Convert to serializable dictionary."""
return {
"input_text": self.input_text,
"input_tokens": self.input_tokens,
"generated_text": self.generated_text,
"generated_tokens": self.generated_tokens,
"top_k_tokens": self.top_k_tokens,
"layer_outputs_shapes": self.layer_outputs_shapes,
"kv_reuse_layer_outputs_shapes": self.kv_reuse_layer_outputs_shapes,
"embedding_stats": self.embedding_stats,
"logits_stats": self.logits_stats,
"ffn_stats": self.ffn_stats,
}
class CLIProber:
"""Command-line probing tool for AFM7."""
def __init__(self, model, tokenizer, config: ProbeConfig, verbose: bool = False):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.verbose = verbose
self.results = ProbeResults()
# Storage for raw arrays (not exported to JSON)
self._embeddings = None
self._layer_outputs = {}
self._kv_reuse_outputs = {}
self._ffn_gates = {}
self._logits = None
def _log(self, msg: str):
if self.verbose:
print(f"[PROBE] {msg}")
def _should_capture_layer(self, idx: int) -> bool:
if self.config.layer_indices is None:
return True
return idx in self.config.layer_indices
def _compute_stats(self, arr: np.ndarray) -> Dict[str, float]:
"""Compute basic statistics for an array."""
return {
"mean": float(np.mean(arr)),
"std": float(np.std(arr)),
"min": float(np.min(arr)),
"max": float(np.max(arr)),
"l2_norm": float(np.linalg.norm(arr)),
}
def probe_generation(
self,
prompt: str,
max_tokens: int = 100,
temperature: float = 0.0,
top_k: int = 10
) -> ProbeResults:
"""Run generation with probing enabled."""
import mlx.core as mx
self.results = ProbeResults()
self.results.input_text = prompt
# Encode
tokens = self.tokenizer.encode(prompt)
self.results.input_tokens = tokens
self._log(f"Input: {len(tokens)} tokens")
# Get model internals
afm_model = self.model.model
# Create cache
cache = self.model.make_cache()
x = mx.array([tokens])
# Capture embedding output
if self.config.capture_embeddings:
embeddings = afm_model.embedding(x)
mx.eval(embeddings)
self._embeddings = np.array(embeddings)
self.results.embedding_stats = self._compute_stats(self._embeddings)
self._log(f"Embeddings shape: {self._embeddings.shape}")
# Forward pass through Block 1 layers
h = afm_model.embedding(x)
mask = None # Will be created if needed
for i, layer in enumerate(afm_model.layers):
if self._should_capture_layer(i):
# Capture layer output
h_out = layer(h, mask=mask, cache=cache[i] if cache else None)
mx.eval(h_out)
if self.config.capture_layer_outputs:
arr = np.array(h_out)
self._layer_outputs[i] = arr
self.results.layer_outputs_shapes[str(i)] = list(arr.shape)
# Capture FFN gate activations
if self.config.capture_ffn_activations:
gate = layer.mlp.gate_proj(h)
mx.eval(gate)
gate_arr = np.array(gate)
self._ffn_gates[i] = gate_arr
# Compute sparsity (fraction of values near 0 after SiLU)
silu_gate = gate_arr / (1 + np.exp(-gate_arr)) # Approximate SiLU
sparsity = float((np.abs(silu_gate) < 0.1).mean())
mean_act = float(np.abs(silu_gate).mean())
self.results.ffn_stats[str(i)] = {
"sparsity": sparsity,
"mean_activation": mean_act
}
h = h_out
self._log(f"Layer {i}: shape {arr.shape}, norm {np.linalg.norm(arr):.2f}")
else:
h = layer(h, mask=mask, cache=cache[i] if cache else None)
# Get K,V from last Block 1 layer for Block 2
if hasattr(cache[-1], 'state') and cache[-1].state is not None:
keys, values = cache[-1].state
else:
keys, values = None, None
# Forward through Block 2 (KV-reuse) layers
for i, layer in enumerate(afm_model.kv_reuse_layers):
block2_idx = len(afm_model.layers) + i
if self._should_capture_layer(block2_idx):
if keys is not None:
h_out = layer(h, keys, values, mask=mask)
else:
h_out = layer(h, mask=mask)
mx.eval(h_out)
if self.config.capture_layer_outputs:
arr = np.array(h_out)
self._kv_reuse_outputs[i] = arr
self.results.kv_reuse_layer_outputs_shapes[str(i)] = list(arr.shape)
# FFN stats for Block 2
if self.config.capture_ffn_activations:
gate = layer.mlp.gate_proj(h)
mx.eval(gate)
gate_arr = np.array(gate)
silu_gate = gate_arr / (1 + np.exp(-gate_arr))
sparsity = float((np.abs(silu_gate) < 0.1).mean())
mean_act = float(np.abs(silu_gate).mean())
self.results.ffn_stats[str(block2_idx)] = {
"sparsity": sparsity,
"mean_activation": mean_act
}
h = h_out
self._log(f"KV-Reuse {i}: shape {arr.shape}")
else:
if keys is not None:
h = layer(h, keys, values, mask=mask)
else:
h = layer(h, mask=mask)
# Final norm and logits
h = afm_model.output_norm(h)
logits = afm_model.embedding.as_linear(h)
mx.eval(logits)
if self.config.capture_logits:
self._logits = np.array(logits[0, -1, :])
self.results.logits_stats = self._compute_stats(self._logits)
self._log(f"Logits shape: {self._logits.shape}")
# Get token probabilities
if self.config.capture_token_probs:
probs = mx.softmax(logits[0, -1, :], axis=-1)
probs_np = np.array(probs)
top_indices = np.argsort(probs_np)[-top_k:][::-1]
for idx in top_indices:
prob = float(probs_np[idx])
token_text = self.tokenizer.decode([int(idx)])
self.results.top_k_tokens.append((int(idx), prob, token_text))
# Generate tokens (simplified - full generation would continue the loop)
generated = []
for _ in range(max_tokens):
if temperature == 0:
next_token = mx.argmax(logits[0, -1, :], axis=-1)
else:
probs = mx.softmax(logits[0, -1, :] / temperature, axis=-1)
next_token = mx.random.categorical(probs)
next_token_int = int(next_token)
generated.append(next_token_int)
if next_token_int == 150001: # <turn_end>
break
x = next_token.reshape(1, 1)
logits = self.model(x, cache=cache)
mx.eval(logits)
self.results.generated_tokens = generated
self.results.generated_text = self.tokenizer.decode(generated)
return self.results
def format_prompt(user_message: str, system_prompt: str = "You are a helpful assistant.") -> str:
"""Format prompt for AFM7."""
return f"system\n{system_prompt}<turn_end> user\n {user_message}<turn_end> assistant\n"
def print_results_summary(results: ProbeResults):
"""Print a summary of probe results."""
print("\n" + "=" * 60)
print("PROBE RESULTS SUMMARY")
print("=" * 60)
print(f"\n📥 INPUT ({len(results.input_tokens)} tokens)")
print(f" {results.input_text[:100]}{'...' if len(results.input_text) > 100 else ''}")
print(f"\n📤 OUTPUT ({len(results.generated_tokens)} tokens)")
print(f" {results.generated_text[:200]}{'...' if len(results.generated_text) > 200 else ''}")
if results.embedding_stats:
print(f"\n📊 EMBEDDING STATS")
for k, v in results.embedding_stats.items():
print(f" {k}: {v:.4f}")
if results.layer_outputs_shapes:
print(f"\n🔢 BLOCK 1 LAYERS CAPTURED: {len(results.layer_outputs_shapes)}")
for layer, shape in list(results.layer_outputs_shapes.items())[:5]:
print(f" Layer {layer}: {shape}")
if len(results.layer_outputs_shapes) > 5:
print(f" ... and {len(results.layer_outputs_shapes) - 5} more")
if results.kv_reuse_layer_outputs_shapes:
print(f"\n🔄 BLOCK 2 (KV-REUSE) LAYERS CAPTURED: {len(results.kv_reuse_layer_outputs_shapes)}")
if results.ffn_stats:
print(f"\n⚡ FFN GATE STATS (sample)")
for layer, stats in list(results.ffn_stats.items())[:3]:
print(f" Layer {layer}: sparsity={stats['sparsity']:.2%}, mean_act={stats['mean_activation']:.3f}")
if results.top_k_tokens:
print(f"\n🎯 TOP-K NEXT TOKEN PREDICTIONS")
for i, (tok_id, prob, tok_text) in enumerate(results.top_k_tokens[:5]):
print(f" {i+1}. {repr(tok_text)} (id={tok_id}): {prob:.2%}")
if results.logits_stats:
print(f"\n📈 LOGITS STATS")
for k, v in results.logits_stats.items():
print(f" {k}: {v:.4f}")
print("\n" + "=" * 60)
def main():
parser = argparse.ArgumentParser(
description="AFM7 Probing CLI - Inspect model internals during inference",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python probe_cli.py -m mlx_afm7 -p "What is AI?"
python probe_cli.py -m mlx_afm7 --interactive --layers 0,17,34,55
python probe_cli.py -m mlx_afm7 -p "Hello" -o results.json -v
"""
)
parser.add_argument("-m", "--model-path", default="mlx_afm7",
help="Path to MLX model directory")
parser.add_argument("-p", "--prompt", type=str,
help="User prompt for single inference")
parser.add_argument("-s", "--system-prompt", default="You are a helpful assistant.",
help="System prompt")
parser.add_argument("--interactive", action="store_true",
help="Run in interactive mode")
parser.add_argument("--max-tokens", type=int, default=100,
help="Maximum tokens to generate")
parser.add_argument("--temperature", type=float, default=0.0,
help="Sampling temperature (0 = greedy)")
parser.add_argument("--layers", type=str, default=None,
help="Comma-separated layer indices to capture (default: all)")
parser.add_argument("-o", "--output", type=str,
help="Output JSON file for probe results")
parser.add_argument("-v", "--verbose", action="store_true",
help="Verbose output during probing")
parser.add_argument("--no-embeddings", action="store_true",
help="Skip embedding capture")
parser.add_argument("--no-ffn", action="store_true",
help="Skip FFN activation capture")
args = parser.parse_args()
# Validate args
if not args.prompt and not args.interactive:
parser.error("Must specify --prompt or --interactive")
# Import and apply patches
from generate import load_model, apply_mlx_lm_patches
apply_mlx_lm_patches()
# Load model
print(f"Loading model from {args.model_path}...")
try:
model, tokenizer = load_model(args.model_path)
print("✓ Model loaded successfully")
except Exception as e:
print(f"✗ Failed to load model: {e}")
sys.exit(1)
# Configure probing
config = ProbeConfig(
capture_embeddings=not args.no_embeddings,
capture_ffn_activations=not args.no_ffn,
layer_indices=[int(x) for x in args.layers.split(",")] if args.layers else None
)
prober = CLIProber(model, tokenizer, config, verbose=args.verbose)
if args.interactive:
print("\n🔬 AFM7 Probing CLI - Interactive Mode")
print("Type 'quit' or 'exit' to end session")
print("Type 'export <filename>' to save last results to JSON")
print("-" * 40)
last_results = None
while True:
try:
user_input = input("\n📝 You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nExiting...")
break
if not user_input:
continue
if user_input.lower() in ('quit', 'exit'):
break
if user_input.lower().startswith('export '):
if last_results:
filename = user_input[7:].strip()
with open(filename, 'w') as f:
json.dump(last_results.to_dict(), f, indent=2)
print(f"✓ Exported to {filename}")
else:
print("No results to export yet")
continue
formatted = format_prompt(user_input, args.system_prompt)
print("\n⏳ Running inference with probing...")
results = prober.probe_generation(
formatted,
max_tokens=args.max_tokens,
temperature=args.temperature
)
last_results = results
print(f"\n🤖 Assistant: {results.generated_text}")
print_results_summary(results)
else:
# Single prompt mode
formatted = format_prompt(args.prompt, args.system_prompt)
print(f"\n⏳ Running inference with probing...")
results = prober.probe_generation(
formatted,
max_tokens=args.max_tokens,
temperature=args.temperature
)
print(f"\n🤖 Response: {results.generated_text}")
print_results_summary(results)
# Export if requested
if args.output:
with open(args.output, 'w') as f:
json.dump(results.to_dict(), f, indent=2)
print(f"\n✓ Results exported to {args.output}")
print(f" View in browser: open probe_viewer.html and load the JSON file")
if __name__ == "__main__":
main()