|
1 | 1 | import mlx.core as mx |
2 | 2 | import numpy as np |
3 | 3 | import pytest |
| 4 | +from mlx_lm import load |
| 5 | + |
4 | 6 | from .tiny_llm_base import * |
5 | 7 | from .utils import * |
6 | 8 |
|
@@ -174,3 +176,79 @@ def test_task_1_attention_with_mask_gpu(): |
174 | 176 |
|
175 | 177 | def test_task_1_attention_with_mask_gpu_large(): |
176 | 178 | attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) |
| 179 | + |
| 180 | + |
| 181 | +def helper_test_task_3(model_name: str, seq_len: int, iters: int = 1): |
| 182 | + """Tests for continuous batching of decode requests.""" |
| 183 | + requests = 4 |
| 184 | + max_seq_len = seq_len |
| 185 | + |
| 186 | + mlx_model, tokenizer = load(model_name) |
| 187 | + model = Qwen2ModelWeek2(mlx_model) |
| 188 | + for _ in range(iters): |
| 189 | + cache = [ |
| 190 | + BatchingKvCache(requests, max_seq_len) |
| 191 | + for _ in range(model.num_hidden_layers) |
| 192 | + ] |
| 193 | + # Start each request at a staggered token index. |
| 194 | + staggered_start = [seq_len * i // requests for i in range(requests)] |
| 195 | + inputs = mx.random.randint(0, tokenizer.vocab_size, (requests, seq_len)) |
| 196 | + ref_outputs = mlx_model(inputs) |
| 197 | + for offset in range(seq_len + staggered_start[-1]): |
| 198 | + seq_idx = [offset - start for start in staggered_start] |
| 199 | + |
| 200 | + # Requests join at the staggered start, and leave when they reach seq_len. |
| 201 | + for request_id, sidx in enumerate(seq_idx): |
| 202 | + if sidx == 0: |
| 203 | + for c in cache: |
| 204 | + c.add_request(TinyKvFullCache(), request_id) |
| 205 | + elif sidx == seq_len: |
| 206 | + for c in cache: |
| 207 | + c.remove_request(request_id) |
| 208 | + |
| 209 | + next_tokens = [] |
| 210 | + next_offsets = [] |
| 211 | + for request_id, sidx in enumerate(seq_idx): |
| 212 | + if 0 <= sidx < seq_len: |
| 213 | + next_tokens.append(inputs[request_id, sidx].item()) |
| 214 | + next_offsets.append(sidx) |
| 215 | + else: |
| 216 | + next_tokens.append(0) |
| 217 | + next_offsets.append(0) |
| 218 | + |
| 219 | + user_out = model( |
| 220 | + inputs=mx.array(next_tokens, dtype=mx.int32).reshape(-1, 1), |
| 221 | + offset=mx.array(next_offsets, dtype=mx.int32), |
| 222 | + cache=cache, |
| 223 | + ) |
| 224 | + |
| 225 | + for request_id, sidx in enumerate(seq_idx): |
| 226 | + if 0 <= sidx < seq_len: |
| 227 | + user_out_r = user_out[request_id, 0, :] |
| 228 | + ref_out_r = ref_outputs[request_id, sidx, :] |
| 229 | + user_out_r = user_out_r - mx.logsumexp(user_out_r, keepdims=True) |
| 230 | + ref_out_r = ref_out_r - mx.logsumexp(ref_out_r, keepdims=True) |
| 231 | + assert_allclose( |
| 232 | + user_out_r, ref_out_r, precision=mx.float16, rtol=1e-1 |
| 233 | + ) |
| 234 | + |
| 235 | + |
| 236 | +@pytest.mark.skipif( |
| 237 | + not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" |
| 238 | +) |
| 239 | +def test_task_3_qwen_2_05b(): |
| 240 | + helper_test_task_3("Qwen/Qwen2-0.5B-Instruct-MLX", seq_len=3) |
| 241 | + |
| 242 | + |
| 243 | +@pytest.mark.skipif( |
| 244 | + not qwen_2_7b_model_exists(), reason="Qwen2-7B-Instruct-MLX model not found" |
| 245 | +) |
| 246 | +def test_task_3_qwen_2_7b(): |
| 247 | + helper_test_task_3("Qwen/Qwen2-7B-Instruct-MLX", seq_len=3) |
| 248 | + |
| 249 | + |
| 250 | +@pytest.mark.skipif( |
| 251 | + not qwen_2_15b_model_exists(), reason="Qwen2-1.5B-Instruct-MLX model not found" |
| 252 | +) |
| 253 | +def test_task_3_qwen_2_15b(): |
| 254 | + helper_test_task_3("Qwen/Qwen2-1.5B-Instruct-MLX", seq_len=3) |
0 commit comments