Commit ed2cefa
authored
[varlen Kernel] Extend paged attention v2 to varlen [4/n] (#166)
## Summary
- Add `find_seq_idx` binary search to the v2 Metal kernel so each
threadgroup discovers its sequence from a flat `cu_seqlens_q` array,
enabling variable-length queries (prefill + decode in one launch)
- **This PR does not take actual effect in production**. The current
production still use `mx.sdpa` for prefilling, and use this PR v2 for
decoding. But the kernels_v2 is identical to previous v1, by freezing
some parameters.
- Pass all triangle tests. safe to move forward to stage 3 continuous
batching.
Notes:
- vendored feature from upstream vllm: adding sliding window support,
and soft capping to the v2 kernel
- Update production decode path to match the new function signature
(default params, no behavior change)
- **These features are NOT TESTED IN END-to-END production usage, they
are expected to be binded with specific models such as early version of
mistral models.**
## Triangle Test Status
```
ref (pure-MLX naive)
/ \
edge 1 edge 3
/ \
v1 ── edge 2 ── v2
```
- **Edge 1** (v1 == ref): 6 pass (unchanged)
- **Edge 2** (v2 == v1): 6 pass (unchanged)
- **Edge 3** (v2 == ref): **24 pass** (was 3 pass + 21 xfail)
- varlen (q_len > 1): now passing
- sliding window (128): now passing
- soft capping (50.0): now passing
**Before:** 15 passed + 21 xfail → **After:** 36 passed + 0 xfail
## What's NOT in this PR
The kernel now supports unified prefill+decode, but production still
uses the split path (MLX SDPA for prefill, v2 kernel for decode). Wiring
`metal_unified_attention()` into `model_runner.py` is a follow-up.
## Numeric Stability
```
python -m pytest tests/test_paged_deterministic.py -v -s
```
* Before this PR: 5/6 match mlx_lm path
* After this PR: 6/6 match mlx_lm path
However, I don't want to change the test for now. The test result will
flip on and off later by the following PRs.
## Benchmark
run same benchmark script as
#136
<details>
<summary>This PR: </summary>
```
============ Serving Benchmark Result ============
Successful requests: 100
Failed requests: 0
Maximum request concurrency: 32
Request rate configured (RPS): 10.00
Benchmark duration (s): 107.24
Total input tokens: 23260
Total generated tokens: 22061
Request throughput (req/s): 0.93
Output token throughput (tok/s): 205.71
Peak output token throughput (tok/s): 319.00
Peak concurrent requests: 35.00
Total token throughput (tok/s): 422.60
---------------Time to First Token----------------
Mean TTFT (ms): 593.33
Median TTFT (ms): 386.44
P99 TTFT (ms): 2147.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 134.30
Median TPOT (ms): 127.79
P99 TPOT (ms): 477.35
---------------Inter-token Latency----------------
Mean ITL (ms): 117.58
Median ITL (ms): 104.05
P99 ITL (ms): 473.33
==================================================
```
</details>
<details>
<summary>before this PR:</summary>
```
============ Serving Benchmark Result ============
Successful requests: 100
Failed requests: 0
Maximum request concurrency: 32
Request rate configured (RPS): 10.00
Benchmark duration (s): 106.42
Total input tokens: 23260
Total generated tokens: 22061
Request throughput (req/s): 0.94
Output token throughput (tok/s): 207.30
Peak output token throughput (tok/s): 320.00
Peak concurrent requests: 35.00
Total token throughput (tok/s): 425.87
---------------Time to First Token----------------
Mean TTFT (ms): 982.74
Median TTFT (ms): 452.35
P99 TTFT (ms): 3030.71
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 132.63
Median TPOT (ms): 124.33
P99 TPOT (ms): 442.78
---------------Inter-token Latency----------------
Mean ITL (ms): 115.19
Median ITL (ms): 101.69
P99 ITL (ms): 440.37
==================================================
```
</details>
This PR has no effects on the performance. It paves the way for
continuous batching.
## Possible Limitation
* binary search is translated from the triton kernel. But it may not be
neccecary. Triton uses it to avoid CPU-GPU data copy, but we are on a
unifed memory. Maybe we can prebuild the reverse map. But from the data
range, O(log(n)) are the same with O(1) but takes less space.
* didn't check the partition on or off.
---------
Signed-off-by: ran <hzz5361@psu.edu>1 parent 95ad433 commit ed2cefa
5 files changed
Lines changed: 136 additions & 51 deletions
File tree
- tests
- vllm_metal
- metal_kernel_backend
- metal
- kernels_v2
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
357 | 357 | | |
358 | 358 | | |
359 | 359 | | |
360 | | - | |
361 | | - | |
362 | | - | |
363 | | - | |
364 | | - | |
365 | | - | |
366 | 360 | | |
367 | 361 | | |
368 | 362 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
87 | 87 | | |
88 | 88 | | |
89 | 89 | | |
90 | | - | |
91 | | - | |
92 | | - | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
93 | 96 | | |
| 97 | + | |
94 | 98 | | |
95 | 99 | | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | 100 | | |
111 | 101 | | |
112 | 102 | | |
113 | 103 | | |
114 | 104 | | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
115 | 113 | | |
116 | 114 | | |
117 | 115 | | |
118 | | - | |
| 116 | + | |
119 | 117 | | |
120 | 118 | | |
121 | 119 | | |
| |||
124 | 122 | | |
125 | 123 | | |
126 | 124 | | |
| 125 | + | |
127 | 126 | | |
128 | 127 | | |
| 128 | + | |
129 | 129 | | |
130 | 130 | | |
| 131 | + | |
131 | 132 | | |
132 | 133 | | |
133 | 134 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
756 | 756 | | |
757 | 757 | | |
758 | 758 | | |
| 759 | + | |
| 760 | + | |
| 761 | + | |
| 762 | + | |
| 763 | + | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
| 773 | + | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
| 777 | + | |
| 778 | + | |
| 779 | + | |
| 780 | + | |
| 781 | + | |
| 782 | + | |
| 783 | + | |
| 784 | + | |
| 785 | + | |
| 786 | + | |
759 | 787 | | |
760 | 788 | | |
761 | 789 | | |
| |||
795 | 823 | | |
796 | 824 | | |
797 | 825 | | |
| 826 | + | |
| 827 | + | |
| 828 | + | |
798 | 829 | | |
799 | 830 | | |
800 | 831 | | |
801 | 832 | | |
802 | 833 | | |
803 | 834 | | |
804 | | - | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
| 838 | + | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
805 | 842 | | |
806 | 843 | | |
807 | 844 | | |
808 | 845 | | |
809 | | - | |
810 | | - | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
811 | 856 | | |
812 | 857 | | |
813 | 858 | | |
814 | 859 | | |
815 | | - | |
| 860 | + | |
816 | 861 | | |
817 | 862 | | |
818 | 863 | | |
| |||
867 | 912 | | |
868 | 913 | | |
869 | 914 | | |
870 | | - | |
| 915 | + | |
871 | 916 | | |
872 | 917 | | |
873 | 918 | | |
| |||
955 | 1000 | | |
956 | 1001 | | |
957 | 1002 | | |
958 | | - | |
| 1003 | + | |
959 | 1004 | | |
960 | 1005 | | |
961 | 1006 | | |
962 | 1007 | | |
963 | | - | |
| 1008 | + | |
964 | 1009 | | |
965 | 1010 | | |
966 | | - | |
| 1011 | + | |
| 1012 | + | |
| 1013 | + | |
| 1014 | + | |
| 1015 | + | |
| 1016 | + | |
967 | 1017 | | |
968 | 1018 | | |
969 | 1019 | | |
| |||
981 | 1031 | | |
982 | 1032 | | |
983 | 1033 | | |
984 | | - | |
| 1034 | + | |
985 | 1035 | | |
986 | 1036 | | |
987 | 1037 | | |
| |||
1058 | 1108 | | |
1059 | 1109 | | |
1060 | 1110 | | |
| 1111 | + | |
1061 | 1112 | | |
1062 | 1113 | | |
1063 | | - | |
| 1114 | + | |
1064 | 1115 | | |
1065 | 1116 | | |
1066 | 1117 | | |
1067 | | - | |
| 1118 | + | |
1068 | 1119 | | |
1069 | 1120 | | |
1070 | 1121 | | |
| |||
1143 | 1194 | | |
1144 | 1195 | | |
1145 | 1196 | | |
1146 | | - | |
| 1197 | + | |
1147 | 1198 | | |
1148 | 1199 | | |
1149 | 1200 | | |
| |||
1165 | 1216 | | |
1166 | 1217 | | |
1167 | 1218 | | |
| 1219 | + | |
| 1220 | + | |
1168 | 1221 | | |
1169 | 1222 | | |
1170 | 1223 | | |
| |||
1174 | 1227 | | |
1175 | 1228 | | |
1176 | 1229 | | |
1177 | | - | |
| 1230 | + | |
| 1231 | + | |
| 1232 | + | |
| 1233 | + | |
| 1234 | + | |
| 1235 | + | |
1178 | 1236 | | |
1179 | | - | |
| 1237 | + | |
| 1238 | + | |
1180 | 1239 | | |
1181 | 1240 | | |
1182 | 1241 | | |
1183 | | - | |
| 1242 | + | |
1184 | 1243 | | |
1185 | | - | |
| 1244 | + | |
1186 | 1245 | | |
1187 | 1246 | | |
1188 | 1247 | | |
| |||
1203 | 1262 | | |
1204 | 1263 | | |
1205 | 1264 | | |
1206 | | - | |
| 1265 | + | |
1207 | 1266 | | |
1208 | 1267 | | |
1209 | 1268 | | |
| |||
1242 | 1301 | | |
1243 | 1302 | | |
1244 | 1303 | | |
1245 | | - | |
| 1304 | + | |
1246 | 1305 | | |
1247 | 1306 | | |
1248 | 1307 | | |
| |||
1265 | 1324 | | |
1266 | 1325 | | |
1267 | 1326 | | |
1268 | | - | |
| 1327 | + | |
1269 | 1328 | | |
1270 | 1329 | | |
1271 | | - | |
| 1330 | + | |
1272 | 1331 | | |
1273 | 1332 | | |
1274 | 1333 | | |
| |||
1313 | 1372 | | |
1314 | 1373 | | |
1315 | 1374 | | |
| 1375 | + | |
| 1376 | + | |
| 1377 | + | |
1316 | 1378 | | |
1317 | 1379 | | |
1318 | 1380 | | |
| |||
1334 | 1396 | | |
1335 | 1397 | | |
1336 | 1398 | | |
| 1399 | + | |
| 1400 | + | |
1337 | 1401 | | |
1338 | 1402 | | |
1339 | 1403 | | |
| |||
0 commit comments