Skip to content

Commit fd288dc

Browse files
authored
Adding Architecture Adapter Creation Guide to Docs (#1274)
* Adding architecture Adapter creation guide, add split QKV example to quantized LLaMA demo * ignore docs/build from black linting
1 parent d95bd96 commit fd288dc

7 files changed

Lines changed: 1008 additions & 103 deletions

File tree

demos/LLaMA2_GPU_Quantized.ipynb

Lines changed: 78 additions & 101 deletions
Large diffs are not rendered by default.
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""<MODEL_NAME> architecture adapter.
2+
3+
TODO: Replace <MODEL_NAME> with the actual model name throughout this file.
4+
"""
5+
6+
from typing import Any
7+
8+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
9+
from transformer_lens.model_bridge.generalized_components import (
10+
BlockBridge,
11+
EmbeddingBridge,
12+
GatedMLPBridge,
13+
LinearBridge,
14+
PositionEmbeddingsAttentionBridge,
15+
RMSNormalizationBridge,
16+
RotaryEmbeddingBridge,
17+
UnembeddingBridge,
18+
)
19+
20+
21+
class ModelNameArchitectureAdapter(ArchitectureAdapter):
22+
"""Architecture adapter for <MODEL_NAME> models.
23+
24+
TODO: Document which parameters are optional (missing biases, etc.)
25+
26+
Optional Parameters (may not exist in state_dict):
27+
-------------------------------------------------
28+
TODO: List parameters that may not exist. Example for models without biases:
29+
30+
- blocks.{i}.attn.b_Q - No bias on query projection
31+
- blocks.{i}.attn.b_K - No bias on key projection
32+
- blocks.{i}.attn.b_V - No bias on value projection
33+
- blocks.{i}.attn.b_O - No bias on output projection
34+
- blocks.{i}.mlp.b_in - No bias on MLP input
35+
- blocks.{i}.mlp.b_gate - No bias on MLP gate projection
36+
- blocks.{i}.mlp.b_out - No bias on MLP output
37+
- blocks.{i}.ln1.b - RMSNorm has no bias
38+
- blocks.{i}.ln2.b - RMSNorm has no bias
39+
- ln_final.b - RMSNorm has no bias
40+
"""
41+
42+
def __init__(self, cfg: Any) -> None:
43+
"""Initialize the <MODEL_NAME> architecture adapter."""
44+
super().__init__(cfg)
45+
46+
# =====================================================================
47+
# 1. CONFIG ATTRIBUTES
48+
# Set these based on the HuggingFace model's architecture.
49+
# =====================================================================
50+
51+
# TODO: Set normalization type
52+
# "RMS" for RMSNorm (Llama, Qwen, Gemma, etc.)
53+
# "LN" for LayerNorm (GPT-2, GPT-J, etc.)
54+
self.cfg.normalization_type = "RMS"
55+
56+
# TODO: Set positional embedding type
57+
# "rotary" for RoPE (Llama, Qwen, Mistral, etc.)
58+
# "standard" for learned positional embeddings (GPT-2)
59+
self.cfg.positional_embedding_type = "rotary"
60+
61+
# TODO: Set these flags
62+
self.cfg.final_rms = True # True if final layer norm is RMSNorm
63+
self.cfg.gated_mlp = True # True if MLP has gate projection (SwiGLU)
64+
self.cfg.attn_only = False # True only for attention-only models (rare)
65+
self.cfg.uses_rms_norm = True # Should match normalization_type
66+
67+
# TODO: Set the epsilon attribute name used by this model's normalization
68+
# Check the HF model's norm layer to find the correct attribute name
69+
self.cfg.eps_attr = "variance_epsilon" # or "layer_norm_eps", "rms_norm_eps", etc.
70+
71+
# TODO: Handle GQA if applicable
72+
# If the model uses Grouped Query Attention (n_key_value_heads < n_heads):
73+
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
74+
self.cfg.n_key_value_heads = cfg.n_key_value_heads
75+
76+
# =====================================================================
77+
# 2. WEIGHT PROCESSING CONVERSIONS
78+
# Defines how to reshape weights from HF format to TL format.
79+
# For most models with separate Q/K/V/O, use the built-in helper.
80+
# =====================================================================
81+
82+
self.weight_processing_conversions = {
83+
**self._qkvo_weight_conversions(),
84+
# TODO: Add any model-specific weight conversions here
85+
}
86+
87+
# =====================================================================
88+
# 3. COMPONENT MAPPING
89+
# Maps TransformerLens canonical names to HuggingFace module paths.
90+
# The `name=` parameter is the HF path relative to the model root
91+
# (for top-level) or relative to the block (for block submodules).
92+
# =====================================================================
93+
94+
# TODO: Replace all HF paths (name="...") with actual paths from the model.
95+
# Inspect the HF model's named_modules() or config to find the correct paths.
96+
self.component_mapping = {
97+
# Token embedding
98+
"embed": EmbeddingBridge(name="model.embed_tokens"),
99+
# Rotary position embeddings (remove if model uses standard pos embeddings)
100+
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
101+
# Transformer blocks
102+
"blocks": BlockBridge(
103+
name="model.layers", # TODO: HF path to the layer list
104+
submodules={
105+
# Pre-attention layer norm
106+
"ln1": RMSNormalizationBridge(
107+
name="input_layernorm", # TODO: HF name within block
108+
config=self.cfg,
109+
),
110+
# Post-attention layer norm
111+
"ln2": RMSNormalizationBridge(
112+
name="post_attention_layernorm", # TODO: HF name within block
113+
config=self.cfg,
114+
),
115+
# Self-attention
116+
"attn": PositionEmbeddingsAttentionBridge(
117+
name="self_attn", # TODO: HF name within block
118+
config=self.cfg,
119+
submodules={
120+
"q": LinearBridge(name="q_proj"), # TODO: HF projection names
121+
"k": LinearBridge(name="k_proj"),
122+
"v": LinearBridge(name="v_proj"),
123+
"o": LinearBridge(name="o_proj"),
124+
},
125+
requires_attention_mask=True,
126+
requires_position_embeddings=True,
127+
),
128+
# MLP (gated)
129+
"mlp": GatedMLPBridge(
130+
name="mlp", # TODO: HF name within block
131+
config=self.cfg,
132+
submodules={
133+
"gate": LinearBridge(name="gate_proj"), # TODO: HF projection names
134+
"in": LinearBridge(name="up_proj"),
135+
"out": LinearBridge(name="down_proj"),
136+
},
137+
),
138+
},
139+
),
140+
# Final layer norm
141+
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
142+
# Output head (unembedding)
143+
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
144+
}
145+
146+
def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
147+
"""Set up model-specific references for component testing.
148+
149+
TODO: Required for RoPE models. Remove if model uses standard positional embeddings.
150+
"""
151+
# Get rotary embedding instance from the HF model
152+
rotary_emb = hf_model.model.rotary_emb # TODO: Adjust path if different
153+
154+
# Set rotary_emb on actual bridge instances
155+
if bridge_model is not None and hasattr(bridge_model, "blocks"):
156+
for block in bridge_model.blocks:
157+
if hasattr(block, "attn"):
158+
block.attn.set_rotary_emb(rotary_emb)
159+
160+
# Set on template for get_generalized_component() calls
161+
attn_bridge = self.get_generalized_component("blocks.0.attn")
162+
attn_bridge.set_rotary_emb(rotary_emb)

0 commit comments

Comments
 (0)