Skip to content

Commit 3acf26c

Browse files
committed
Add tests for week 2, day 6 - continuous batching
1 parent 136ad7f commit 3acf26c

File tree

5 files changed

+107
-17
lines changed

5 files changed

+107
-17
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ Week 1 is complete. Week 2 is in progress.
4242
| 2.3 | Quantized Matmul and Linear - GPU ||| 🚧 |
4343
| 2.4 | Flash Attention 2 - CPU ||| 🚧 |
4444
| 2.5 | Flash Attention 2 - GPU ||| 🚧 |
45-
| 2.6 | Continuous Batching || 🚧 ||
46-
| 2.7 | Chunked Prefill || 🚧 ||
45+
| 2.6 | Continuous Batching || ||
46+
| 2.7 | Chunked Prefill || ||
4747
| 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 |
4848
| 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 |
4949
| 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 |
5050
| 3.4 | Speculative Decoding | 🚧 | 🚧 | 🚧 |
5151
| 3.5 | RAG Pipeline | 🚧 | 🚧 | 🚧 |
5252
| 3.6 | AI Agent / Tool Calling | 🚧 | 🚧 | 🚧 |
53-
| 3.7 | Long Context | 🚧 | 🚧 | 🚧 |
53+
| 3.7 | Long Context | 🚧 | 🚧 | 🚧 |
5454

5555
Other topics not covered: quantized/compressed kv cache, prefix/prompt cache; sampling, fine tuning; smaller kernels (softmax, silu, etc)

book/src/week2-06-prefill-and-batch.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ src/tiny_llm/qwen2_week2.py
9292

9393
Ensure your model can handle multiple requests simultaneously. You should also use the masks returned by the batch KV cache.
9494

95+
You should pass all of the tests by running:
96+
97+
```bash
98+
pdm run test --week 2 --day 6 -- -k task_3
99+
```
100+
95101
## Task 4: Batch Generate
96102

97103
```

src/tiny_llm/kv_cache.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from abc import ABC, abstractmethod
12
from typing import Optional
23

34
import mlx.core as mx
45

56

6-
class TinyKvCache:
7+
class TinyKvCache(ABC):
8+
@abstractmethod
79
def update_and_fetch(
810
self,
911
key: mx.array,
@@ -26,7 +28,6 @@ def update_and_fetch(
2628
In week 2 day 6/7, we need to return the updated key-value cache, the updated value, the sequence length, and the mask.
2729
so that the batching kv cache can use this information to generate the mask.
2830
"""
29-
pass
3031

3132

3233
class BatchingKvCache(TinyKvCache):

src/tiny_llm_ref/kv_cache.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from abc import ABC, abstractmethod
12
from typing import Optional
23

34
from .attention import causal_mask
45
import mlx.core as mx
56

67

7-
class TinyKvCache:
8+
class TinyKvCache(ABC):
9+
@abstractmethod
810
def update_and_fetch(
911
self,
1012
key: mx.array,
@@ -24,7 +26,6 @@ def update_and_fetch(
2426
Returns:
2527
A tuple of the updated key-value cache, the updated value, the sequence length, and the mask.
2628
"""
27-
pass
2829

2930

3031
class BatchingKvCache(TinyKvCache):
@@ -44,7 +45,10 @@ def update_and_fetch(
4445
B, H, S, D = keys.shape
4546
assert keys.shape == values.shape
4647
assert S <= self.max_seq_len
47-
assert self.HD == (H, D), f"expect {self.HD} but got {H, D}"
48+
if self.HD is None:
49+
self.HD = (H, D)
50+
else:
51+
assert self.HD == (H, D), f"expect {self.HD} but got {H, D}"
4852
assert B == self.max_active_requests
4953
# Step 1: append the result to the cache
5054
data = []
@@ -88,19 +92,20 @@ def get_seq_len(data):
8892
elif isinstance(mask, mx.array):
8993
masks[b, :, seq_len - S : seq_len] = mask
9094
else:
91-
raise NotImplemented
95+
raise NotImplementedError
9296
return keys, values, None, masks.reshape(B, 1, mask_length, seq_len)
9397

9498
def add_request(self, prefilled: TinyKvCache, id: int):
9599
if id >= self.max_active_requests:
96100
raise ValueError(f"Request id {id} is out of range")
97-
keys, _ = prefilled.key_values
98-
B, H, _, D = keys.shape
99-
assert B == 1
100-
if self.HD is None:
101-
self.HD = (H, D)
102-
else:
103-
assert self.HD == (H, D)
101+
if getattr(prefilled, "key_values", None) is not None:
102+
keys, _ = prefilled.key_values
103+
B, H, _, D = keys.shape
104+
assert B == 1
105+
if self.HD is None:
106+
self.HD = (H, D)
107+
else:
108+
assert self.HD == (H, D)
104109
self.kv_caches[id] = prefilled
105110

106111
def remove_request(self, id: int):
@@ -126,7 +131,7 @@ def update_and_fetch(
126131
self.key_values = (key, value)
127132
B, H, S, D = key.shape
128133
self.offset = S
129-
return key, value, 0, mask
134+
return key, value, self.offset, mask
130135
else:
131136
B, H, S, D = key.shape
132137
assert key.shape == value.shape

tests_refsol/test_week_2_day_6.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import mlx.core as mx
22
import numpy as np
33
import pytest
4+
from mlx_lm import load
5+
46
from .tiny_llm_base import *
57
from .utils import *
68

@@ -174,3 +176,79 @@ def test_task_1_attention_with_mask_gpu():
174176

175177
def test_task_1_attention_with_mask_gpu_large():
176178
attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False)
179+
180+
181+
def helper_test_task_3(model_name: str, seq_len: int, iters: int = 1):
182+
"""Tests for continuous batching of decode requests."""
183+
requests = 4
184+
max_seq_len = seq_len
185+
186+
mlx_model, tokenizer = load(model_name)
187+
model = Qwen2ModelWeek2(mlx_model)
188+
for _ in range(iters):
189+
cache = [
190+
BatchingKvCache(requests, max_seq_len)
191+
for _ in range(model.num_hidden_layers)
192+
]
193+
# Start each request at a staggered token index.
194+
staggered_start = [seq_len * i // requests for i in range(requests)]
195+
inputs = mx.random.randint(0, tokenizer.vocab_size, (requests, seq_len))
196+
ref_outputs = mlx_model(inputs)
197+
for offset in range(seq_len + staggered_start[-1]):
198+
seq_idx = [offset - start for start in staggered_start]
199+
200+
# Requests join at the staggered start, and leave when they reach seq_len.
201+
for request_id, sidx in enumerate(seq_idx):
202+
if sidx == 0:
203+
for c in cache:
204+
c.add_request(TinyKvFullCache(), request_id)
205+
elif sidx == seq_len:
206+
for c in cache:
207+
c.remove_request(request_id)
208+
209+
next_tokens = []
210+
next_offsets = []
211+
for request_id, sidx in enumerate(seq_idx):
212+
if 0 <= sidx < seq_len:
213+
next_tokens.append(inputs[request_id, sidx].item())
214+
next_offsets.append(sidx)
215+
else:
216+
next_tokens.append(0)
217+
next_offsets.append(0)
218+
219+
user_out = model(
220+
inputs=mx.array(next_tokens, dtype=mx.int32).reshape(-1, 1),
221+
offset=mx.array(next_offsets, dtype=mx.int32),
222+
cache=cache,
223+
)
224+
225+
for request_id, sidx in enumerate(seq_idx):
226+
if 0 <= sidx < seq_len:
227+
user_out_r = user_out[request_id, 0, :]
228+
ref_out_r = ref_outputs[request_id, sidx, :]
229+
user_out_r = user_out_r - mx.logsumexp(user_out_r, keepdims=True)
230+
ref_out_r = ref_out_r - mx.logsumexp(ref_out_r, keepdims=True)
231+
assert_allclose(
232+
user_out_r, ref_out_r, precision=mx.float16, rtol=1e-1
233+
)
234+
235+
236+
@pytest.mark.skipif(
237+
not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found"
238+
)
239+
def test_task_3_qwen_2_05b():
240+
helper_test_task_3("Qwen/Qwen2-0.5B-Instruct-MLX", seq_len=3)
241+
242+
243+
@pytest.mark.skipif(
244+
not qwen_2_7b_model_exists(), reason="Qwen2-7B-Instruct-MLX model not found"
245+
)
246+
def test_task_3_qwen_2_7b():
247+
helper_test_task_3("Qwen/Qwen2-7B-Instruct-MLX", seq_len=3)
248+
249+
250+
@pytest.mark.skipif(
251+
not qwen_2_15b_model_exists(), reason="Qwen2-1.5B-Instruct-MLX model not found"
252+
)
253+
def test_task_3_qwen_2_15b():
254+
helper_test_task_3("Qwen/Qwen2-1.5B-Instruct-MLX", seq_len=3)

0 commit comments

Comments
 (0)