Commit 5c78507
feat(attention): add Flash Attention VJP for Metal GPU
Implement fused backward pass (VJP) for scaled_dot_product_attention
on Metal GPU, enabling efficient training without falling back to
unfused attention.
- **dQ Kernel** (steel_attention_vjp_dq.h): Computes query gradients
- Outer loop over KV blocks, inner accumulation for dQ
- Uses log2 domain for numerical stability
- **dK/dV Kernel** (steel_attention_vjp_dkv.h): Computes key/value gradients
- K-row ownership model eliminates atomic operations
- Each simdgroup owns exclusive K rows to prevent races
- Optimized path for short sequences (L ≤ 8)
- Uses shared memory for efficient reduction
- Float32 accumulators for half/bfloat16 precision
- Logsumexp caching from forward pass
- Proper GQA (grouped query attention) support
- Causal mask support
- Comprehensive test coverage for all code paths
- No gradient support for mask or attention sinks (falls back to unfused)
- Requires logsumexp from forward pass (training mode only)
- Head dimension D=256 not supported in vector VJP (threadgroup memory)
Co-Authored-By: Claude <noreply@anthropic.com>1 parent d2bef3c commit 5c78507
File tree
16 files changed
+3613
-67
lines changed- mlx
- backend
- cuda
- metal
- kernels
- steel/attn
- kernels
- no_gpu
- python/tests
16 files changed
+3613
-67
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
402 | 402 | | |
403 | 403 | | |
404 | 404 | | |
405 | | - | |
406 | 405 | | |
407 | 406 | | |
408 | 407 | | |
| |||
460 | 459 | | |
461 | 460 | | |
462 | 461 | | |
463 | | - | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
464 | 472 | | |
465 | 473 | | |
466 | 474 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
57 | 79 | | |
58 | 80 | | |
59 | 81 | | |
| |||
81 | 103 | | |
82 | 104 | | |
83 | 105 | | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | 106 | | |
101 | 107 | | |
102 | 108 | | |
| |||
Lines changed: 39 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| |||
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
44 | 83 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
80 | 80 | | |
81 | 81 | | |
82 | 82 | | |
| 83 | + | |
| 84 | + | |
83 | 85 | | |
84 | | - | |
| 86 | + | |
85 | 87 | | |
86 | 88 | | |
87 | 89 | | |
| |||
90 | 92 | | |
91 | 93 | | |
92 | 94 | | |
93 | | - | |
| 95 | + | |
| 96 | + | |
94 | 97 | | |
95 | 98 | | |
96 | 99 | | |
| |||
117 | 120 | | |
118 | 121 | | |
119 | 122 | | |
120 | | - | |
| 123 | + | |
| 124 | + | |
121 | 125 | | |
122 | 126 | | |
123 | | - | |
| 127 | + | |
124 | 128 | | |
125 | | - | |
126 | | - | |
| 129 | + | |
| 130 | + | |
127 | 131 | | |
128 | 132 | | |
129 | 133 | | |
| |||
155 | 159 | | |
156 | 160 | | |
157 | 161 | | |
158 | | - | |
| 162 | + | |
159 | 163 | | |
160 | 164 | | |
161 | 165 | | |
| |||
252 | 256 | | |
253 | 257 | | |
254 | 258 | | |
| 259 | + | |
| 260 | + | |
255 | 261 | | |
256 | | - | |
| 262 | + | |
257 | 263 | | |
258 | 264 | | |
259 | 265 | | |
| |||
263 | 269 | | |
264 | 270 | | |
265 | 271 | | |
266 | | - | |
| 272 | + | |
| 273 | + | |
267 | 274 | | |
268 | 275 | | |
269 | 276 | | |
| |||
291 | 298 | | |
292 | 299 | | |
293 | 300 | | |
294 | | - | |
| 301 | + | |
| 302 | + | |
295 | 303 | | |
296 | 304 | | |
297 | | - | |
| 305 | + | |
298 | 306 | | |
299 | | - | |
300 | | - | |
| 307 | + | |
| 308 | + | |
301 | 309 | | |
302 | 310 | | |
303 | 311 | | |
| |||
329 | 337 | | |
330 | 338 | | |
331 | 339 | | |
332 | | - | |
| 340 | + | |
333 | 341 | | |
334 | 342 | | |
335 | 343 | | |
| |||
342 | 350 | | |
343 | 351 | | |
344 | 352 | | |
345 | | - | |
| 353 | + | |
346 | 354 | | |
347 | 355 | | |
348 | 356 | | |
| |||
390 | 398 | | |
391 | 399 | | |
392 | 400 | | |
393 | | - | |
| 401 | + | |
394 | 402 | | |
395 | 403 | | |
396 | 404 | | |
| |||
0 commit comments