Commit d226a28
committed
feat(attention): add Flash Attention VJP for vector path (L≤8)
Implement fused backward pass for scaled_dot_product_attention on
short sequences (L≤8) using the vector kernel approach. This eliminates
the O(N²) memory requirement of unfused attention by recomputing the
attention matrix on-the-fly during backpropagation.
Key changes:
- Add sdpa_vector_vjp.h with GPU kernels for computing dQ, dK, dV
- Extend forward pass to output logsumexp (LSE) when needed for VJP
- Add comprehensive Python tests for gradient correctness
- Fix CUDA cuDNN backward to handle masks via set_bias() (removes
unnecessary fallback)
Performance (M3 Max, L≤8):
- 1.1-1.4x faster than unfused attention for backward pass
- Memory: O(N) instead of O(N²) for attention matrix
The STEEL VJP for longer sequences (L>8) will be added in a follow-up PR.1 parent 0bb50d9 commit d226a28
File tree
12 files changed
+1695
-48
lines changed- mlx
- backend
- cuda
- metal
- kernels
- steel/attn
- kernels
- no_gpu
- python/tests
12 files changed
+1695
-48
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
402 | 402 | | |
403 | 403 | | |
404 | 404 | | |
405 | | - | |
406 | 405 | | |
407 | 406 | | |
408 | 407 | | |
| |||
460 | 459 | | |
461 | 460 | | |
462 | 461 | | |
463 | | - | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
464 | 468 | | |
465 | 469 | | |
466 | 470 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
| 56 | + | |
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| |||
Lines changed: 39 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| |||
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
44 | 83 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
81 | 81 | | |
82 | 82 | | |
83 | 83 | | |
| 84 | + | |
| 85 | + | |
84 | 86 | | |
85 | | - | |
| 87 | + | |
86 | 88 | | |
87 | 89 | | |
88 | 90 | | |
| |||
91 | 93 | | |
92 | 94 | | |
93 | 95 | | |
94 | | - | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
95 | 99 | | |
96 | 100 | | |
97 | 101 | | |
| |||
118 | 122 | | |
119 | 123 | | |
120 | 124 | | |
121 | | - | |
| 125 | + | |
| 126 | + | |
122 | 127 | | |
123 | 128 | | |
124 | | - | |
| 129 | + | |
125 | 130 | | |
126 | | - | |
127 | | - | |
| 131 | + | |
| 132 | + | |
128 | 133 | | |
129 | 134 | | |
130 | 135 | | |
| |||
156 | 161 | | |
157 | 162 | | |
158 | 163 | | |
159 | | - | |
| 164 | + | |
160 | 165 | | |
161 | 166 | | |
162 | 167 | | |
| |||
247 | 252 | | |
248 | 253 | | |
249 | 254 | | |
250 | | - | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
251 | 258 | | |
252 | | - | |
| 259 | + | |
253 | 260 | | |
254 | 261 | | |
255 | 262 | | |
256 | 263 | | |
257 | 264 | | |
258 | | - | |
| 265 | + | |
| 266 | + | |
259 | 267 | | |
260 | 268 | | |
261 | 269 | | |
| |||
278 | 286 | | |
279 | 287 | | |
280 | 288 | | |
281 | | - | |
| 289 | + | |
| 290 | + | |
282 | 291 | | |
283 | 292 | | |
284 | | - | |
| 293 | + | |
285 | 294 | | |
286 | | - | |
287 | | - | |
| 295 | + | |
| 296 | + | |
288 | 297 | | |
289 | 298 | | |
290 | 299 | | |
| |||
0 commit comments