Skip to content

Commit 43bc7cb

Browse files
committed
feat(attention): add Flash Attention VJP for vector path (L≤8)
Implement fused backward pass for scaled_dot_product_attention on short sequences (L≤8) using the vector kernel approach. This eliminates the O(N²) memory requirement of unfused attention by recomputing the attention matrix on-the-fly during backpropagation. Key changes: - Add sdpa_vector_vjp.h with GPU kernels for computing dQ, dK, dV - Extend forward pass to output logsumexp (LSE) when needed for VJP - Add comprehensive Python tests for gradient correctness - Fix CUDA cuDNN backward to handle masks via set_bias() (removes unnecessary fallback) Performance (M3 Max, L≤8): - 1.1-1.4x faster than unfused attention for backward pass - Memory: O(N) instead of O(N²) for attention matrix The STEEL VJP for longer sequences (L>8) will be added in a follow-up PR.
1 parent 1650c49 commit 43bc7cb

File tree

12 files changed

+1697
-50
lines changed

12 files changed

+1697
-50
lines changed

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ bool ScaledDotProductAttention::use_fallback(
402402
bool has_mask,
403403
bool has_arr_mask,
404404
bool do_causal,
405-
bool is_training,
406405
bool output_logsumexp,
407406
Stream s) {
408407
if (s.device == Device::cpu) {
@@ -460,7 +459,12 @@ void ScaledDotProductAttention::eval_gpu(
460459
}
461460
}
462461

463-
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
462+
bool ScaledDotProductAttentionVJP::use_fallback(
463+
const array& q,
464+
Stream s,
465+
bool has_mask,
466+
bool has_sinks,
467+
int /* n_kv_heads */) {
464468
// The frontend adds a padding mask when sequence length is not a multiple of
465469
// tile size.
466470
if (q.shape(2) % 128 != 0) {

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ build_kernel(layer_norm)
5353
build_kernel(random)
5454
build_kernel(rms_norm)
5555
build_kernel(rope)
56-
build_kernel(scaled_dot_product_attention sdpa_vector.h)
56+
build_kernel(scaled_dot_product_attention sdpa_vector.h sdpa_vector_vjp.h)
5757
if(MLX_METAL_VERSION GREATER_EQUAL 320)
5858
build_kernel(fence)
5959
endif()

mlx/backend/metal/kernels/scaled_dot_product_attention.metal

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// clang-format off
44
#include "mlx/backend/metal/kernels/utils.h"
55
#include "mlx/backend/metal/kernels/sdpa_vector.h"
6+
#include "mlx/backend/metal/kernels/sdpa_vector_vjp.h"
67

78
using namespace metal;
89

@@ -41,4 +42,42 @@ using namespace metal;
4142
instantiate_sdpa_vector_heads(float)
4243
instantiate_sdpa_vector_heads(bfloat16_t)
4344
instantiate_sdpa_vector_heads(float16_t)
45+
46+
// SDPA vector VJP instantiations
47+
#define instantiate_sdpa_vector_vjp(type, qk_dim, value_dim) \
48+
instantiate_kernel( \
49+
"sdpa_vector_vjp_" #type "_" #qk_dim "_" #value_dim, \
50+
sdpa_vector_vjp, \
51+
type, \
52+
qk_dim, \
53+
value_dim)
54+
55+
// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
56+
#define instantiate_sdpa_vector_vjp_heads(type) \
57+
instantiate_sdpa_vector_vjp(type, 64, 64) \
58+
instantiate_sdpa_vector_vjp(type, 96, 96) \
59+
instantiate_sdpa_vector_vjp(type, 128, 128)
60+
61+
instantiate_sdpa_vector_vjp_heads(float)
62+
instantiate_sdpa_vector_vjp_heads(bfloat16_t)
63+
instantiate_sdpa_vector_vjp_heads(float16_t)
64+
65+
// SDPA vector VJP accumulate instantiations (for half/bfloat16 with float32 accumulators)
66+
#define instantiate_sdpa_vector_vjp_accumulate(type, qk_dim, value_dim) \
67+
instantiate_kernel( \
68+
"sdpa_vector_vjp_accumulate_" #type "_" #qk_dim "_" #value_dim, \
69+
sdpa_vector_vjp_accumulate, \
70+
type, \
71+
qk_dim, \
72+
value_dim)
73+
74+
// Note: D=256 exceeds Metal's 32KB threadgroup memory limit for vector VJP kernel
75+
#define instantiate_sdpa_vector_vjp_accumulate_heads(type) \
76+
instantiate_sdpa_vector_vjp_accumulate(type, 64, 64) \
77+
instantiate_sdpa_vector_vjp_accumulate(type, 96, 96) \
78+
instantiate_sdpa_vector_vjp_accumulate(type, 128, 128)
79+
80+
// Note: Only instantiate for half/bfloat16 since float32 doesn't need accumulate variant
81+
instantiate_sdpa_vector_vjp_accumulate_heads(bfloat16_t)
82+
instantiate_sdpa_vector_vjp_accumulate_heads(float16_t)
4483
// clang-format on

mlx/backend/metal/kernels/sdpa_vector.h

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ template <typename T, int D, int V = D>
8080
out += o_offset * V + simd_gid * v_per_thread;
8181

8282
// Read the query and 0 the output accumulator
83+
// Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp)
84+
const U log2e_scale = static_cast<U>(scale * M_LOG2E_F);
8385
for (int i = 0; i < qk_per_thread; i++) {
84-
q[i] = static_cast<U>(scale) * queries[i];
86+
q[i] = log2e_scale * queries[i];
8587
}
8688
for (int i = 0; i < v_per_thread; i++) {
8789
o[i] = 0;
@@ -90,7 +92,9 @@ template <typename T, int D, int V = D>
9092
U max_score = Limits<U>::finite_min;
9193
U sum_exp_score = 0;
9294
if (has_sinks && simd_gid == 0) {
93-
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
95+
// Scale sink by M_LOG2E_F to match log2 domain
96+
max_score = static_cast<U>(M_LOG2E_F) *
97+
static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
9498
sum_exp_score = 1;
9599
}
96100

@@ -117,13 +121,14 @@ template <typename T, int D, int V = D>
117121
}
118122
score = simd_sum(score);
119123
if (float_mask) {
120-
score += static_cast<U>(fmask[0]);
124+
// Scale float mask by M_LOG2E_F to match log2 domain
125+
score += static_cast<U>(M_LOG2E_F) * static_cast<U>(fmask[0]);
121126
}
122127

123-
// Update the accumulators
128+
// Update the accumulators (using exp2 to match STEEL attention)
124129
U new_max = max(max_score, score);
125-
U factor = fast::exp(max_score - new_max);
126-
U exp_score = fast::exp(score - new_max);
130+
U factor = fast::exp2(max_score - new_max);
131+
U exp_score = fast::exp2(score - new_max);
127132

128133
max_score = new_max;
129134
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -155,7 +160,7 @@ template <typename T, int D, int V = D>
155160
threadgroup_barrier(mem_flags::mem_threadgroup);
156161
max_score = max_scores[simd_lid];
157162
U new_max = simd_max(max_score);
158-
U factor = fast::exp(max_score - new_max);
163+
U factor = fast::exp2(max_score - new_max);
159164
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
160165

161166
// Now we need to aggregate all the outputs
@@ -252,8 +257,10 @@ template <typename T, int D, int V = D>
252257
maxs += o_offset * blocks + block_idx;
253258

254259
// Read the query and 0 the output accumulator
260+
// Scale by M_LOG2E_F to match STEEL attention domain (exp2 instead of exp)
261+
const U log2e_scale = static_cast<U>(scale * M_LOG2E_F);
255262
for (int i = 0; i < qk_per_thread; i++) {
256-
q[i] = static_cast<U>(scale) * queries[i];
263+
q[i] = log2e_scale * queries[i];
257264
}
258265
for (int i = 0; i < v_per_thread; i++) {
259266
o[i] = 0;
@@ -263,7 +270,8 @@ template <typename T, int D, int V = D>
263270
U sum_exp_score = 0;
264271
if (has_sinks && block_idx == 0 && simd_gid == 0) {
265272
int q_head_idx = q_batch_head_idx % num_q_heads;
266-
max_score = static_cast<U>(sinks[q_head_idx]);
273+
// Scale sink by M_LOG2E_F to match log2 domain
274+
max_score = static_cast<U>(M_LOG2E_F) * static_cast<U>(sinks[q_head_idx]);
267275
sum_exp_score = 1;
268276
}
269277

@@ -291,13 +299,14 @@ template <typename T, int D, int V = D>
291299
score = simd_sum(score);
292300

293301
if (float_mask) {
294-
score += fmask[0];
302+
// Scale float mask by M_LOG2E_F to match log2 domain
303+
score += static_cast<U>(M_LOG2E_F) * static_cast<U>(fmask[0]);
295304
}
296305

297-
// Update the accumulators
306+
// Update the accumulators (using exp2 to match STEEL attention)
298307
U new_max = max(max_score, score);
299-
U factor = fast::exp(max_score - new_max);
300-
U exp_score = fast::exp(score - new_max);
308+
U factor = fast::exp2(max_score - new_max);
309+
U exp_score = fast::exp2(score - new_max);
301310

302311
max_score = new_max;
303312
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -329,7 +338,7 @@ template <typename T, int D, int V = D>
329338
threadgroup_barrier(mem_flags::mem_threadgroup);
330339
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
331340
U new_max = simd_max(max_score);
332-
U factor = fast::exp(max_score - new_max);
341+
U factor = fast::exp2(max_score - new_max);
333342
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
334343
sum_exp_score = simd_sum(sum_exp_score * factor);
335344

@@ -342,7 +351,7 @@ template <typename T, int D, int V = D>
342351
// Now we need to aggregate all the outputs
343352
for (int i = 0; i < v_per_thread; i++) {
344353
outputs[simd_lid * BN + simd_gid] =
345-
o[i] * fast::exp(max_scores[simd_gid] - new_max);
354+
o[i] * fast::exp2(max_scores[simd_gid] - new_max);
346355
threadgroup_barrier(mem_flags::mem_threadgroup);
347356

348357
// And write the output
@@ -390,7 +399,7 @@ template <typename T, int D>
390399
// First everybody reads the max and sum_exp
391400
U max_score = maxs[simd_lid];
392401
U new_max = simd_max(max_score);
393-
U factor = fast::exp(max_score - new_max);
402+
U factor = fast::exp2(max_score - new_max);
394403
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
395404

396405
// Now read the block into registers and then use shared memory to transpose

0 commit comments

Comments
 (0)