You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: chapters/04_05_transformers.md
+30-30Lines changed: 30 additions & 30 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,15 +17,15 @@ Before a transformer can process text, the raw string must be converted into a s
17
17
18
18
Modern LLMs use **subword tokenization**, which decomposes text into vocabulary units that are between characters and words in granularity. The dominant algorithm is **Byte Pair Encoding (BPE)**[@sennrich2016bpe], which starts from a character vocabulary and iteratively merges the most frequent adjacent pair of symbols until the vocabulary reaches a target size $|\cV|$ (typically 32k–128k tokens). Common words become single tokens; rare words are split into subword pieces. Because BPE operates on bytes, it handles any Unicode text without unknown-token issues.
19
19
20
-
Each discrete token index $\ell_t \in \{1, \ldots, |\cV|\}$ is then mapped to a continuous vector via a learned **embedding matrix** $\mbE \in \reals^{|\cV| \times D}$, giving $\mbc_t = \mbE_{\ell_t} \in \reals^D$.
20
+
Each discrete token index $z_t \in \{1, \ldots, |\cV|\}$ is then mapped to a continuous vector via a learned **embedding matrix** $\mbW_e \in \reals^{|\cV| \times D}$, giving $\mbx_t^{(0)} = (\mbW_e)_{z_t} \in \reals^D$.
21
21
22
-
Let $\mbX^{(0)} \in \reals^{T \times D}$ denote our data matrix with row $\mbx_t^{(0)} \in \reals^D$ representing the $t$-th token embedding. The embeddings may be fixed or learned as part of the model.
22
+
Let $\mbX^{(0)} \in \reals^{T \times D}$ denote the matrix of token embeddings with $\mbx_t^{(0)}$ as its $t$-th row.
23
23
24
-
The output of the transformer will be another matrix of the same shape, $\mbX^{(M)} \in \reals^{T \times D}$. These output features can be used for downstream tasks like sentiment classification, machine translation, or autoregressive modeling.
24
+
The output of the transformer will be another matrix of the same shape, $\mbX^{(L)} \in \reals^{T \times D}$. These output features can be used for downstream tasks like sentiment classification, machine translation, or autoregressive modeling.
25
25
26
26
The output results from a stack of transformer blocks,
Each block consists of two stages: one that operates vertically, combining information across the sequence length; another that operates horizontally, combining information across feature dimensions.
31
31
@@ -35,11 +35,11 @@ Each block consists of two stages: one that operates vertically, combining infor
35
35
36
36
The first stage combines information across sequence length using a mechanism called **attention**. Mathematically, attention is a weighted average,
37
37
$$
38
-
\mbY^{(m)} = \mbA^{(m)} \mbX^{(m-1)},
38
+
\mbY^{(\ell)} = \mbA^{(\ell)} \mbX^{(\ell-1)},
39
39
$$
40
-
where $\mbA^{(m)} \in \reals_+^{T \times T}$ is a row-stochastic attention matrix: $\sum_{s} A_{ts}^{(m)} = 1$ for all $t$. Intuitively, $A_{t,s}^{(m)}$ indicates how much output location $t$ attends to input location $s$.
40
+
where $\mbA^{(\ell)} \in \reals_+^{T \times T}$ is a row-stochastic attention matrix: $\sum_{s} A_{ts}^{(\ell)} = 1$ for all $t$. Intuitively, $A_{t,s}^{(\ell)}$ indicates how much output location $t$ attends to input location $s$.
41
41
42
-
When using transformers for autoregressive sequence modeling, we constrain the attention matrix to be **causal** by requiring $A_{t,s}^{(m)} = 0$ for all $s > t$ — i.e., the matrix is **lower triangular**.
42
+
When using transformers for autoregressive sequence modeling, we constrain the attention matrix to be **causal** by requiring $A_{t,s}^{(\ell)} = 0$ for all $s > t$ — i.e., the matrix is **lower triangular**.
(We drop the superscript ${}^{(m)}$ for clarity in this section.)
52
+
(We drop the superscript ${}^{(\ell)}$ for clarity in this section.)
53
53
54
54
In practice, different feature dimensions convey different kinds of information. Transformers use separate linear projections for the **queries** and **keys**:
where $\mbU_q \mbx_t \in \reals^{K}$ are the **queries**, $\mbU_k \mbx_s \in \reals^{K}$ are the **keys**, and the $1/\sqrt{K}$ factor prevents the dot products from growing large in magnitude [@vaswani2017attention].
58
+
where $\mbW_q \mbx_t \in \reals^{K}$ are the **queries**, $\mbW_k \mbx_s \in \reals^{K}$ are the **keys**, and the $1/\sqrt{K}$ factor prevents the dot products from growing large in magnitude [@vaswani2017attention].
59
59
60
60

61
61
62
62
:::{admonition} Causal attention
63
63
To enforce causality, we zero out the upper triangular entries of the attention matrix and renormalize:
Rather than propagating information about past tokens via a hidden state, a transformer with causal attention can directly attend to any past token.
70
+
Rather than propagating information about past tokens via a hidden state, a transformer with causal attention can directly attend to any past token. However, we will make a precise connection in the chapter on [Deep SSMs and Linear Attention](04_06_linear_attention): linearizing the attention kernel recovers a recurrent model, showing that are more closely connected than you might think.
71
71
:::
72
72
73
73
:::{admonition} Connection to Convolutional Neural Networks (CNNs)
@@ -76,29 +76,29 @@ If the attention weights were only a function of the distance between tokens, $A
76
76
77
77
### The KV Cache
78
78
79
-
During autoregressive generation, the model produces one token at a time. At step $t$, all keys $\mbU_k \mbx_{s}$ and values $\mbU_v \mbx_{s}$ for $s < t$ were already computed at previous steps. Rather than recomputing them, an efficient implementation stores them in a **KV cache** and retrieves them when generating each new token. This reduces the per-step cost from $O(t D^2)$ to $O(D^2)$, but the cache grows linearly with context length — a key memory bottleneck at inference time.
79
+
During autoregressive generation, the model produces one token at a time. At step $t$, all keys $\mbW_k \mbx_{s}$ and values $\mbW_v \mbx_{s}$ for $s < t$ were already computed at previous steps. Rather than recomputing them, an efficient implementation stores them in a **KV cache** and retrieves them when generating each new token. This reduces the per-step cost from $O(t D^2)$ to $O(D^2)$, but the cache grows linearly with context length — a key memory bottleneck at inference time.
80
80
81
81
### Multi-Headed Self-Attention
82
82
83
83
Just as a CNN uses a bank of filters in parallel, a transformer block uses $H$ **attention heads** in parallel. Let
The standard convention is $K = D / H$, so the per-head query/key/value dimension is much smaller than the full token dimension. In Qwen3-8B [@qwen3], for example, $D = 4096$, $H = 32$, giving $K = 128$ — a factor of 32 smaller than $D$. This keeps the total Q, K, and V parameter count at $3 \times H \times D \times K = 3D^2$, the same as a single $D \times D$ projection regardless of how many heads are used. It also means each head's read and write operations on the residual stream are genuinely low-rank: the head projects the $D$-dimensional stream down to $K$ dimensions to compute attention, then projects the result back up.
103
103
:::
104
104
@@ -111,21 +111,21 @@ Standard multi-head attention maintains $H$ independent sets of key and value pr
111
111
112
112
The residual connections in a transformer — introduced as an optimization technique in @he2016deep — have a deeper architectural interpretation. Writing out the full forward pass,
reveals that every component — every attention head and every MLP — adds its output directly onto a shared **residual stream** that begins as the token embedding and accumulates updates across all layers [@elhage2021mathematical].
117
117
118
118
This perspective has several consequences:
119
119
120
120
-**All components read from and write to the same space.** Attention heads and MLPs interact not through layer boundaries but through their shared updates to the stream. One head can write information that a later head in a different layer reads directly.
121
-
-**Each head is a low-rank read-write.** With the Q/K/V/O factorization, each attention head reads from the stream via $\mbU_q$ and $\mbU_k$, computes an update, and writes it back via $\mbU_o \mbU_v$ — a rank-$K$ operation on the $D$-dimensional stream.
121
+
-**Each head is a low-rank read-write.** With the Q/K/V/O factorization, each attention head reads from the stream via $\mbW_q$ and $\mbW_k$, computes an update, and writes it back via $\mbW_o \mbW_v$ — a rank-$K$ operation on the $D$-dimensional stream.
122
122
-**Residuals are structural, not incidental.** The residual stream framing is the foundation of mechanistic interpretability: by analyzing which subspaces different heads and MLPs read from and write to, researchers have identified circuits that implement specific algorithms (induction, copying, retrieval) inside trained transformers.
123
123
124
124
## Token-wise Nonlinearity
125
125
126
126
After the multi-headed self-attention step, the transformer applies a token-wise nonlinear transformation to mix feature dimensions. This is done with a feedforward network applied identically at each position,
127
127
$$
128
-
\mbx_t^{(m)} = \texttt{mlp}(\mby_t^{(m)}).
128
+
\mbx_t^{(\ell)} = \texttt{mlp}(\mby_t^{(\ell)}).
129
129
$$
130
130
131
131
:::{admonition} Computational Complexity
@@ -164,11 +164,11 @@ $$
164
164
where $\mbbeta, \mbgamma \in \reals^D$ are learned parameters. LayerNorm is applied before each sub-layer (Pre-LN), which yields more stable training than the original Post-LN design:
This defines one $\texttt{transformer-block}$. A transformer stacks $M$ such blocks to produce a deep sequence-to-sequence model.
171
+
This defines one $\texttt{transformer-block}$. A transformer stacks $L$ such blocks to produce a deep sequence-to-sequence model.
172
172
173
173
## Positional Encodings
174
174
@@ -178,9 +178,9 @@ Without explicit position information, a transformer treats its inputs as an **u
178
178
179
179
The original transformer adds a fixed position vector to each token embedding:
180
180
$$
181
-
\mbx_t^{(0)} = \mbc_t + \mbp_t,
181
+
\mbx_t^{(0)} \leftarrow \mbx_t^{(0)} + \mbp_t,
182
182
$$
183
-
where $\mbc_t \in \reals^D$ is the content embedding and $\mbp_t \in \reals^D$ encodes the position using sinusoidal basis functions [@vaswani2017attention]. Learned absolute position embeddings are also common.
183
+
where $\mbp_t \in \reals^D$ encodes the position using sinusoidal basis functions [@vaswani2017attention]. Learned absolute position embeddings are also common.
To use a transformer for autoregressive modeling, predictions are read from the final layer's representations. To predict the next token label $\ell_{t+1} \in \{1,\ldots,V\}$ given past tokens $\mbx_{1:t}^{(0)}$:
201
+
To use a transformer for autoregressive modeling, predictions are read from the final layer's representations. To predict the next token label $z_{t+1} \in \{1,\ldots,V\}$ given past tokens $z_{1:t}$:
where $\mbW \in \reals^{V \times D}$. Like hidden states in an RNN, the final-layer representations $\mbx_t^{(M)}$ aggregate information from all tokens up to index $t$.
205
+
where $\mbW_u \in \reals^{V \times D}$ is the **unembedding matrix**. Like hidden states in an RNN, the final-layer representations $\mbx_t^{(L)}$ aggregate information from all tokens up to index $t$.
0 commit comments