Skip to content

Commit 568ff36

Browse files
Brooooooklynclaude
andcommitted
feat(attention): add Flash Attention VJP for Metal GPU
Implement fused backward pass (VJP) for scaled_dot_product_attention on Metal GPU, enabling efficient training without falling back to unfused attention. - **dQ Kernel** (steel_attention_vjp_dq.h): Computes query gradients - Outer loop over KV blocks, inner accumulation for dQ - Uses log2 domain for numerical stability - **dK/dV Kernel** (steel_attention_vjp_dkv.h): Computes key/value gradients - K-row ownership model eliminates atomic operations - Each simdgroup owns exclusive K rows to prevent races - Optimized path for short sequences (L ≤ 8) - Uses shared memory for efficient reduction - Float32 accumulators for half/bfloat16 precision - Logsumexp caching from forward pass - Proper GQA (grouped query attention) support - Causal mask support - Comprehensive test coverage for all code paths - No gradient support for mask or attention sinks (falls back to unfused) - Requires logsumexp from forward pass (training mode only) - Head dimension D=256 not supported in vector VJP (threadgroup memory) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ac26a4c commit 568ff36

15 files changed

+3426
-49
lines changed

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ bool ScaledDotProductAttention::use_fallback(
402402
bool has_mask,
403403
bool has_arr_mask,
404404
bool do_causal,
405-
bool is_training,
406405
bool output_logsumexp,
407406
Stream s) {
408407
if (s.device == Device::cpu) {
@@ -460,7 +459,15 @@ void ScaledDotProductAttention::eval_gpu(
460459
}
461460
}
462461

463-
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
462+
bool ScaledDotProductAttentionVJP::use_fallback(
463+
const array& q,
464+
Stream s,
465+
bool has_mask,
466+
bool has_sinks) {
467+
// Force unfused attention when masks/sinks present
468+
if (has_mask || has_sinks) {
469+
return true;
470+
}
464471
// The frontend adds a padding mask when sequence length is not a multiple of
465472
// tile size.
466473
if (q.shape(2) % 128 != 0) {

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,29 @@ build_kernel(layer_norm)
5353
build_kernel(random)
5454
build_kernel(rms_norm)
5555
build_kernel(rope)
56-
build_kernel(scaled_dot_product_attention sdpa_vector.h)
56+
build_kernel(scaled_dot_product_attention sdpa_vector.h sdpa_vector_vjp.h)
57+
58+
set(STEEL_ATTN_HEADERS
59+
steel/defines.h
60+
steel/utils.h
61+
steel/gemm/gemm.h
62+
steel/gemm/mma.h
63+
steel/gemm/loader.h
64+
steel/gemm/transforms.h
65+
steel/utils/type_traits.h
66+
steel/utils/integral_constant.h
67+
steel/attn/attn.h
68+
steel/attn/loader.h
69+
steel/attn/mma.h
70+
steel/attn/params.h
71+
steel/attn/transforms.h
72+
steel/attn/kernels/steel_attention.h
73+
steel/attn/kernels/steel_attention_vjp_dq.h
74+
steel/attn/kernels/steel_attention_vjp_dkv.h)
75+
76+
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
77+
build_kernel(steel/attn/kernels/steel_attention_vjp_dq ${STEEL_ATTN_HEADERS})
78+
build_kernel(steel/attn/kernels/steel_attention_vjp_dkv ${STEEL_ATTN_HEADERS})
5779
if(MLX_METAL_VERSION GREATER_EQUAL 320)
5880
build_kernel(fence)
5981
endif()
@@ -81,22 +103,6 @@ set(STEEL_HEADERS
81103
steel/utils/type_traits.h
82104
steel/utils/integral_constant.h)
83105

84-
set(STEEL_ATTN_HEADERS
85-
steel/defines.h
86-
steel/utils.h
87-
steel/gemm/gemm.h
88-
steel/gemm/mma.h
89-
steel/gemm/loader.h
90-
steel/gemm/transforms.h
91-
steel/utils/type_traits.h
92-
steel/utils/integral_constant.h
93-
steel/attn/attn.h
94-
steel/attn/loader.h
95-
steel/attn/mma.h
96-
steel/attn/params.h
97-
steel/attn/transforms.h
98-
steel/attn/kernels/steel_attention.h)
99-
100106
set(STEEL_NAX_HEADERS
101107
steel/defines.h
102108
steel/utils.h

mlx/backend/metal/kernels/scaled_dot_product_attention.metal

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// clang-format off
44
#include "mlx/backend/metal/kernels/utils.h"
55
#include "mlx/backend/metal/kernels/sdpa_vector.h"
6+
#include "mlx/backend/metal/kernels/sdpa_vector_vjp.h"
67

78
using namespace metal;
89

@@ -41,4 +42,42 @@ using namespace metal;
4142
instantiate_sdpa_vector_heads(float)
4243
instantiate_sdpa_vector_heads(bfloat16_t)
4344
instantiate_sdpa_vector_heads(float16_t)
45+
46+
// SDPA vector VJP instantiations
47+
#define instantiate_sdpa_vector_vjp(type, qk_dim, value_dim) \
48+
instantiate_kernel( \
49+
"sdpa_vector_vjp_" #type "_" #qk_dim "_" #value_dim, \
50+
sdpa_vector_vjp, \
51+
type, \
52+
qk_dim, \
53+
value_dim)
54+
55+
// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
56+
#define instantiate_sdpa_vector_vjp_heads(type) \
57+
instantiate_sdpa_vector_vjp(type, 64, 64) \
58+
instantiate_sdpa_vector_vjp(type, 96, 96) \
59+
instantiate_sdpa_vector_vjp(type, 128, 128)
60+
61+
instantiate_sdpa_vector_vjp_heads(float)
62+
instantiate_sdpa_vector_vjp_heads(bfloat16_t)
63+
instantiate_sdpa_vector_vjp_heads(float16_t)
64+
65+
// SDPA vector VJP accumulate instantiations (for half/bfloat16 with float32 accumulators)
66+
#define instantiate_sdpa_vector_vjp_accumulate(type, qk_dim, value_dim) \
67+
instantiate_kernel( \
68+
"sdpa_vector_vjp_accumulate_" #type "_" #qk_dim "_" #value_dim, \
69+
sdpa_vector_vjp_accumulate, \
70+
type, \
71+
qk_dim, \
72+
value_dim)
73+
74+
// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
75+
#define instantiate_sdpa_vector_vjp_accumulate_heads(type) \
76+
instantiate_sdpa_vector_vjp_accumulate(type, 64, 64) \
77+
instantiate_sdpa_vector_vjp_accumulate(type, 96, 96) \
78+
instantiate_sdpa_vector_vjp_accumulate(type, 128, 128)
79+
80+
// Note: Only instantiate for half/bfloat16 since float32 doesn't need accumulate variant
81+
instantiate_sdpa_vector_vjp_accumulate_heads(bfloat16_t)
82+
instantiate_sdpa_vector_vjp_accumulate_heads(float16_t)
4483
// clang-format on

0 commit comments

Comments
 (0)