Commit 568ff36
feat(attention): add Flash Attention VJP for Metal GPU
Implement fused backward pass (VJP) for scaled_dot_product_attention
on Metal GPU, enabling efficient training without falling back to
unfused attention.
- **dQ Kernel** (steel_attention_vjp_dq.h): Computes query gradients
- Outer loop over KV blocks, inner accumulation for dQ
- Uses log2 domain for numerical stability
- **dK/dV Kernel** (steel_attention_vjp_dkv.h): Computes key/value gradients
- K-row ownership model eliminates atomic operations
- Each simdgroup owns exclusive K rows to prevent races
- Optimized path for short sequences (L ≤ 8)
- Uses shared memory for efficient reduction
- Float32 accumulators for half/bfloat16 precision
- Logsumexp caching from forward pass
- Proper GQA (grouped query attention) support
- Causal mask support
- Comprehensive test coverage for all code paths
- No gradient support for mask or attention sinks (falls back to unfused)
- Requires logsumexp from forward pass (training mode only)
- Head dimension D=256 not supported in vector VJP (threadgroup memory)
Co-Authored-By: Claude <noreply@anthropic.com>1 parent ac26a4c commit 568ff36
File tree
15 files changed
+3426
-49
lines changed- mlx
- backend
- cuda
- metal
- kernels
- steel/attn
- kernels
- no_gpu
- python/tests
15 files changed
+3426
-49
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
402 | 402 | | |
403 | 403 | | |
404 | 404 | | |
405 | | - | |
406 | 405 | | |
407 | 406 | | |
408 | 407 | | |
| |||
460 | 459 | | |
461 | 460 | | |
462 | 461 | | |
463 | | - | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
464 | 471 | | |
465 | 472 | | |
466 | 473 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
57 | 79 | | |
58 | 80 | | |
59 | 81 | | |
| |||
81 | 103 | | |
82 | 104 | | |
83 | 105 | | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | 106 | | |
101 | 107 | | |
102 | 108 | | |
| |||
Lines changed: 39 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
6 | 7 | | |
7 | 8 | | |
8 | 9 | | |
| |||
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
44 | 83 | | |
0 commit comments