Skip to content

Commit 6f56518

Browse files
authored
Merge pull request #1277 from TransformerLensOrg/dev
TransformerLens 3.1.0
2 parents 58b007f + 0a5218c commit 6f56518

48 files changed

Lines changed: 5138 additions & 463 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
5050
logits, activations = bridge.run_with_cache("Hello World")
5151
```
5252

53-
`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. The legacy `HookedTransformer.from_pretrained` API is still available through a compatibility layer but is deprecated - see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide for conversion recipes.
53+
`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. By default it preserves raw HuggingFace weights – logits and activations match HF, *not* legacy `HookedTransformer` (which folds LayerNorm and centers weights by default). Call `bridge.enable_compatibility_mode()` after booting for HookedTransformer-equivalent numerics. The legacy `HookedTransformer.from_pretrained` API is still available but deprecated see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide.
5454

5555
## Key Tutorials
5656

demos/LLaMA2_GPU_Quantized.ipynb

Lines changed: 78 additions & 101 deletions
Large diffs are not rendered by default.

demos/T5.ipynb

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@
8888
"generated token: \",\", token id: 6\n",
8989
"generated token: \"comment\", token id: 1670\n",
9090
"generated token: \"\", token id: 3\n",
91-
"generated token: \"\u00eates\", token id: 6738\n",
91+
"generated token: \"êtes\", token id: 6738\n",
9292
"generated token: \"-\", token id: 18\n",
9393
"generated token: \"vous\", token id: 3249\n",
9494
"generated token: \"\", token id: 3\n",
9595
"generated token: \"?\", token id: 58\n",
9696
"generated token: \"</s>\", token id: 1\n",
9797
"translate English to French: Hello, how are you? \n",
98-
" Bonjour, comment \u00eates-vous?\n"
98+
" Bonjour, comment êtes-vous?\n"
9999
]
100100
}
101101
],
@@ -206,7 +206,7 @@
206206
},
207207
{
208208
"cell_type": "code",
209-
"execution_count": 8,
209+
"execution_count": null,
210210
"metadata": {
211211
"execution": {
212212
"iopub.execute_input": "2026-03-05T18:28:00.478310Z",
@@ -215,21 +215,8 @@
215215
"shell.execute_reply": "2026-03-05T18:28:00.629766Z"
216216
}
217217
},
218-
"outputs": [
219-
{
220-
"name": "stdout",
221-
"output_type": "stream",
222-
"text": [
223-
"Hallo, magst du Bananen?\n"
224-
]
225-
}
226-
],
227-
"source": [
228-
"prompt=\"translate English to German: Hello, do you like bananas?\"\n",
229-
"\n",
230-
"output = model.generate(prompt, do_sample=False, max_new_tokens=20)\n",
231-
"print(output)"
232-
]
218+
"outputs": [],
219+
"source": "prompt=\"translate English to German: Hello, do you like bananas?\"\n\noutput = model.generate(prompt, do_sample=False, max_new_tokens=20, verbose=False)\nprint(output)"
233220
},
234221
{
235222
"cell_type": "markdown",
@@ -928,7 +915,7 @@
928915
"outputs": [],
929916
"source": [
930917
"encoder_attn_pattern = cache[\"encoder_blocks.0.attn.hook_pattern\"]\n",
931-
"input_str_tokens = [w.lstrip(\"\u2581\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]"
918+
"input_str_tokens = [w.lstrip(\"\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]"
932919
]
933920
},
934921
{
@@ -993,14 +980,14 @@
993980
"data": {
994981
"text/plain": [
995982
"['<pad>',\n",
996-
" '\u2581Bonjour',\n",
983+
" '▁Bonjour',\n",
997984
" ',',\n",
998-
" '\u2581comment',\n",
999-
" '\u2581',\n",
1000-
" '\u00eates',\n",
985+
" '▁comment',\n",
986+
" '',\n",
987+
" 'êtes',\n",
1001988
" '-',\n",
1002989
" 'vous',\n",
1003-
" '\u2581',\n",
990+
" '',\n",
1004991
" '?',\n",
1005992
" '</s>']"
1006993
]
@@ -1143,4 +1130,4 @@
11431130
},
11441131
"nbformat": 4,
11451132
"nbformat_minor": 2
1146-
}
1133+
}
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""<MODEL_NAME> architecture adapter.
2+
3+
TODO: Replace <MODEL_NAME> with the actual model name throughout this file.
4+
"""
5+
6+
from typing import Any
7+
8+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
9+
from transformer_lens.model_bridge.generalized_components import (
10+
BlockBridge,
11+
EmbeddingBridge,
12+
GatedMLPBridge,
13+
LinearBridge,
14+
PositionEmbeddingsAttentionBridge,
15+
RMSNormalizationBridge,
16+
RotaryEmbeddingBridge,
17+
UnembeddingBridge,
18+
)
19+
20+
21+
class ModelNameArchitectureAdapter(ArchitectureAdapter):
22+
"""Architecture adapter for <MODEL_NAME> models.
23+
24+
TODO: Document which parameters are optional (missing biases, etc.)
25+
26+
Optional Parameters (may not exist in state_dict):
27+
-------------------------------------------------
28+
TODO: List parameters that may not exist. Example for models without biases:
29+
30+
- blocks.{i}.attn.b_Q - No bias on query projection
31+
- blocks.{i}.attn.b_K - No bias on key projection
32+
- blocks.{i}.attn.b_V - No bias on value projection
33+
- blocks.{i}.attn.b_O - No bias on output projection
34+
- blocks.{i}.mlp.b_in - No bias on MLP input
35+
- blocks.{i}.mlp.b_gate - No bias on MLP gate projection
36+
- blocks.{i}.mlp.b_out - No bias on MLP output
37+
- blocks.{i}.ln1.b - RMSNorm has no bias
38+
- blocks.{i}.ln2.b - RMSNorm has no bias
39+
- ln_final.b - RMSNorm has no bias
40+
"""
41+
42+
def __init__(self, cfg: Any) -> None:
43+
"""Initialize the <MODEL_NAME> architecture adapter."""
44+
super().__init__(cfg)
45+
46+
# =====================================================================
47+
# 1. CONFIG ATTRIBUTES
48+
# Set these based on the HuggingFace model's architecture.
49+
# =====================================================================
50+
51+
# TODO: Set normalization type
52+
# "RMS" for RMSNorm (Llama, Qwen, Gemma, etc.)
53+
# "LN" for LayerNorm (GPT-2, GPT-J, etc.)
54+
self.cfg.normalization_type = "RMS"
55+
56+
# TODO: Set positional embedding type
57+
# "rotary" for RoPE (Llama, Qwen, Mistral, etc.)
58+
# "standard" for learned positional embeddings (GPT-2)
59+
self.cfg.positional_embedding_type = "rotary"
60+
61+
# TODO: Set these flags
62+
self.cfg.final_rms = True # True if final layer norm is RMSNorm
63+
self.cfg.gated_mlp = True # True if MLP has gate projection (SwiGLU)
64+
self.cfg.attn_only = False # True only for attention-only models (rare)
65+
self.cfg.uses_rms_norm = True # Should match normalization_type
66+
67+
# TODO: Set the epsilon attribute name used by this model's normalization
68+
# Check the HF model's norm layer to find the correct attribute name
69+
self.cfg.eps_attr = "variance_epsilon" # or "layer_norm_eps", "rms_norm_eps", etc.
70+
71+
# TODO: Handle GQA if applicable
72+
# If the model uses Grouped Query Attention (n_key_value_heads < n_heads):
73+
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
74+
self.cfg.n_key_value_heads = cfg.n_key_value_heads
75+
76+
# =====================================================================
77+
# 2. WEIGHT PROCESSING CONVERSIONS
78+
# Defines how to reshape weights from HF format to TL format.
79+
# For most models with separate Q/K/V/O, use the built-in helper.
80+
# =====================================================================
81+
82+
self.weight_processing_conversions = {
83+
**self._qkvo_weight_conversions(),
84+
# TODO: Add any model-specific weight conversions here
85+
}
86+
87+
# =====================================================================
88+
# 3. COMPONENT MAPPING
89+
# Maps TransformerLens canonical names to HuggingFace module paths.
90+
# The `name=` parameter is the HF path relative to the model root
91+
# (for top-level) or relative to the block (for block submodules).
92+
# =====================================================================
93+
94+
# TODO: Replace all HF paths (name="...") with actual paths from the model.
95+
# Inspect the HF model's named_modules() or config to find the correct paths.
96+
self.component_mapping = {
97+
# Token embedding
98+
"embed": EmbeddingBridge(name="model.embed_tokens"),
99+
# Rotary position embeddings (remove if model uses standard pos embeddings)
100+
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
101+
# Transformer blocks
102+
"blocks": BlockBridge(
103+
name="model.layers", # TODO: HF path to the layer list
104+
submodules={
105+
# Pre-attention layer norm
106+
"ln1": RMSNormalizationBridge(
107+
name="input_layernorm", # TODO: HF name within block
108+
config=self.cfg,
109+
),
110+
# Post-attention layer norm
111+
"ln2": RMSNormalizationBridge(
112+
name="post_attention_layernorm", # TODO: HF name within block
113+
config=self.cfg,
114+
),
115+
# Self-attention
116+
"attn": PositionEmbeddingsAttentionBridge(
117+
name="self_attn", # TODO: HF name within block
118+
config=self.cfg,
119+
submodules={
120+
"q": LinearBridge(name="q_proj"), # TODO: HF projection names
121+
"k": LinearBridge(name="k_proj"),
122+
"v": LinearBridge(name="v_proj"),
123+
"o": LinearBridge(name="o_proj"),
124+
},
125+
requires_attention_mask=True,
126+
requires_position_embeddings=True,
127+
),
128+
# MLP (gated)
129+
"mlp": GatedMLPBridge(
130+
name="mlp", # TODO: HF name within block
131+
config=self.cfg,
132+
submodules={
133+
"gate": LinearBridge(name="gate_proj"), # TODO: HF projection names
134+
"in": LinearBridge(name="up_proj"),
135+
"out": LinearBridge(name="down_proj"),
136+
},
137+
),
138+
},
139+
),
140+
# Final layer norm
141+
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
142+
# Output head (unembedding)
143+
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
144+
}
145+
146+
def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
147+
"""Set up model-specific references for component testing.
148+
149+
TODO: Required for RoPE models. Remove if model uses standard positional embeddings.
150+
"""
151+
# Get rotary embedding instance from the HF model
152+
rotary_emb = hf_model.model.rotary_emb # TODO: Adjust path if different
153+
154+
# Set rotary_emb on actual bridge instances
155+
if bridge_model is not None and hasattr(bridge_model, "blocks"):
156+
for block in bridge_model.blocks:
157+
if hasattr(block, "attn"):
158+
block.attn.set_rotary_emb(rotary_emb)
159+
160+
# Set on template for get_generalized_component() calls
161+
attn_bridge = self.get_generalized_component("blocks.0.attn")
162+
attn_bridge.set_rotary_emb(rotary_emb)

0 commit comments

Comments
 (0)