Commit 43bc7cb
committed
feat(attention): add Flash Attention VJP for vector path (L≤8)
Implement fused backward pass for scaled_dot_product_attention on
short sequences (L≤8) using the vector kernel approach. This eliminates
the O(N²) memory requirement of unfused attention by recomputing the
attention matrix on-the-fly during backpropagation.
Key changes:
- Add sdpa_vector_vjp.h with GPU kernels for computing dQ, dK, dV
- Extend forward pass to output logsumexp (LSE) when needed for VJP
- Add comprehensive Python tests for gradient correctness
- Fix CUDA cuDNN backward to handle masks via set_bias() (removes
unnecessary fallback)
Performance (M3 Max, L≤8):
- 1.1-1.4x faster than unfused attention for backward pass
- Memory: O(N) instead of O(N²) for attention matrix
The STEEL VJP for longer sequences (L>8) will be added in a follow-up PR.1 parent 1650c49 commit 43bc7cb
File tree
12 files changed
+1697
-50
lines changed- mlx
- backend
- cuda
- metal
- kernels
- steel/attn
- kernels
- no_gpu
- python/tests
12 files changed
+1697
-50
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
402 | 402 | | |
403 | 403 | | |
404 | 404 | | |
405 | | - | |
406 | 405 | | |
407 | 406 | | |
408 | 407 | | |
| |||
460 | 459 | | |
461 | 460 | | |
462 | 461 | | |
463 | | - | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
464 | 468 | | |
465 | 469 | | |
466 | 470 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
| 56 | + | |
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| |||
Lines changed: 39 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| |||
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
44 | 83 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
80 | 80 | | |
81 | 81 | | |
82 | 82 | | |
| 83 | + | |
| 84 | + | |
83 | 85 | | |
84 | | - | |
| 86 | + | |
85 | 87 | | |
86 | 88 | | |
87 | 89 | | |
| |||
90 | 92 | | |
91 | 93 | | |
92 | 94 | | |
93 | | - | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
94 | 98 | | |
95 | 99 | | |
96 | 100 | | |
| |||
117 | 121 | | |
118 | 122 | | |
119 | 123 | | |
120 | | - | |
| 124 | + | |
| 125 | + | |
121 | 126 | | |
122 | 127 | | |
123 | | - | |
| 128 | + | |
124 | 129 | | |
125 | | - | |
126 | | - | |
| 130 | + | |
| 131 | + | |
127 | 132 | | |
128 | 133 | | |
129 | 134 | | |
| |||
155 | 160 | | |
156 | 161 | | |
157 | 162 | | |
158 | | - | |
| 163 | + | |
159 | 164 | | |
160 | 165 | | |
161 | 166 | | |
| |||
252 | 257 | | |
253 | 258 | | |
254 | 259 | | |
| 260 | + | |
| 261 | + | |
255 | 262 | | |
256 | | - | |
| 263 | + | |
257 | 264 | | |
258 | 265 | | |
259 | 266 | | |
| |||
263 | 270 | | |
264 | 271 | | |
265 | 272 | | |
266 | | - | |
| 273 | + | |
| 274 | + | |
267 | 275 | | |
268 | 276 | | |
269 | 277 | | |
| |||
291 | 299 | | |
292 | 300 | | |
293 | 301 | | |
294 | | - | |
| 302 | + | |
| 303 | + | |
295 | 304 | | |
296 | 305 | | |
297 | | - | |
| 306 | + | |
298 | 307 | | |
299 | | - | |
300 | | - | |
| 308 | + | |
| 309 | + | |
301 | 310 | | |
302 | 311 | | |
303 | 312 | | |
| |||
329 | 338 | | |
330 | 339 | | |
331 | 340 | | |
332 | | - | |
| 341 | + | |
333 | 342 | | |
334 | 343 | | |
335 | 344 | | |
| |||
342 | 351 | | |
343 | 352 | | |
344 | 353 | | |
345 | | - | |
| 354 | + | |
346 | 355 | | |
347 | 356 | | |
348 | 357 | | |
| |||
390 | 399 | | |
391 | 400 | | |
392 | 401 | | |
393 | | - | |
| 402 | + | |
394 | 403 | | |
395 | 404 | | |
396 | 405 | | |
| |||
0 commit comments