Skip to content

Commit 136ad7f

Browse files
authored
Day 6, task 1 tests - RoPE with multiple offsets (#68)
This test requires the latest version of mlx 0.29.1, since they just merged support for this in mlx a week ago: ml-explore/mlx#2564 I verified that the other tests still pass with the version upgrade.
1 parent 308388e commit 136ad7f

File tree

5 files changed

+92
-41
lines changed

5 files changed

+92
-41
lines changed

book/src/week2-06-prefill-and-batch.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ src/tiny_llm/positional_encoding.py
5050
src/tiny_llm/attention.py::causal_mask
5151
```
5252

53-
Ensure your RoPE implementation accepts a list of offsets. Also, make sure your mask implementation correctly handles the case where `L != S`.
53+
Ensure your RoPE implementation accepts a `list[slice]` of offsets (one slice for sequence in the batch). Also, make sure your mask implementation correctly handles the case where `L != S`.
54+
55+
You can verify multi-offset RoPE, and that masking works for attention and flash attention with:
56+
57+
```bash
58+
pdm run test --week 2 --day 6 -- -k task_1
59+
```
5460

5561
## Task 2: Batch KV Cache
5662

pdm.lock

Lines changed: 23 additions & 23 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ version = "0.1.0"
88
requires-python = ">=3.10, <3.13"
99
readme = "README.md"
1010
dependencies = [
11-
"mlx>=0.27.0",
11+
"mlx>=0.29.1",
1212
"torch>=2.6.0",
1313
"torchtune>=0.6.1",
1414
"torchao>=0.10.0",
15-
"mlx-lm>=0.26.0",
15+
"mlx-lm>=0.27.1",
1616
"numpy>=2.2.4",
1717
"pytest>=8.3.5",
1818
"ruff>=0.11.6",

tests_refsol/test_week_2_day_6.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,54 @@
1-
import pytest
21
import mlx.core as mx
2+
import numpy as np
3+
import pytest
34
from .tiny_llm_base import *
45
from .utils import *
56

67

8+
def rope_helper(stream: mx.Stream, traditional: bool, precision: mx.Dtype):
9+
BATCH_SIZE = 16
10+
NUM_HEADS = 8
11+
HEAD_DIM = 4
12+
MAX_SEQ_LEN = 14
13+
SEQ_LEN = 9
14+
BASE = 10000
15+
with mx.stream(stream):
16+
for _ in range(100):
17+
user_layer = RoPE(HEAD_DIM, MAX_SEQ_LEN, BASE, traditional=traditional)
18+
x = mx.random.uniform(
19+
shape=(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM), dtype=precision
20+
)
21+
22+
input_pos = np.random.randint(0, MAX_SEQ_LEN - SEQ_LEN, size=BATCH_SIZE)
23+
input_pos_mx = mx.array(input_pos, dtype=mx.int32)
24+
input_pos_user = [slice(i, i + SEQ_LEN) for i in input_pos]
25+
26+
reference_output = mx.fast.rope(
27+
x.transpose(0, 2, 1, 3),
28+
dims=HEAD_DIM,
29+
traditional=traditional,
30+
base=BASE,
31+
scale=1.0,
32+
offset=input_pos_mx,
33+
).transpose(0, 2, 1, 3)
34+
user_output = user_layer(x, input_pos_user)
35+
assert_allclose(
36+
user_output,
37+
reference_output,
38+
precision,
39+
atol=5e-6 if precision == mx.float32 else 1e-3,
40+
)
41+
42+
43+
@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
44+
@pytest.mark.parametrize("traditional", [False, True], ids=["default", "traditional"])
45+
@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
46+
def test_task_1_rope_multiple_offsets(
47+
stream: mx.Stream, traditional: bool, precision: mx.Dtype
48+
):
49+
rope_helper(stream, traditional, precision)
50+
51+
752
def attention_helper(
853
stream: mx.Stream, H_q, H, L, E, S, BATCH, use_flash_attention: bool = False
954
):
@@ -75,57 +120,57 @@ def attention_helper(
75120
)
76121

77122

78-
def test_flash_attention_with_mask_cpu_small():
123+
def test_task_1_flash_attention_with_mask_cpu_small():
79124
attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True)
80125

81126

82-
def test_flash_attention_with_mask_cpu():
127+
def test_task_1_flash_attention_with_mask_cpu():
83128
attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True)
84129

85130

86-
def test_flash_attention_with_mask_cpu_large():
131+
def test_task_1_flash_attention_with_mask_cpu_large():
87132
attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True)
88133

89134

90-
def test_flash_attention_with_mask_gpu_extra_small():
135+
def test_task_1_flash_attention_with_mask_gpu_extra_small():
91136
attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=True)
92137

93138

94-
def test_flash_attention_with_mask_gpu_small():
139+
def test_task_1_flash_attention_with_mask_gpu_small():
95140
attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True)
96141

97142

98-
def test_flash_attention_with_mask_gpu():
143+
def test_task_1_flash_attention_with_mask_gpu():
99144
attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True)
100145

101146

102-
def test_flash_attention_with_mask_gpu_large():
147+
def test_task_1_flash_attention_with_mask_gpu_large():
103148
attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True)
104149

105150

106-
def test_attention_with_mask_cpu_small():
151+
def test_task_1_attention_with_mask_cpu_small():
107152
attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False)
108153

109154

110-
def test_attention_with_mask_cpu():
155+
def test_task_1_attention_with_mask_cpu():
111156
attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False)
112157

113158

114-
def test_attention_with_mask_cpu_large():
159+
def test_task_1_attention_with_mask_cpu_large():
115160
attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False)
116161

117162

118-
def test_attention_with_mask_gpu_extra_small():
163+
def test_task_1_attention_with_mask_gpu_extra_small():
119164
attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=False)
120165

121166

122-
def test_attention_with_mask_gpu_small():
167+
def test_task_1_attention_with_mask_gpu_small():
123168
attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False)
124169

125170

126-
def test_attention_with_mask_gpu():
171+
def test_task_1_attention_with_mask_gpu():
127172
attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False)
128173

129174

130-
def test_attention_with_mask_gpu_large():
175+
def test_task_1_attention_with_mask_gpu_large():
131176
attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False)

0 commit comments

Comments
 (0)