|
1 | | -import pytest |
2 | 1 | import mlx.core as mx |
| 2 | +import numpy as np |
| 3 | +import pytest |
3 | 4 | from .tiny_llm_base import * |
4 | 5 | from .utils import * |
5 | 6 |
|
6 | 7 |
|
| 8 | +def rope_helper(stream: mx.Stream, traditional: bool, precision: mx.Dtype): |
| 9 | + BATCH_SIZE = 16 |
| 10 | + NUM_HEADS = 8 |
| 11 | + HEAD_DIM = 4 |
| 12 | + MAX_SEQ_LEN = 14 |
| 13 | + SEQ_LEN = 9 |
| 14 | + BASE = 10000 |
| 15 | + with mx.stream(stream): |
| 16 | + for _ in range(100): |
| 17 | + user_layer = RoPE(HEAD_DIM, MAX_SEQ_LEN, BASE, traditional=traditional) |
| 18 | + x = mx.random.uniform( |
| 19 | + shape=(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM), dtype=precision |
| 20 | + ) |
| 21 | + |
| 22 | + input_pos = np.random.randint(0, MAX_SEQ_LEN - SEQ_LEN, size=BATCH_SIZE) |
| 23 | + input_pos_mx = mx.array(input_pos, dtype=mx.int32) |
| 24 | + input_pos_user = [slice(i, i + SEQ_LEN) for i in input_pos] |
| 25 | + |
| 26 | + reference_output = mx.fast.rope( |
| 27 | + x.transpose(0, 2, 1, 3), |
| 28 | + dims=HEAD_DIM, |
| 29 | + traditional=traditional, |
| 30 | + base=BASE, |
| 31 | + scale=1.0, |
| 32 | + offset=input_pos_mx, |
| 33 | + ).transpose(0, 2, 1, 3) |
| 34 | + user_output = user_layer(x, input_pos_user) |
| 35 | + assert_allclose( |
| 36 | + user_output, |
| 37 | + reference_output, |
| 38 | + precision, |
| 39 | + atol=5e-6 if precision == mx.float32 else 1e-3, |
| 40 | + ) |
| 41 | + |
| 42 | + |
| 43 | +@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) |
| 44 | +@pytest.mark.parametrize("traditional", [False, True], ids=["default", "traditional"]) |
| 45 | +@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) |
| 46 | +def test_task_1_rope_multiple_offsets( |
| 47 | + stream: mx.Stream, traditional: bool, precision: mx.Dtype |
| 48 | +): |
| 49 | + rope_helper(stream, traditional, precision) |
| 50 | + |
| 51 | + |
7 | 52 | def attention_helper( |
8 | 53 | stream: mx.Stream, H_q, H, L, E, S, BATCH, use_flash_attention: bool = False |
9 | 54 | ): |
@@ -75,57 +120,57 @@ def attention_helper( |
75 | 120 | ) |
76 | 121 |
|
77 | 122 |
|
78 | | -def test_flash_attention_with_mask_cpu_small(): |
| 123 | +def test_task_1_flash_attention_with_mask_cpu_small(): |
79 | 124 | attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True) |
80 | 125 |
|
81 | 126 |
|
82 | | -def test_flash_attention_with_mask_cpu(): |
| 127 | +def test_task_1_flash_attention_with_mask_cpu(): |
83 | 128 | attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True) |
84 | 129 |
|
85 | 130 |
|
86 | | -def test_flash_attention_with_mask_cpu_large(): |
| 131 | +def test_task_1_flash_attention_with_mask_cpu_large(): |
87 | 132 | attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True) |
88 | 133 |
|
89 | 134 |
|
90 | | -def test_flash_attention_with_mask_gpu_extra_small(): |
| 135 | +def test_task_1_flash_attention_with_mask_gpu_extra_small(): |
91 | 136 | attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=True) |
92 | 137 |
|
93 | 138 |
|
94 | | -def test_flash_attention_with_mask_gpu_small(): |
| 139 | +def test_task_1_flash_attention_with_mask_gpu_small(): |
95 | 140 | attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True) |
96 | 141 |
|
97 | 142 |
|
98 | | -def test_flash_attention_with_mask_gpu(): |
| 143 | +def test_task_1_flash_attention_with_mask_gpu(): |
99 | 144 | attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True) |
100 | 145 |
|
101 | 146 |
|
102 | | -def test_flash_attention_with_mask_gpu_large(): |
| 147 | +def test_task_1_flash_attention_with_mask_gpu_large(): |
103 | 148 | attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True) |
104 | 149 |
|
105 | 150 |
|
106 | | -def test_attention_with_mask_cpu_small(): |
| 151 | +def test_task_1_attention_with_mask_cpu_small(): |
107 | 152 | attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False) |
108 | 153 |
|
109 | 154 |
|
110 | | -def test_attention_with_mask_cpu(): |
| 155 | +def test_task_1_attention_with_mask_cpu(): |
111 | 156 | attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False) |
112 | 157 |
|
113 | 158 |
|
114 | | -def test_attention_with_mask_cpu_large(): |
| 159 | +def test_task_1_attention_with_mask_cpu_large(): |
115 | 160 | attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) |
116 | 161 |
|
117 | 162 |
|
118 | | -def test_attention_with_mask_gpu_extra_small(): |
| 163 | +def test_task_1_attention_with_mask_gpu_extra_small(): |
119 | 164 | attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=False) |
120 | 165 |
|
121 | 166 |
|
122 | | -def test_attention_with_mask_gpu_small(): |
| 167 | +def test_task_1_attention_with_mask_gpu_small(): |
123 | 168 | attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False) |
124 | 169 |
|
125 | 170 |
|
126 | | -def test_attention_with_mask_gpu(): |
| 171 | +def test_task_1_attention_with_mask_gpu(): |
127 | 172 | attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False) |
128 | 173 |
|
129 | 174 |
|
130 | | -def test_attention_with_mask_gpu_large(): |
| 175 | +def test_task_1_attention_with_mask_gpu_large(): |
131 | 176 | attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) |
0 commit comments