-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathtest_hook_completeness.py
More file actions
207 lines (163 loc) · 8.2 KB
/
test_hook_completeness.py
File metadata and controls
207 lines (163 loc) · 8.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""Test that all HookedTransformer hooks are available and functional in TransformerBridge.
This test ensures complete hook parity across different model architectures.
It verifies that:
1. All hooks from HookedTransformer exist in TransformerBridge
2. All hooks actually fire during forward pass
3. Hook activations match between the two implementations
"""
import os
import pytest
from transformer_lens import HookedTransformer
from transformer_lens.benchmarks import benchmark_forward_hooks, benchmark_hook_registry
from transformer_lens.model_bridge import TransformerBridge
pytestmark = pytest.mark.slow
# Diverse architectures for hook completeness testing.
# Constraint: these tests compare bridge vs legacy HookedTransformer, so each
# entry must be in HookedTransformer's OFFICIAL_MODEL_NAMES. Tiny C1-affected
# families (Llama/Qwen/Gemma under ~150M) aren't registered with HT; for those,
# tests/unit/model_bridge/test_component_hooks_fire.py (Tier 2) provides
# direct per-adapter hook-firing coverage without needing an HT counterpart.
MODELS_TO_TEST = [
"gpt2", # JointQKVAttentionBridge (standard decoder-only)
"EleutherAI/pythia-14m", # ParallelBlockBridge (C15 regression guard)
]
# Gemma2: local only (too large for CI)
if not os.getenv("CI"):
MODELS_TO_TEST.append("google/gemma-2-2b-it") # Gemma2 with unique normalization setup
class TestHookCompleteness:
"""Test suite for verifying complete hook coverage across architectures."""
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_all_hooks_exist(self, model_name):
"""Test that TransformerBridge has all hooks that HookedTransformer has.
This test verifies that the hook registry is complete - every hook in
HookedTransformer must exist in TransformerBridge.
"""
# Load both models
ht = HookedTransformer.from_pretrained_no_processing(model_name, device="cpu")
bridge = TransformerBridge.boot_transformers(model_name, device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
# Run benchmark
result = benchmark_hook_registry(bridge, reference_model=ht)
# Must pass - no missing hooks allowed
assert result.passed, (
f"Hook registry check failed for {model_name}:\n"
f" {result.message}\n"
f" Details: {result.details}"
)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_all_hooks_fire(self, model_name):
"""Test that all hooks actually fire during a forward pass.
This test verifies that hooks don't just exist in the registry, but
actually execute and capture activations during forward pass. This is
critical because hooks that don't fire indicate architectural bugs
(e.g., missing ln2 calls in patched forward methods).
"""
# Load both models
ht = HookedTransformer.from_pretrained_no_processing(model_name, device="cpu")
bridge = TransformerBridge.boot_transformers(model_name, device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
# Use a short prompt to speed up testing
test_text = "The quick brown fox"
# Run benchmark - this will fail if hooks don't fire
# tolerance=1e-2: some architectures (e.g., pythia) accumulate small floating-point
# differences across layers that exceed 1e-3 but are not meaningful divergences.
result = benchmark_forward_hooks(bridge, test_text, reference_model=ht, tolerance=1e-2)
# Must pass - all hooks must fire
assert result.passed, (
f"Forward hooks check failed for {model_name}:\n"
f" {result.message}\n"
f" Details: {result.details}\n"
f"\n"
f"This likely means:\n"
f" 1. Some hooks are missing from TransformerBridge, OR\n"
f" 2. Some hooks exist but don't fire during forward pass, OR\n"
f" 3. Hook activations don't match between implementations\n"
f"\n"
f"Check the 'missing_hooks' or 'didnt_fire_hooks' in details above."
)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_normalization_hooks_fire(self, model_name):
"""Test that all layer normalization hooks fire.
This is a targeted test for normalization hooks because they're
architecture-specific and prone to being missed (e.g., Gemma2's ln2).
"""
# Load bridge model
bridge = TransformerBridge.boot_transformers(model_name, device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
test_text = "Hello world"
# Track which normalization hooks fired
norm_hooks_fired = set()
def capture_hook(name):
def hook_fn(tensor, hook):
norm_hooks_fired.add(name)
return tensor
return hook_fn
# Register hooks for all normalization layers
hooks_to_test = []
for layer_idx in range(bridge.cfg.n_layers):
# Test ln1, ln2 hook_normalized for each layer
for norm_name in ["ln1", "ln2"]:
hook_name = f"blocks.{layer_idx}.{norm_name}.hook_normalized"
if hook_name in bridge.hook_dict:
hooks_to_test.append((hook_name, capture_hook(hook_name)))
# Also test ln_final
if "ln_final.hook_normalized" in bridge.hook_dict:
hooks_to_test.append(
("ln_final.hook_normalized", capture_hook("ln_final.hook_normalized"))
)
# Run forward pass with hooks
bridge.run_with_hooks(test_text, fwd_hooks=hooks_to_test)
# Verify all hooks fired
expected_hooks = {name for name, _ in hooks_to_test}
missing_hooks = expected_hooks - norm_hooks_fired
assert not missing_hooks, (
f"Normalization hooks didn't fire for {model_name}:\n"
f" Missing: {sorted(missing_hooks)}\n"
f" Total hooks tested: {len(expected_hooks)}\n"
f" Hooks that fired: {len(norm_hooks_fired)}\n"
f"\n"
f"This indicates a bug in the block's patched forward method.\n"
f"The normalization layers exist but aren't being called during forward pass."
)
class TestArchitectureSpecificHooks:
"""Test architecture-specific hook requirements."""
@pytest.mark.skipif(bool(os.getenv("CI")), reason="Gemma2 is too large for CI")
def test_gemma2_ln2_hook(self):
"""Specific test for Gemma2 ln2 hook (regression test).
Gemma2 has unique architecture with 4 normalization layers per block.
This test ensures ln2 (pre_feedforward_layernorm) fires correctly.
"""
bridge = TransformerBridge.boot_transformers("google/gemma-2-2b-it", device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
test_text = "Test"
ln2_fired = []
def ln2_hook(tensor, hook):
ln2_fired.append(hook.name)
return tensor
# Test ln2 for all layers
hooks = [(f"blocks.{i}.ln2.hook_normalized", ln2_hook) for i in range(bridge.cfg.n_layers)]
bridge.run_with_hooks(test_text, fwd_hooks=hooks)
# All ln2 hooks should fire
assert len(ln2_fired) == bridge.cfg.n_layers, (
f"Gemma2 ln2 hooks didn't fire!\n"
f" Expected: {bridge.cfg.n_layers}\n"
f" Got: {len(ln2_fired)}\n"
f" Fired: {ln2_fired}\n"
f"\n"
f"This is a regression - Gemma2's pre_feedforward_layernorm (ln2)\n"
f"must be called in the block's patched forward method."
)
if __name__ == "__main__":
# Run tests on GPT-2 when executed directly
print("Testing hook completeness on gpt2...")
test = TestHookCompleteness()
print("\n1. Testing hook registry...")
test.test_all_hooks_exist("gpt2")
print(" ✓ All hooks exist")
print("\n2. Testing hooks fire during forward pass...")
test.test_all_hooks_fire("gpt2")
print(" ✓ All hooks fire and match")
print("\n3. Testing normalization hooks...")
test.test_normalization_hooks_fire("gpt2")
print(" ✓ All normalization hooks fire")
print("\n✅ All hook completeness tests passed for gpt2!")