Skip to content

Commit 5499009

Browse files
slindermanclaude
andcommitted
Add head dimension admonition with Qwen3-8B example to transformer chapter
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 15c8ad4 commit 5499009

1 file changed

Lines changed: 30 additions & 30 deletions

File tree

chapters/04_05_transformers.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ Before a transformer can process text, the raw string must be converted into a s
1717

1818
Modern LLMs use **subword tokenization**, which decomposes text into vocabulary units that are between characters and words in granularity. The dominant algorithm is **Byte Pair Encoding (BPE)** [@sennrich2016bpe], which starts from a character vocabulary and iteratively merges the most frequent adjacent pair of symbols until the vocabulary reaches a target size $|\cV|$ (typically 32k–128k tokens). Common words become single tokens; rare words are split into subword pieces. Because BPE operates on bytes, it handles any Unicode text without unknown-token issues.
1919

20-
Each discrete token index $\ell_t \in \{1, \ldots, |\cV|\}$ is then mapped to a continuous vector via a learned **embedding matrix** $\mbE \in \reals^{|\cV| \times D}$, giving $\mbc_t = \mbE_{\ell_t} \in \reals^D$.
20+
Each discrete token index $z_t \in \{1, \ldots, |\cV|\}$ is then mapped to a continuous vector via a learned **embedding matrix** $\mbW_e \in \reals^{|\cV| \times D}$, giving $\mbx_t^{(0)} = (\mbW_e)_{z_t} \in \reals^D$.
2121

22-
Let $\mbX^{(0)} \in \reals^{T \times D}$ denote our data matrix with row $\mbx_t^{(0)} \in \reals^D$ representing the $t$-th token embedding. The embeddings may be fixed or learned as part of the model.
22+
Let $\mbX^{(0)} \in \reals^{T \times D}$ denote the matrix of token embeddings with $\mbx_t^{(0)}$ as its $t$-th row.
2323

24-
The output of the transformer will be another matrix of the same shape, $\mbX^{(M)} \in \reals^{T \times D}$. These output features can be used for downstream tasks like sentiment classification, machine translation, or autoregressive modeling.
24+
The output of the transformer will be another matrix of the same shape, $\mbX^{(L)} \in \reals^{T \times D}$. These output features can be used for downstream tasks like sentiment classification, machine translation, or autoregressive modeling.
2525

2626
The output results from a stack of transformer blocks,
2727
$$
28-
\mbX^{(m)} = \texttt{transformer-block}(\mbX^{(m-1)}).
28+
\mbX^{(\ell)} = \texttt{transformer-block}(\mbX^{(\ell-1)}).
2929
$$
3030
Each block consists of two stages: one that operates vertically, combining information across the sequence length; another that operates horizontally, combining information across feature dimensions.
3131

@@ -35,11 +35,11 @@ Each block consists of two stages: one that operates vertically, combining infor
3535

3636
The first stage combines information across sequence length using a mechanism called **attention**. Mathematically, attention is a weighted average,
3737
$$
38-
\mbY^{(m)} = \mbA^{(m)} \mbX^{(m-1)},
38+
\mbY^{(\ell)} = \mbA^{(\ell)} \mbX^{(\ell-1)},
3939
$$
40-
where $\mbA^{(m)} \in \reals_+^{T \times T}$ is a row-stochastic attention matrix: $\sum_{s} A_{ts}^{(m)} = 1$ for all $t$. Intuitively, $A_{t,s}^{(m)}$ indicates how much output location $t$ attends to input location $s$.
40+
where $\mbA^{(\ell)} \in \reals_+^{T \times T}$ is a row-stochastic attention matrix: $\sum_{s} A_{ts}^{(\ell)} = 1$ for all $t$. Intuitively, $A_{t,s}^{(\ell)}$ indicates how much output location $t$ attends to input location $s$.
4141

42-
When using transformers for autoregressive sequence modeling, we constrain the attention matrix to be **causal** by requiring $A_{t,s}^{(m)} = 0$ for all $s > t$ — i.e., the matrix is **lower triangular**.
42+
When using transformers for autoregressive sequence modeling, we constrain the attention matrix to be **causal** by requiring $A_{t,s}^{(\ell)} = 0$ for all $s > t$ — i.e., the matrix is **lower triangular**.
4343

4444
![Self Attention](../figures/14_transformers/self-attention-1.png)
4545

@@ -49,25 +49,25 @@ Where does the attention matrix come from? In a transformer, the attention weigh
4949
$$
5050
A_{t,s} = \frac{\exp \{ \mbx_t^\top \mbx_s\}}{\sum_{s'=1}^T \exp\{\mbx_t^\top \mbx_{s'}\}}.
5151
$$
52-
(We drop the superscript ${}^{(m)}$ for clarity in this section.)
52+
(We drop the superscript ${}^{(\ell)}$ for clarity in this section.)
5353

5454
In practice, different feature dimensions convey different kinds of information. Transformers use separate linear projections for the **queries** and **keys**:
5555
$$
56-
A_{t,s} = \frac{\exp \{ (\mbU_q \mbx_t)^\top (\mbU_k \mbx_s) / \sqrt{K}\}}{\sum_{s'=1}^T \exp\{(\mbU_q \mbx_t)^\top (\mbU_k \mbx_{s'}) / \sqrt{K}\}},
56+
A_{t,s} = \frac{\exp \{ (\mbW_q \mbx_t)^\top (\mbW_k \mbx_s) / \sqrt{K}\}}{\sum_{s'=1}^T \exp\{(\mbW_q \mbx_t)^\top (\mbW_k \mbx_{s'}) / \sqrt{K}\}},
5757
$$
58-
where $\mbU_q \mbx_t \in \reals^{K}$ are the **queries**, $\mbU_k \mbx_s \in \reals^{K}$ are the **keys**, and the $1/\sqrt{K}$ factor prevents the dot products from growing large in magnitude [@vaswani2017attention].
58+
where $\mbW_q \mbx_t \in \reals^{K}$ are the **queries**, $\mbW_k \mbx_s \in \reals^{K}$ are the **keys**, and the $1/\sqrt{K}$ factor prevents the dot products from growing large in magnitude [@vaswani2017attention].
5959

6060
![Self Attention with Queries and Keys](../figures/14_transformers/self-attention-2.png)
6161

6262
:::{admonition} Causal attention
6363
To enforce causality, we zero out the upper triangular entries of the attention matrix and renormalize:
6464
$$
65-
A_{t,s} = \frac{\exp \{ (\mbU_q \mbx_t)^\top (\mbU_k \mbx_s) / \sqrt{K}\}}{\sum_{s'=1}^{t} \exp\{(\mbU_q \mbx_t)^\top (\mbU_k \mbx_{s'}) / \sqrt{K}\}} \cdot \bbI[t \geq s].
65+
A_{t,s} = \frac{\exp \{ (\mbW_q \mbx_t)^\top (\mbW_k \mbx_s) / \sqrt{K}\}}{\sum_{s'=1}^{t} \exp\{(\mbW_q \mbx_t)^\top (\mbW_k \mbx_{s'}) / \sqrt{K}\}} \cdot \bbI[t \geq s].
6666
$$
6767
:::
6868

6969
:::{admonition} Comparison to RNNs
70-
Rather than propagating information about past tokens via a hidden state, a transformer with causal attention can directly attend to any past token.
70+
Rather than propagating information about past tokens via a hidden state, a transformer with causal attention can directly attend to any past token. However, we will make a precise connection in the chapter on [Deep SSMs and Linear Attention](04_06_linear_attention): linearizing the attention kernel recovers a recurrent model, showing that are more closely connected than you might think.
7171
:::
7272

7373
:::{admonition} Connection to Convolutional Neural Networks (CNNs)
@@ -76,29 +76,29 @@ If the attention weights were only a function of the distance between tokens, $A
7676

7777
### The KV Cache
7878

79-
During autoregressive generation, the model produces one token at a time. At step $t$, all keys $\mbU_k \mbx_{s}$ and values $\mbU_v \mbx_{s}$ for $s < t$ were already computed at previous steps. Rather than recomputing them, an efficient implementation stores them in a **KV cache** and retrieves them when generating each new token. This reduces the per-step cost from $O(t D^2)$ to $O(D^2)$, but the cache grows linearly with context length — a key memory bottleneck at inference time.
79+
During autoregressive generation, the model produces one token at a time. At step $t$, all keys $\mbW_k \mbx_{s}$ and values $\mbW_v \mbx_{s}$ for $s < t$ were already computed at previous steps. Rather than recomputing them, an efficient implementation stores them in a **KV cache** and retrieves them when generating each new token. This reduces the per-step cost from $O(t D^2)$ to $O(D^2)$, but the cache grows linearly with context length — a key memory bottleneck at inference time.
8080

8181
### Multi-Headed Self-Attention
8282

8383
Just as a CNN uses a bank of filters in parallel, a transformer block uses $H$ **attention heads** in parallel. Let
8484
$$
85-
\mbY^{(m,h)} = \mbA^{(m,h)} \mbX^{(m-1)} \mbU_v^{(m,h)\top} \in \reals^{T \times K},
85+
\mbY^{(\ell,h)} = \mbA^{(\ell,h)} \mbX^{(\ell-1)} \mbW_v^{(\ell,h)\top} \in \reals^{T \times K},
8686
$$
8787
where
8888
$$
89-
A_{t,s}^{(m,h)} =
90-
\frac{\exp \{ (\mbU_q^{(m,h)} \mbx_t^{(m-1)})^\top (\mbU_k^{(m,h)} \mbx_s^{(m-1)}) / \sqrt{K}\}}{\sum_{s'=1}^T \exp\{(\mbU_q^{(m,h)} \mbx_t^{(m-1)})^\top (\mbU_k^{(m,h)} \mbx_{s'}^{(m-1)}) / \sqrt{K}\}}
89+
A_{t,s}^{(\ell,h)} =
90+
\frac{\exp \{ (\mbW_q^{(\ell,h)} \mbx_t^{(\ell-1)})^\top (\mbW_k^{(\ell,h)} \mbx_s^{(\ell-1)}) / \sqrt{K}\}}{\sum_{s'=1}^T \exp\{(\mbW_q^{(\ell,h)} \mbx_t^{(\ell-1)})^\top (\mbW_k^{(\ell,h)} \mbx_{s'}^{(\ell-1)}) / \sqrt{K}\}}
9191
$$
9292
for $h = 1, \ldots, H$. The outputs are projected and summed:
9393
$$
94-
\mbY^{(m)} = \sum_{h=1}^H \mbY^{(m,h)} \mbU_o^{(m,h)\top} \triangleq \texttt{mhsa}(\mbX^{(m-1)}),
94+
\mbY^{(\ell)} = \sum_{h=1}^H \mbY^{(\ell,h)} \mbW_o^{(\ell,h)\top} \triangleq \texttt{mhsa}(\mbX^{(\ell-1)}),
9595
$$
96-
where $\mbU_o^{(m,h)} \in \reals^{D \times K}$ maps each head's output back to the token dimension.
96+
where $\mbW_o^{(\ell,h)} \in \reals^{D \times K}$ maps each head's output back to the token dimension.
9797

9898
![Multi-Headed Self Attention](../figures/14_transformers/mhsa-2.png)
9999

100100
:::{admonition} Head dimension in practice
101-
:class: note
101+
:class: note dropdown
102102
The standard convention is $K = D / H$, so the per-head query/key/value dimension is much smaller than the full token dimension. In Qwen3-8B [@qwen3], for example, $D = 4096$, $H = 32$, giving $K = 128$ — a factor of 32 smaller than $D$. This keeps the total Q, K, and V parameter count at $3 \times H \times D \times K = 3D^2$, the same as a single $D \times D$ projection regardless of how many heads are used. It also means each head's read and write operations on the residual stream are genuinely low-rank: the head projects the $D$-dimensional stream down to $K$ dimensions to compute attention, then projects the result back up.
103103
:::
104104

@@ -111,21 +111,21 @@ Standard multi-head attention maintains $H$ independent sets of key and value pr
111111

112112
The residual connections in a transformer — introduced as an optimization technique in @he2016deep — have a deeper architectural interpretation. Writing out the full forward pass,
113113
$$
114-
\mbx_t^{(M)} = \mbx_t^{(0)} + \sum_{m=1}^{M} \left[ \texttt{mhsa}^{(m)}(\mbX^{(m-1)})_t + \texttt{mlp}^{(m)}(\mby_t^{(m)}) \right],
114+
\mbx_t^{(L)} = \mbx_t^{(0)} + \sum_{\ell=1}^{L} \left[ \texttt{mhsa}^{(\ell)}(\mbX^{(\ell-1)})_t + \texttt{mlp}^{(\ell)}(\mby_t^{(\ell)}) \right],
115115
$$
116116
reveals that every component — every attention head and every MLP — adds its output directly onto a shared **residual stream** that begins as the token embedding and accumulates updates across all layers [@elhage2021mathematical].
117117

118118
This perspective has several consequences:
119119

120120
- **All components read from and write to the same space.** Attention heads and MLPs interact not through layer boundaries but through their shared updates to the stream. One head can write information that a later head in a different layer reads directly.
121-
- **Each head is a low-rank read-write.** With the Q/K/V/O factorization, each attention head reads from the stream via $\mbU_q$ and $\mbU_k$, computes an update, and writes it back via $\mbU_o \mbU_v$ — a rank-$K$ operation on the $D$-dimensional stream.
121+
- **Each head is a low-rank read-write.** With the Q/K/V/O factorization, each attention head reads from the stream via $\mbW_q$ and $\mbW_k$, computes an update, and writes it back via $\mbW_o \mbW_v$ — a rank-$K$ operation on the $D$-dimensional stream.
122122
- **Residuals are structural, not incidental.** The residual stream framing is the foundation of mechanistic interpretability: by analyzing which subspaces different heads and MLPs read from and write to, researchers have identified circuits that implement specific algorithms (induction, copying, retrieval) inside trained transformers.
123123

124124
## Token-wise Nonlinearity
125125

126126
After the multi-headed self-attention step, the transformer applies a token-wise nonlinear transformation to mix feature dimensions. This is done with a feedforward network applied identically at each position,
127127
$$
128-
\mbx_t^{(m)} = \texttt{mlp}(\mby_t^{(m)}).
128+
\mbx_t^{(\ell)} = \texttt{mlp}(\mby_t^{(\ell)}).
129129
$$
130130

131131
:::{admonition} Computational Complexity
@@ -164,11 +164,11 @@ $$
164164
where $\mbbeta, \mbgamma \in \reals^D$ are learned parameters. LayerNorm is applied before each sub-layer (Pre-LN), which yields more stable training than the original Post-LN design:
165165
$$
166166
\begin{aligned}
167-
\mbY^{(m)} &= \mbX^{(m-1)} + \texttt{mhsa}(\texttt{layer-norm}(\mbX^{(m-1)})) \\
168-
\mbX^{(m)} &= \mbY^{(m)} + \texttt{mlp}(\texttt{layer-norm}(\mbY^{(m)})).
167+
\mbY^{(\ell)} &= \mbX^{(\ell-1)} + \texttt{mhsa}(\texttt{layer-norm}(\mbX^{(\ell-1)})) \\
168+
\mbX^{(\ell)} &= \mbY^{(\ell)} + \texttt{mlp}(\texttt{layer-norm}(\mbY^{(\ell)})).
169169
\end{aligned}
170170
$$
171-
This defines one $\texttt{transformer-block}$. A transformer stacks $M$ such blocks to produce a deep sequence-to-sequence model.
171+
This defines one $\texttt{transformer-block}$. A transformer stacks $L$ such blocks to produce a deep sequence-to-sequence model.
172172

173173
## Positional Encodings
174174

@@ -178,9 +178,9 @@ Without explicit position information, a transformer treats its inputs as an **u
178178

179179
The original transformer adds a fixed position vector to each token embedding:
180180
$$
181-
\mbx_t^{(0)} = \mbc_t + \mbp_t,
181+
\mbx_t^{(0)} \leftarrow \mbx_t^{(0)} + \mbp_t,
182182
$$
183-
where $\mbc_t \in \reals^D$ is the content embedding and $\mbp_t \in \reals^D$ encodes the position using sinusoidal basis functions [@vaswani2017attention]. Learned absolute position embeddings are also common.
183+
where $\mbp_t \in \reals^D$ encodes the position using sinusoidal basis functions [@vaswani2017attention]. Learned absolute position embeddings are also common.
184184

185185
### Rotary Position Embeddings (RoPE)
186186

@@ -198,11 +198,11 @@ Because $(\mbR_t \mbq_t)^\top (\mbR_s \mbk_s) = \mbq_t^\top \mbR_t^\top \mbR_s \
198198

199199
## Autoregressive Modeling
200200

201-
To use a transformer for autoregressive modeling, predictions are read from the final layer's representations. To predict the next token label $\ell_{t+1} \in \{1,\ldots,V\}$ given past tokens $\mbx_{1:t}^{(0)}$:
201+
To use a transformer for autoregressive modeling, predictions are read from the final layer's representations. To predict the next token label $z_{t+1} \in \{1,\ldots,V\}$ given past tokens $z_{1:t}$:
202202
$$
203-
\ell_{t+1} \sim \mathrm{Cat}(\mathrm{softmax}(\mbW \mbx_t^{(M)})),
203+
z_{t+1} \sim \mathrm{Cat}(\mathrm{softmax}(\mbW_u \mbx_t^{(L)})),
204204
$$
205-
where $\mbW \in \reals^{V \times D}$. Like hidden states in an RNN, the final-layer representations $\mbx_t^{(M)}$ aggregate information from all tokens up to index $t$.
205+
where $\mbW_u \in \reals^{V \times D}$ is the **unembedding matrix**. Like hidden states in an RNN, the final-layer representations $\mbx_t^{(L)}$ aggregate information from all tokens up to index $t$.
206206

207207
## Training
208208

0 commit comments

Comments
 (0)