Skip to content

Commit 71e8fac

Browse files
committed
add docs to week 1 day 2
Signed-off-by: Alex Chi <[email protected]>
1 parent 7f7527c commit 71e8fac

11 files changed

+154
-138
lines changed

book/src/SUMMARY.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
- [Week 1: From Matmul to Text](./week1-overview.md)
99
- [Attention and Multi-Head Attention](./week1-01-attention.md)
10-
- [Positional Embeddings and RoPE]()
10+
- [Positional Encodings and RoPE](./week1-02-positional-encodings.md)
1111
- [Grouped/Multi Query Attention]()
1212
- [Multilayer Perceptron Layer and Transformer]()
1313
- [Wiring the Qwen2 Model]()

book/src/week1-01-attention.md

+15-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ we will pass a tensor of the shape `N.. x 1024 x 512` to the attention layer.
2222

2323
## Task 1: Implement `scaled_dot_product_attention`
2424

25+
In this task, we will implement the scaled dot product attention function.
26+
27+
```
28+
poetry run pytest tests -k week_1_day_1_task_1 -v
29+
```
30+
31+
2532
**📚 Readings**
2633

2734
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
@@ -51,6 +58,7 @@ K: 1 x H x L x D
5158
V: 1 x H x L x D
5259
Q: 1 x H x L x D
5360
output: 1 x H x L x D
61+
mask: 1 x H x L x L
5462
```
5563

5664
.. though the attention layer only cares about the last two dimensions. The test case will test any shape of the batching dimension.
@@ -64,6 +72,12 @@ poetry run pytest tests -k test_attention_with_mask
6472

6573
## Task 2: Implement `MultiHeadAttention`
6674

75+
In this task, we will implement the multi-head attention layer.
76+
77+
```
78+
src/tiny_llm/attention.py
79+
```
80+
6781
**📚 Readings**
6882

6983
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
@@ -100,7 +114,7 @@ W_o: (H x D) x E
100114
At the end of the day, you should be able to pass the following tests:
101115

102116
```
103-
poetry run pytest tests -k test_multi_head_attention
117+
poetry run pytest tests -k week_1_day_1_task_2 -v
104118
```
105119

106120
{{#include copyright.md}}

book/src/week1-02-positional-embeddings.md

-1
This file was deleted.
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Week 1 Day 2: Positional Encodings and RoPE
2+
3+
In day 2, we will implement the positional embedding used in the Qwen2 model: Rotary Postional Encoding. In a transformer
4+
model, we need a way to embed the information of the position of a token into the input of the attention layers. In Qwen2,
5+
positional embedding is applied within the multi head attention layer on the query and key vectors.
6+
7+
**📚 Readings**
8+
9+
- [You could have designed state of the art positional encoding](https://huggingface.co/blog/designing-positional-encoding)
10+
- [Roformer: Enhanced Transformer with Rotary Positional Encoding](https://arxiv.org/pdf/2104.09864)
11+
12+
## Task 1: Implement Rotary Postional Encoding "RoPE"
13+
14+
You will need to modify the following file:
15+
16+
```
17+
src/tiny_llm/positional_encoding.py
18+
```
19+
20+
In traditional RoPE (as described in the readings), the positional encoding is applied to each head of the query and key vectors.
21+
You can pre-compute the frequencies when initializing the `RoPE` class.
22+
23+
If `offset` is not provided, the positional encoding will be applied to the entire sequence: 0th frequency applied to the
24+
0th token, up to the (L-1)-th token. Otherwise, the positional encoding will be applied to the sequence according to the
25+
offset slice. If the offset slice is 5..10, then the sequence length provided to the layer would be 5, and the 0th token
26+
will be applied with the 5th frequency.
27+
28+
```
29+
x: (N, L, H, D)
30+
cos/sin_freqs: (MAX_SEQ_LEN, D // 2)
31+
```
32+
33+
In the traditional form of RoPE, each head on the dimension of `D` is viewed as consequtive complex pairs. That is to
34+
say, if D = 8, then, x[0] and x[1] are a pair, x[2] and x[3] are another pair, and so on. A pair gets the same frequency
35+
from `cos/sin_freqs`.
36+
37+
```
38+
output[0] = x[0] * cos_freqs[0] + x[1] * sin_freqs[0]
39+
output[1] = x[0] * -sin_freqs[0] + x[1] * cos_freqs[0]
40+
...and so on
41+
```
42+
43+
You can do this by reshaping `x` to (N, L, H, D // 2, 2) and then applying the above formula to each pair.
44+
45+
**📚 Readings**
46+
47+
- [PyTorch RotaryPositionalEmbeddings API](https://pytorch.org/torchtune/stable/generated/torchtune.modules.RotaryPositionalEmbeddings.html)
48+
- [MLX Implementation of RoPE before the custom metal kernel implementation](https://github.com/ml-explore/mlx/pull/676/files)
49+
50+
You can test your implementation by running the following command:
51+
52+
```
53+
poetry run pytest tests -k week_1_day_2_task_1 -v
54+
```
55+
56+
## Task 2: Implement `RoPE` in the non-traditional form
57+
58+
The Qwen2 model uses a non-traditional form of RoPE. In this form, the head embedding dimension is split into two halves,
59+
and the two halves are applied with different frequencies.
60+
61+
```
62+
output[0] = x[0] * cos_freqs[0] + x[HALF_DIM] * sin_freqs[0]
63+
output[HALF_DIM] = x[0] * -sin_freqs[0] + x[HALF_DIM] * cos_freqs[0]
64+
output[1] = x[1] * cos_freqs[1] + x[HALF_DIM + 1] * sin_freqs[1]
65+
output[HALF_DIM + 1] = x[1] * -sin_freqs[1] + x[HALF_DIM + 1] * cos_freqs[1]
66+
...and so on
67+
```
68+
69+
You can do this by directly getting the first half / second half of the embedding dimension of `x` and applying the
70+
frequencies to each half separately.
71+
72+
You can test your implementation by running the following command:
73+
74+
```
75+
poetry run pytest tests -k week_1_day_2_task_2 -v
76+
```
77+
78+
**📚 Readings**
79+
80+
- [vLLM implementation of RoPE](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py)

book/src/week1-overview.md

+17
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,21 @@ To make the journey as interesting as possible, we will skip a few things for no
2828
* Loading the model weights -- I don't think it's an interesting thing to learn how to decode those tensor dump files, so
2929
we will use the `mlx_lm` to load the model and steal the weights from the loaded model into our layer implementations.
3030

31+
## Qwen2 Models
32+
33+
You can try the Qwen2 model with MLX/vLLM. You can read the blog post below to have some idea of what we will build
34+
within this course. At the end of this week, we will be able to chat with the model -- that is to say, use Qwen2 to
35+
generate text, as a casual language model.
36+
37+
The reference implementation of the Qwen2 model can be found in huggingface transformers, vLLM, and mlx-lm. You may
38+
utilize these resources to better understand the internals of the model and what we will implement in this week.
39+
40+
**📚 Readings**
41+
42+
- [Qwen2.5: A Party of Foundation Models!](https://qwenlm.github.io/blog/qwen2.5/)
43+
- [Key Concepts of the Qwen2 Model](https://qwen.readthedocs.io/en/latest/getting_started/concepts.html)
44+
- [Huggingface Transformers - Qwen2](https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2)
45+
- [vLLM Qwen2](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2.py)
46+
- [mlx-lm Qwen2](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py)
47+
3148
{{#include copyright.md}}

src/tiny_llm/multi_head_attention.py

-28
This file was deleted.

src/tiny_llm_week1_ref/attention.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,22 @@ def __call__(
8181
value: mx.array,
8282
mask: mx.array | None = None,
8383
) -> mx.array:
84-
n_batches = query.shape[0]
85-
seq_len = query.shape[1]
84+
N, L, E = query.shape
85+
assert query.shape == key.shape == value.shape
8686
projection_q = (
8787
linear(query, self.wq)
88-
.reshape(n_batches, self.num_heads, seq_len, self.head_dim)
89-
.transpose(1, 0, 2, 3)
88+
.reshape(N, L, self.num_heads, self.head_dim)
89+
.transpose(0, 2, 1, 3)
9090
)
9191
projection_k = (
9292
linear(key, self.wk)
93-
.reshape(n_batches, self.num_heads, seq_len, self.head_dim)
94-
.transpose(1, 0, 2, 3)
93+
.reshape(N, L, self.num_heads, self.head_dim)
94+
.transpose(0, 2, 1, 3)
9595
)
9696
projection_v = (
9797
linear(value, self.wv)
98-
.reshape(n_batches, self.num_heads, seq_len, self.head_dim)
99-
.transpose(1, 0, 2, 3)
98+
.reshape(N, L, self.num_heads, self.head_dim)
99+
.transpose(0, 2, 1, 3)
100100
)
101101
x = scaled_dot_product_attention(
102102
projection_q,
@@ -105,5 +105,5 @@ def __call__(
105105
scale=self.scale,
106106
mask=mask,
107107
)
108-
x = x.transpose(1, 0, 2).reshape(n_batches, seq_len, self.hidden_size)
108+
x = x.transpose(0, 2, 1, 3).reshape(N, L, self.hidden_size)
109109
return linear(x, self.wo)

src/tiny_llm_week1_ref/multi_head_attention.py

-65
This file was deleted.

tests/test_attention.py

+29-30
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
1010
@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
11-
def test_attention_simple(stream: mx.Stream, precision: np.dtype):
11+
def test_attention_week_1_day_1_task_1(stream: mx.Stream, precision: np.dtype):
1212
with mx.stream(stream):
1313
BATCH_SIZE = 3
1414
DIM_N = 4
@@ -35,18 +35,18 @@ def test_attention_simple(stream: mx.Stream, precision: np.dtype):
3535
@pytest.mark.parametrize(
3636
"qkv_shape", [True, False], ids=["with_seq_len", "without_seq_len"]
3737
)
38-
def test_attention_with_mask(stream: mx.Stream, precision: np.dtype, qkv_shape: bool):
38+
def test_attention_with_mask_week_1_day_1_task_1(stream: mx.Stream, precision: np.dtype, qkv_shape: bool):
3939
with mx.stream(stream):
4040
BATCH_SIZE = 3
4141
SEQ_LEN = 10
42-
DIM_N = 4
43-
DIM_M = 5
42+
H = 4
43+
D = 5
4444
if qkv_shape:
45-
qkv_shape = (BATCH_SIZE, SEQ_LEN, DIM_N, DIM_M)
46-
mask_shape = (BATCH_SIZE, SEQ_LEN, DIM_N, DIM_N)
45+
qkv_shape = (BATCH_SIZE, H, SEQ_LEN, D)
46+
mask_shape = (BATCH_SIZE, H, SEQ_LEN, SEQ_LEN)
4747
else:
48-
qkv_shape = (BATCH_SIZE, DIM_N, DIM_M)
49-
mask_shape = (BATCH_SIZE, DIM_N, DIM_N)
48+
qkv_shape = (BATCH_SIZE, H, SEQ_LEN, D)
49+
mask_shape = (BATCH_SIZE, H, SEQ_LEN, SEQ_LEN)
5050
for _ in range(100):
5151
query = np.random.rand(*qkv_shape).astype(precision)
5252
key = np.random.rand(*qkv_shape).astype(precision)
@@ -72,33 +72,31 @@ def test_attention_with_mask(stream: mx.Stream, precision: np.dtype, qkv_shape:
7272

7373
@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
7474
@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
75-
def test_multi_head_attention(stream: mx.Stream, precision: np.dtype):
75+
def test_multi_head_attention_week_1_day_1_task_2(stream: mx.Stream, precision: np.dtype):
7676
with mx.stream(stream):
77-
BATCH_SIZE = 7
78-
DIM_N = 11
79-
DIM_M = 9
80-
NUM_HEADS = 3
77+
SEQ_LEN = 11
78+
D = 9
79+
H = 3
80+
BATCH_SIZE = 10
8181
for _ in range(100):
82-
query = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
83-
key = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
84-
value = np.random.rand(BATCH_SIZE, DIM_N, DIM_M).astype(precision)
85-
q_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
86-
k_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
87-
v_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
88-
out_proj_weight = np.random.rand(DIM_M, DIM_M).astype(precision)
89-
mask = np.random.rand(DIM_N * NUM_HEADS, BATCH_SIZE, BATCH_SIZE).astype(
90-
precision
91-
)
82+
query = np.random.rand(BATCH_SIZE, SEQ_LEN, H * D).astype(precision)
83+
key = np.random.rand(BATCH_SIZE, SEQ_LEN, H * D).astype(precision)
84+
value = np.random.rand(BATCH_SIZE, SEQ_LEN, H * D).astype(precision)
85+
q_proj_weight = np.random.rand(H * D, H * D).astype(precision)
86+
k_proj_weight = np.random.rand(H * D, H * D).astype(precision)
87+
v_proj_weight = np.random.rand(H * D, H * D).astype(precision)
88+
out_proj_weight = np.random.rand(H * D, H * D).astype(precision)
89+
mask = np.random.rand(SEQ_LEN, SEQ_LEN).astype(precision)
9290
reference_output, _ = torch.nn.functional.multi_head_attention_forward(
93-
torch.tensor(query, device=TORCH_DEVICE),
94-
torch.tensor(key, device=TORCH_DEVICE),
95-
torch.tensor(value, device=TORCH_DEVICE),
96-
num_heads=NUM_HEADS,
91+
torch.tensor(query, device=TORCH_DEVICE).transpose(0, 1),
92+
torch.tensor(key, device=TORCH_DEVICE).transpose(0, 1),
93+
torch.tensor(value, device=TORCH_DEVICE).transpose(0, 1),
94+
num_heads=H,
9795
q_proj_weight=torch.tensor(q_proj_weight, device=TORCH_DEVICE),
9896
k_proj_weight=torch.tensor(k_proj_weight, device=TORCH_DEVICE),
9997
v_proj_weight=torch.tensor(v_proj_weight, device=TORCH_DEVICE),
10098
out_proj_weight=torch.tensor(out_proj_weight, device=TORCH_DEVICE),
101-
embed_dim_to_check=DIM_M,
99+
embed_dim_to_check=H * D,
102100
in_proj_weight=None,
103101
in_proj_bias=None,
104102
bias_k=None,
@@ -109,9 +107,10 @@ def test_multi_head_attention(stream: mx.Stream, precision: np.dtype):
109107
use_separate_proj_weight=True,
110108
attn_mask=torch.tensor(mask, device=TORCH_DEVICE),
111109
)
110+
reference_output = reference_output.transpose(0, 1)
112111
user_output = MultiHeadAttention(
113-
DIM_M,
114-
NUM_HEADS,
112+
H * D,
113+
H,
115114
mx.array(q_proj_weight),
116115
mx.array(k_proj_weight),
117116
mx.array(v_proj_weight),

0 commit comments

Comments
 (0)