RNNs are natural models for sequential data, but the
Transformers underlie large language models (LLMs) like Open AI's ChatGPT and Google's Gemini. They are also widely used in computer vision and other domains of machine learning. This lecture will walk through the basic building blocks of a transformer: self-attention, token-wise nonlinear transformations, layer norm, and positional encodings. We will focus on modeling sequential data. We will follow the presentation of {cite:t}turner2023introduction, but we will make some slight modifications to the notation to be consistent with our previous notes and Homework 4.
Let
The output of the transformer will be another matrix of the same shape,
The output results from a stack of transformer blocks, \begin{align*} \mbX^{(m)} &= \texttt{transformer-block}(\mbX^{(m-1)}) \end{align*} Each block consists of two stages: one that operates vertically, combining information across the sequence length; another that operates horizontally, combining information across feature dimensions.
The first stage combines information across sequence length using a mechanism called attention. Mathematically, attention is just a weighted average,
\begin{align*}
\mbY^{(m)} &= \mbA^{(m)} \mbX^{(m-1)},
\end{align*}
where
When we are using transformers for autoregressive sequence modeling, we constrain the attention matrix to be causal by requiring
Where does the attention matrix come from? In a transformer, the attention weights are determined by the pairwise similarity of tokens in the sequence. The simplest instantiation of this idea would be something like,
\begin{align*}
A_{t,s} &\propto \exp \left{ \mbx_t^\top \mbx_s \right}
\end{align*}
Once normalized,
\begin{align*}
A_{t,s} &= \frac{\exp { \mbx_t^\top \mbx_s}}{\sum_{s'=1}^T \exp{\mbx_t^\top \mbx_{s'}}}.
\end{align*}
(Note: we have dropped the superscript
This approach implies that attention depends equally on all
Still, the numerator in this attention weight is symmetric. If we think of the attention weight as specifying how relevant token
Transformers use a more general form of attention to address this asymmetry,
\begin{align*}
A_{t,s} &= \frac{\exp { (\mbU_q \mbx_t)^\top (\mbU_k \mbx_s)}}{\sum_{s'=1}^T \exp{(\mbU_q \mbx_t)^\top (\mbU_k \mbx_{s'}) }},
\end{align*}
where
The parameters
:::{admonition} Causal attention To enforce causality in the attention layer, we simply zero out the upper triangular part of the attention matrix and normalize the rows appropriately, \begin{align*} A_{t,s} &= \frac{\exp { (\mbU_q \mbx_t)^\top (\mbU_k \mbx_s)}}{\textcolor{red}{\sum_{s'=1}^{t}} \exp{(\mbU_q \mbx_t)^\top (\mbU_k \mbx_{s'})}} \cdot \bbI[t \geq s]. \end{align*}
:::
:::{admonition} Comparison to RNNs Compare the self-attention mechanism to the information processing in an RNN. Rather than propagating information about past tokens via a hidden state, a transformer with causal attention can directly attend to any one of the past tokens. :::
:::{admonition} Connection to Convolutional Neural Networks (CNNs)
If the attention weights were only a function of the distance between tokens,
Just as in a CNN each layer performs convolutions with a bank of filters in parallel, in a transformer each block uses a bank of
The outputs of the attention heads are either concatenated or linearly combined,
\begin{align*}
\mbY^{(m)} &= \sum_{h=1}^H \mbY^{(m,h)} (\mbV^{(m,h)})^\top \
&\triangleq \texttt{mhsa}(\mbX^{(m-1)}).
\end{align*}
where
:::{admonition} Queries, Keys, and Values
The original transformer paper presents the output of a single head as a set of values,
\begin{align*}
\mbY^{(m,h)} &= \mbA^{(m,h)} (\mbX^{(m-1)} (\mbU_v^{(m,h)})^\top) \in \reals^{T \times K}
\end{align*}
where
The final output is projected back into the original token dimension and linearly combined,
\begin{align*}
\mbY^{(m)} &= \sum_{h=1}^H \mbY^{(m,h)} (\mbU_o^{(m,h)})^\top
\end{align*}
where
This formulation corresponds to a low-rank read-out matrix
After applying the multi-headed self-attention to obtain
:::{admonition} Computational Complexity
:class: warning
The MLP typically has hidden dimensions of at least
Rather than parameterizing
Transformers use residual connections for both the multi-headed self-attention step and the MLP. So, \begin{align*} \mbY^{(m)} &= \mbX^{(m-1)} + \texttt{mhsa}(\mbX^{(m-1)}) \ \mbX^{(m)} &= \mbY^{(m)} + \texttt{mlp}(\mbY^{(m)}) \end{align*}
Finally, it is important to use other deep learning "tricks" like LayerNorm to stabilize training. In a transformer, LayerNorm amounts to z-scoring each token
LayerNorm is typically applied before the multi-headed self-attention and MLP steps,
\begin{align*}
\overline{\mbX}^{(m-1)} &= \texttt{layer-norm}(\mbX^{(m-1)}) \
\mbY^{(m)} &= \overline{\mbX}^{(m-1)} + \texttt{mhsa}(\overline{\mbX}^{(m-1)}) \
\overline{\mbY}^{(m)} &= \texttt{layer-norm}(\mbY^{(m)}) \
\mbX^{(m)} &= \overline{\mbY}^{(m)} + \texttt{mlp}(\overline{\mbY}^{(m)})
\end{align*}
This defines one
A transformer stacks
Except for the lower triangular constraint on the attention matrices, the transformer architecture knows nothing about the relative positions of the tokens. Absent this constraint, the transformer essentially treats the data as an unordered set of tokens. This can actually be a feature! It allows the transformer to act on a wide range of datasets aside from just sequences. For example, transformers are often applied to images by chunking the image up into patches and embedding each one.
However, when the data posses some spatial or temporal structure, it is helpful to include that information in the embedding. A simple way to do so is to add position and content in the token,
\begin{align*}
\mbx_t^{(0)} &= \mbc_t + \mbp_t,
\end{align*}
where
To use a transformer for autoregressive modeling, we need to make predictions from the final layer's representations. If the goal is to predict the next word label
Training deep neural networks is somewhat of a dark art. Standard practice is to use the Adam optimizer with a bag of tricks including gradient clipping, learning rate annealing schedules, increasing minibatch sizes, dropout, etc. Generally, treat these algorithmic decisions as hyperparameters to be tuned. We won't try to put any details in writing lest you overfit to them.
Transformers are a workhorse of modern machine learning and key to many of the impressive advances over recent years. However, there are still areas for improvement. For example, the computational cost of attention is