Skip to content

Commit 5c78507

Browse files
Brooooooklynclaude
andcommitted
feat(attention): add Flash Attention VJP for Metal GPU
Implement fused backward pass (VJP) for scaled_dot_product_attention on Metal GPU, enabling efficient training without falling back to unfused attention. - **dQ Kernel** (steel_attention_vjp_dq.h): Computes query gradients - Outer loop over KV blocks, inner accumulation for dQ - Uses log2 domain for numerical stability - **dK/dV Kernel** (steel_attention_vjp_dkv.h): Computes key/value gradients - K-row ownership model eliminates atomic operations - Each simdgroup owns exclusive K rows to prevent races - Optimized path for short sequences (L ≤ 8) - Uses shared memory for efficient reduction - Float32 accumulators for half/bfloat16 precision - Logsumexp caching from forward pass - Proper GQA (grouped query attention) support - Causal mask support - Comprehensive test coverage for all code paths - No gradient support for mask or attention sinks (falls back to unfused) - Requires logsumexp from forward pass (training mode only) - Head dimension D=256 not supported in vector VJP (threadgroup memory) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent d2bef3c commit 5c78507

16 files changed

+3613
-67
lines changed

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ bool ScaledDotProductAttention::use_fallback(
402402
bool has_mask,
403403
bool has_arr_mask,
404404
bool do_causal,
405-
bool is_training,
406405
bool output_logsumexp,
407406
Stream s) {
408407
if (s.device == Device::cpu) {
@@ -460,7 +459,16 @@ void ScaledDotProductAttention::eval_gpu(
460459
}
461460
}
462461

463-
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
462+
bool ScaledDotProductAttentionVJP::use_fallback(
463+
const array& q,
464+
Stream s,
465+
bool has_mask,
466+
bool has_sinks,
467+
int /* n_kv_heads */) {
468+
// Force unfused attention when masks/sinks present
469+
if (has_mask || has_sinks) {
470+
return true;
471+
}
464472
// The frontend adds a padding mask when sequence length is not a multiple of
465473
// tile size.
466474
if (q.shape(2) % 128 != 0) {

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,29 @@ build_kernel(layer_norm)
5353
build_kernel(random)
5454
build_kernel(rms_norm)
5555
build_kernel(rope)
56-
build_kernel(scaled_dot_product_attention sdpa_vector.h)
56+
build_kernel(scaled_dot_product_attention sdpa_vector.h sdpa_vector_vjp.h)
57+
58+
set(STEEL_ATTN_HEADERS
59+
steel/defines.h
60+
steel/utils.h
61+
steel/gemm/gemm.h
62+
steel/gemm/mma.h
63+
steel/gemm/loader.h
64+
steel/gemm/transforms.h
65+
steel/utils/type_traits.h
66+
steel/utils/integral_constant.h
67+
steel/attn/attn.h
68+
steel/attn/loader.h
69+
steel/attn/mma.h
70+
steel/attn/params.h
71+
steel/attn/transforms.h
72+
steel/attn/kernels/steel_attention.h
73+
steel/attn/kernels/steel_attention_vjp_dq.h
74+
steel/attn/kernels/steel_attention_vjp_dkv.h)
75+
76+
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
77+
build_kernel(steel/attn/kernels/steel_attention_vjp_dq ${STEEL_ATTN_HEADERS})
78+
build_kernel(steel/attn/kernels/steel_attention_vjp_dkv ${STEEL_ATTN_HEADERS})
5779
if(MLX_METAL_VERSION GREATER_EQUAL 320)
5880
build_kernel(fence)
5981
endif()
@@ -81,22 +103,6 @@ set(STEEL_HEADERS
81103
steel/utils/type_traits.h
82104
steel/utils/integral_constant.h)
83105

84-
set(STEEL_ATTN_HEADERS
85-
steel/defines.h
86-
steel/utils.h
87-
steel/gemm/gemm.h
88-
steel/gemm/mma.h
89-
steel/gemm/loader.h
90-
steel/gemm/transforms.h
91-
steel/utils/type_traits.h
92-
steel/utils/integral_constant.h
93-
steel/attn/attn.h
94-
steel/attn/loader.h
95-
steel/attn/mma.h
96-
steel/attn/params.h
97-
steel/attn/transforms.h
98-
steel/attn/kernels/steel_attention.h)
99-
100106
set(STEEL_NAX_HEADERS
101107
steel/defines.h
102108
steel/utils.h

mlx/backend/metal/kernels/scaled_dot_product_attention.metal

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// clang-format off
44
#include "mlx/backend/metal/kernels/utils.h"
55
#include "mlx/backend/metal/kernels/sdpa_vector.h"
6+
#include "mlx/backend/metal/kernels/sdpa_vector_vjp.h"
67

78
using namespace metal;
89

@@ -41,4 +42,42 @@ using namespace metal;
4142
instantiate_sdpa_vector_heads(float)
4243
instantiate_sdpa_vector_heads(bfloat16_t)
4344
instantiate_sdpa_vector_heads(float16_t)
45+
46+
// SDPA vector VJP instantiations
47+
#define instantiate_sdpa_vector_vjp(type, qk_dim, value_dim) \
48+
instantiate_kernel( \
49+
"sdpa_vector_vjp_" #type "_" #qk_dim "_" #value_dim, \
50+
sdpa_vector_vjp, \
51+
type, \
52+
qk_dim, \
53+
value_dim)
54+
55+
// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
56+
#define instantiate_sdpa_vector_vjp_heads(type) \
57+
instantiate_sdpa_vector_vjp(type, 64, 64) \
58+
instantiate_sdpa_vector_vjp(type, 96, 96) \
59+
instantiate_sdpa_vector_vjp(type, 128, 128)
60+
61+
instantiate_sdpa_vector_vjp_heads(float)
62+
instantiate_sdpa_vector_vjp_heads(bfloat16_t)
63+
instantiate_sdpa_vector_vjp_heads(float16_t)
64+
65+
// SDPA vector VJP accumulate instantiations (for half/bfloat16 with float32 accumulators)
66+
#define instantiate_sdpa_vector_vjp_accumulate(type, qk_dim, value_dim) \
67+
instantiate_kernel( \
68+
"sdpa_vector_vjp_accumulate_" #type "_" #qk_dim "_" #value_dim, \
69+
sdpa_vector_vjp_accumulate, \
70+
type, \
71+
qk_dim, \
72+
value_dim)
73+
74+
// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
75+
#define instantiate_sdpa_vector_vjp_accumulate_heads(type) \
76+
instantiate_sdpa_vector_vjp_accumulate(type, 64, 64) \
77+
instantiate_sdpa_vector_vjp_accumulate(type, 96, 96) \
78+
instantiate_sdpa_vector_vjp_accumulate(type, 128, 128)
79+
80+
// Note: Only instantiate for half/bfloat16 since float32 doesn't need accumulate variant
81+
instantiate_sdpa_vector_vjp_accumulate_heads(bfloat16_t)
82+
instantiate_sdpa_vector_vjp_accumulate_heads(float16_t)
4483
// clang-format on

mlx/backend/metal/kernels/sdpa_vector.h

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ template <typename T, int D, int V = D>
8080
out += o_offset * V + simd_gid * v_per_thread;
8181

8282
// Read the query and 0 the output accumulator
83+
// Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp)
84+
const U log2e_scale = static_cast<U>(scale * M_LOG2E_F);
8385
for (int i = 0; i < qk_per_thread; i++) {
84-
q[i] = static_cast<U>(scale) * queries[i];
86+
q[i] = log2e_scale * queries[i];
8587
}
8688
for (int i = 0; i < v_per_thread; i++) {
8789
o[i] = 0;
@@ -90,7 +92,8 @@ template <typename T, int D, int V = D>
9092
U max_score = Limits<U>::finite_min;
9193
U sum_exp_score = 0;
9294
if (has_sinks && simd_gid == 0) {
93-
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
95+
// Scale sink by M_LOG2E_F to match log2 domain
96+
max_score = static_cast<U>(M_LOG2E_F) * static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
9497
sum_exp_score = 1;
9598
}
9699

@@ -117,13 +120,14 @@ template <typename T, int D, int V = D>
117120
}
118121
score = simd_sum(score);
119122
if (float_mask) {
120-
score += static_cast<U>(fmask[0]);
123+
// Scale float mask by M_LOG2E_F to match log2 domain
124+
score += static_cast<U>(M_LOG2E_F) * static_cast<U>(fmask[0]);
121125
}
122126

123-
// Update the accumulators
127+
// Update the accumulators (using exp2 to match STEEL attention)
124128
U new_max = max(max_score, score);
125-
U factor = fast::exp(max_score - new_max);
126-
U exp_score = fast::exp(score - new_max);
129+
U factor = fast::exp2(max_score - new_max);
130+
U exp_score = fast::exp2(score - new_max);
127131

128132
max_score = new_max;
129133
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -155,7 +159,7 @@ template <typename T, int D, int V = D>
155159
threadgroup_barrier(mem_flags::mem_threadgroup);
156160
max_score = max_scores[simd_lid];
157161
U new_max = simd_max(max_score);
158-
U factor = fast::exp(max_score - new_max);
162+
U factor = fast::exp2(max_score - new_max);
159163
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
160164

161165
// Now we need to aggregate all the outputs
@@ -252,8 +256,10 @@ template <typename T, int D, int V = D>
252256
maxs += o_offset * blocks + block_idx;
253257

254258
// Read the query and 0 the output accumulator
259+
// Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp)
260+
const U log2e_scale = static_cast<U>(scale * M_LOG2E_F);
255261
for (int i = 0; i < qk_per_thread; i++) {
256-
q[i] = static_cast<U>(scale) * queries[i];
262+
q[i] = log2e_scale * queries[i];
257263
}
258264
for (int i = 0; i < v_per_thread; i++) {
259265
o[i] = 0;
@@ -263,7 +269,8 @@ template <typename T, int D, int V = D>
263269
U sum_exp_score = 0;
264270
if (has_sinks && block_idx == 0 && simd_gid == 0) {
265271
int q_head_idx = q_batch_head_idx % num_q_heads;
266-
max_score = static_cast<U>(sinks[q_head_idx]);
272+
// Scale sink by M_LOG2E_F to match log2 domain
273+
max_score = static_cast<U>(M_LOG2E_F) * static_cast<U>(sinks[q_head_idx]);
267274
sum_exp_score = 1;
268275
}
269276

@@ -291,13 +298,14 @@ template <typename T, int D, int V = D>
291298
score = simd_sum(score);
292299

293300
if (float_mask) {
294-
score += fmask[0];
301+
// Scale float mask by M_LOG2E_F to match log2 domain
302+
score += static_cast<U>(M_LOG2E_F) * static_cast<U>(fmask[0]);
295303
}
296304

297-
// Update the accumulators
305+
// Update the accumulators (using exp2 to match STEEL attention)
298306
U new_max = max(max_score, score);
299-
U factor = fast::exp(max_score - new_max);
300-
U exp_score = fast::exp(score - new_max);
307+
U factor = fast::exp2(max_score - new_max);
308+
U exp_score = fast::exp2(score - new_max);
301309

302310
max_score = new_max;
303311
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -329,7 +337,7 @@ template <typename T, int D, int V = D>
329337
threadgroup_barrier(mem_flags::mem_threadgroup);
330338
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
331339
U new_max = simd_max(max_score);
332-
U factor = fast::exp(max_score - new_max);
340+
U factor = fast::exp2(max_score - new_max);
333341
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
334342
sum_exp_score = simd_sum(sum_exp_score * factor);
335343

@@ -342,7 +350,7 @@ template <typename T, int D, int V = D>
342350
// Now we need to aggregate all the outputs
343351
for (int i = 0; i < v_per_thread; i++) {
344352
outputs[simd_lid * BN + simd_gid] =
345-
o[i] * fast::exp(max_scores[simd_gid] - new_max);
353+
o[i] * fast::exp2(max_scores[simd_gid] - new_max);
346354
threadgroup_barrier(mem_flags::mem_threadgroup);
347355

348356
// And write the output
@@ -390,7 +398,7 @@ template <typename T, int D>
390398
// First everybody reads the max and sum_exp
391399
U max_score = maxs[simd_lid];
392400
U new_max = simd_max(max_score);
393-
U factor = fast::exp(max_score - new_max);
401+
U factor = fast::exp2(max_score - new_max);
394402
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
395403

396404
// Now read the block into registers and then use shared memory to transpose

0 commit comments

Comments
 (0)