Skip to content

Commit 7f7527c

Browse files
committed
minor nits to make the dimensions clearer
Signed-off-by: Alex Chi <[email protected]>
1 parent f72fb35 commit 7f7527c

File tree

1 file changed

+55
-40
lines changed

1 file changed

+55
-40
lines changed

book/src/week1-01-attention.md

+55-40
Original file line numberDiff line numberDiff line change
@@ -12,79 +12,94 @@ token embeddings. The output of the model is the most likely next token ID.
1212
[📚 Reading: LLM Inference, the Decode Phase](https://huggingface.co/learn/llm-course/chapter1/8)
1313

1414
Back to the attention layer. The attention layer takes a query, a key, and a value. In a classic implementation, all
15-
of them are of the same shape: `N.. x H x L x D`.
15+
of them are of the same shape: `N.. x L x D`.
1616

17-
`N..` is zero or some number of dimensions for batches. Within each of the batch, `H` is the number of heads, `L` is the
18-
sequence length, and `D` is the dimension of the embedding for a given head in the sequence.
17+
`N..` is zero or some number of dimensions for batches. Within each of the batch, `L` is the sequence length and `D` is
18+
the dimension of the embedding for a given head in the sequence.
1919

20-
So, for example, if we have a sequence of 1024 tokens, where each of the token has a 4096-dimensional embedding. We split
21-
this 4096-dimensional embedding into 8 heads using the upper layer (i.e., multi-head attention layer), then each of the head
22-
will have a 512-dimensional embedding. For week 1, we assume each of the batch will only have one sequence. In this case,
23-
we will pass a tensor of the shape `1 x 8 x 1024 x 512` to the attention layer.
20+
So, for example, if we have a sequence of 1024 tokens, where each of the token has a 512-dimensional embedding (head_dim),
21+
we will pass a tensor of the shape `N.. x 1024 x 512` to the attention layer.
2422

2523
## Task 1: Implement `scaled_dot_product_attention`
2624

25+
**📚 Readings**
26+
27+
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
28+
* [PyTorch Scaled Dot Product Attention API](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (assume `enable_gqa=False`, assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
29+
* [MLX Scaled Dot Product Attention API](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
30+
* [Attention is All You Need](https://arxiv.org/abs/1706.03762)
31+
2732
Implement `scaled_dot_product_attention`. The function takes key, value, and query of the same dimensions.
2833

2934
```
30-
K: N.. x H x L x D
31-
V: N.. x H x L x D
32-
Q: N.. x H x L x D
35+
L is seq_len, in PyTorch API it's S (source len)
36+
D is head_dim
37+
38+
K: N.. x L x D
39+
V: N.. x L x D
40+
Q: N.. x L x D
41+
output: N.. x L x D
3342
```
3443

3544
You may use `softmax` provided by mlx and implement it later in week 2.
3645

37-
**📚 Readings**
46+
Because we are always using the attention layer within the multi-head attention layer, the actual tensor shape when serving
47+
the model will be:
3848

39-
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
40-
* [PyTorch Scaled Dot Product Attention API](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (assume `enable_gqa=False`, assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
41-
* [MLX Scaled Dot Product Attention API](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
42-
* [Attention is All You Need](https://arxiv.org/abs/1706.03762)
49+
```
50+
K: 1 x H x L x D
51+
V: 1 x H x L x D
52+
Q: 1 x H x L x D
53+
output: 1 x H x L x D
54+
```
55+
56+
.. though the attention layer only cares about the last two dimensions. The test case will test any shape of the batching dimension.
57+
58+
At the end of this task, you should be able to pass the following tests:
59+
60+
```
61+
poetry run pytest tests -k test_attention_simple
62+
poetry run pytest tests -k test_attention_with_mask
63+
```
4364

4465
## Task 2: Implement `MultiHeadAttention`
4566

67+
**📚 Readings**
68+
69+
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
70+
* [PyTorch MultiHeadAttention API](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
71+
* [MLX MultiHeadAttention API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
72+
4673
Implement `MultiHeadAttention`. The layer takes a batch of vectors `x`, maps it through the K, V, Q weight matrixes, and
4774
use the attention function we implemented in day 1 to compute the result. The output needs to be mapped using the O
48-
weight matrix. You will also need to implement the `linear` function.
75+
weight matrix.
4976

50-
For `linear`, it takes a tensor of the shape `N.. x I`, and a weight matrix of the shape `O x I`, and a bias vector of
51-
the shape `O`. The output is of the shape `N.. x O`. `I` is the input dimension and `O` is the output dimension.
77+
You will also need to implement the `linear` function first. For `linear`, it takes a tensor of the shape `N.. x I`, a weight matrix of the shape `O x I`, and a bias vector of the shape `O`. The output is of the shape `N.. x O`. `I` is the input dimension and `O` is the output dimension.
5278

5379
For the `MultiHeadAttention` layer, the input tensor `x` has the shape `N x L x E`, where `E` is the dimension of the
54-
embedding for a given head in the sequence. The `K/Q/V` weight matrixes will map the tensor into key, value, and query
55-
separately, where the dimension `E` will be mapped into a dimension of size `H x D`. Then, you will need to reshape it
56-
to `H, D`
80+
embedding for a given token in the sequence. The `K/Q/V` weight matrixes will map the tensor into key, value, and query
81+
separately, where the dimension `E` will be mapped into a dimension of size `H x D`, which means that the token embedding
82+
gets mapped into `H` heads, each with a dimension of `D`. You can directly reshape the tensor to split the `H x D` dimension
83+
into two dimensions of `H` and `D` to get `H` heads for the token. Then, apply the attention function to each of the head
84+
(this requires a transpose, using `swapaxes` in mlx). The attention function takes `N.. x H x L x D` as input so that it
85+
produces an output for each of the head of the token. Then, you can transpose it into `N.. x L x H x D` and reshape it
86+
so that all heads get merged back together with a shape of `N.. x L x (H x D)`. Map it through the output weight matrix to get
87+
the final output.
5788

5889
```
5990
E is hidden_size or embed_dim or dims or model_dim
6091
H is num_heads
6192
D is head_dim
6293
L is seq_len, in PyTorch API it's S (source len)
6394
64-
x: N x L x E
65-
w_q/k/v: E x (H x D)
66-
q/k/v = linear(x, w_q/w_k/w_v) = N x L x (H x D)
67-
then, reshape it into N x L x H x D then transpose it to get N x H x L x D as the input of the attention function.
68-
o = attention(q, k, v) = N x H x L x D
69-
w_o: (H x D) x O
70-
result = linear(reshaped o, w_o) = N x L x O
95+
W_q/k/v: E x (H x D)
96+
output/x: N x L x E
97+
W_o: (H x D) x E
7198
```
7299

73-
You can then directly split the `q/k/v` into `H` heads by reshaping the last dimension into `H x D` and apply the
74-
attention function on it. Note that the attention function takes `N.. x H x L x D` as input, so you will need to
75-
transpose it to get the right shape.
76-
77-
**📚 Readings**
78-
79-
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
80-
* [PyTorch MultiHeadAttention API](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
81-
* [MLX MultiHeadAttention API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
82-
83100
At the end of the day, you should be able to pass the following tests:
84101

85102
```
86-
poetry run pytest tests -k test_attention_simple
87-
poetry run pytest tests -k test_attention_with_mask
88103
poetry run pytest tests -k test_multi_head_attention
89104
```
90105

0 commit comments

Comments
 (0)