Commit 9db03f9
committed
feat(attention): add Flash Attention VJP for vector path (L≤8)
Implement fused backward pass for scaled_dot_product_attention on
short sequences (L≤8) using the vector kernel approach. This eliminates
the O(N²) memory requirement of unfused attention by recomputing the
attention matrix on-the-fly during backpropagation.
Key changes:
- Add sdpa_vector_vjp.h with GPU kernels for computing dQ, dK, dV
- Extend forward pass to output logsumexp (LSE) when needed for VJP
- Add comprehensive Python tests for gradient correctness
- Fix CUDA cuDNN backward to handle masks via set_bias() (removes
unnecessary fallback)
Performance (M3 Max, L≤8):
- 1.1-1.4x faster than unfused attention for backward pass
- Memory: O(N) instead of O(N²) for attention matrix
The STEEL VJP for longer sequences (L>8) will be added in a follow-up PR.1 parent e226af7 commit 9db03f9
File tree
11 files changed
+1310
-48
lines changed- mlx
- backend
- cuda
- metal
- kernels
- steel/attn
- kernels
- no_gpu
11 files changed
+1310
-48
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
555 | 555 | | |
556 | 556 | | |
557 | 557 | | |
558 | | - | |
559 | 558 | | |
560 | 559 | | |
561 | 560 | | |
| |||
618 | 617 | | |
619 | 618 | | |
620 | 619 | | |
621 | | - | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
622 | 626 | | |
623 | 627 | | |
624 | 628 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
| 56 | + | |
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| |||
Lines changed: 39 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| |||
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
44 | 83 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
81 | 81 | | |
82 | 82 | | |
83 | 83 | | |
| 84 | + | |
| 85 | + | |
84 | 86 | | |
85 | | - | |
| 87 | + | |
86 | 88 | | |
87 | 89 | | |
88 | 90 | | |
| |||
91 | 93 | | |
92 | 94 | | |
93 | 95 | | |
94 | | - | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
95 | 99 | | |
96 | 100 | | |
97 | 101 | | |
| |||
118 | 122 | | |
119 | 123 | | |
120 | 124 | | |
121 | | - | |
| 125 | + | |
| 126 | + | |
122 | 127 | | |
123 | 128 | | |
124 | | - | |
| 129 | + | |
125 | 130 | | |
126 | | - | |
127 | | - | |
| 131 | + | |
| 132 | + | |
128 | 133 | | |
129 | 134 | | |
130 | 135 | | |
| |||
156 | 161 | | |
157 | 162 | | |
158 | 163 | | |
159 | | - | |
| 164 | + | |
160 | 165 | | |
161 | 166 | | |
162 | 167 | | |
| |||
247 | 252 | | |
248 | 253 | | |
249 | 254 | | |
250 | | - | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
251 | 258 | | |
252 | | - | |
| 259 | + | |
253 | 260 | | |
254 | 261 | | |
255 | 262 | | |
256 | 263 | | |
257 | 264 | | |
258 | | - | |
| 265 | + | |
| 266 | + | |
259 | 267 | | |
260 | 268 | | |
261 | 269 | | |
| |||
278 | 286 | | |
279 | 287 | | |
280 | 288 | | |
281 | | - | |
| 289 | + | |
| 290 | + | |
282 | 291 | | |
283 | 292 | | |
284 | | - | |
| 293 | + | |
285 | 294 | | |
286 | | - | |
287 | | - | |
| 295 | + | |
| 296 | + | |
288 | 297 | | |
289 | 298 | | |
290 | 299 | | |
| |||
0 commit comments