Skip to content

Commit d6977f2

Browse files
authored
Add sdpa with sinks (#2558)
* add sdpa with sinks * fix 2 pass * fix matrix sdpa * fix perf regression * add to cuda (#2580)
1 parent db5443e commit d6977f2

File tree

9 files changed

+351
-116
lines changed

9 files changed

+351
-116
lines changed

mlx/backend/cuda/scaled_dot_product_attention.cu

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
4646
const T* K,
4747
const T* V,
4848
T* O,
49+
const T* sinks,
4950
__grid_constant__ const AttnParams params) {
5051
constexpr int BN = 32;
5152
constexpr int BD = 32;
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
6566
__shared__ U max_scores[BN];
6667
__shared__ U sum_exp_scores[BN];
6768

68-
const U scale_log2 = params.scale * 1.44269504089f;
69+
const U scale_log2 = params.scale * M_LOG2E;
6970

7071
auto block = cg::this_thread_block();
7172
auto warp = cg::tiled_partition<32>(block);
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
110111

111112
U max_score = -INFINITY;
112113
U sum_exp_score = 0.f;
114+
if (sinks && warp_idx == 0) {
115+
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
116+
sum_exp_score = 1.f;
117+
}
113118

114119
// For each key
115120
for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
137142

138143
// Update the accumulators
139144
U new_max = max(max_score, score);
140-
U factor = exp2f(max_score - new_max);
141-
U exp_score = exp2f(score - new_max);
145+
bool is_neg_inf = new_max == -INFINITY;
146+
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
147+
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
142148

143149
max_score = new_max;
144150
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
193199
const T* Q,
194200
const T* K,
195201
const T* V,
202+
const T* sinks,
196203
float* partials,
197204
float* sums,
198205
float* maxs,
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
268275
o[i] = 0.f;
269276
}
270277

271-
U max_score = -1e9;
278+
U max_score = -INFINITY;
272279
U sum_exp_score = 0.f;
280+
if (sinks && warp_idx == 0 && block_idx == 0) {
281+
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
282+
sum_exp_score = 1.f;
283+
}
273284

274285
// For each key
275286
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
297308

298309
// Update the accumulators
299310
U new_max = max(max_score, score);
300-
U factor = exp2f(max_score - new_max);
301-
U exp_score = exp2f(score - new_max);
311+
bool is_neg_inf = new_max == -INFINITY;
312+
U factor = is_neg_inf ? 1 : exp2f(max_score - new_max);
313+
U exp_score = is_neg_inf ? 0 : exp2f(score - new_max);
302314

303315
max_score = new_max;
304316
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback(
463475
const array& v,
464476
const float scale,
465477
array& o,
466-
bool do_causal_ = false) {
478+
bool do_causal,
479+
const std::optional<array>& sinks) {
467480
encoder.set_input_array(q);
468481
encoder.set_input_array(k);
469482
encoder.set_input_array(v);
483+
if (sinks) {
484+
encoder.set_input_array(*sinks);
485+
}
470486
encoder.set_output_array(o);
471487

472488
cu::AttnParams params{
@@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback(
489505
dim3 block_dim(1024, 1, 1);
490506

491507
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
492-
dispatch_bool(do_causal_, [&](auto do_causal) {
508+
dispatch_bool(do_causal, [&](auto do_causal) {
493509
dispatch_headdim(params.D, [&](auto headdim) {
494510
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
495511

@@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback(
504520
k.data<DataType>(),
505521
v.data<DataType>(),
506522
o.data<DataType>(),
523+
sinks ? (*sinks).data<DataType>() : nullptr,
507524
params);
508525
});
509526
});
@@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback(
518535
const array& v,
519536
const float scale,
520537
array& o,
521-
bool do_causal_ = false) {
538+
bool do_causal,
539+
const std::optional<array>& sinks) {
522540
cu::AttnParams params{
523541
/* int B = */ q.shape(0),
524542
/* int H = */ q.shape(1),
@@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback(
559577
encoder.add_temporary(maxs);
560578

561579
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
562-
dispatch_bool(do_causal_, [&](auto do_causal) {
580+
dispatch_bool(do_causal, [&](auto do_causal) {
563581
dispatch_headdim(params.D, [&](auto headdim) {
564582
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
565583

@@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback(
570588
encoder.set_input_array(q);
571589
encoder.set_input_array(k);
572590
encoder.set_input_array(v);
591+
if (sinks) {
592+
encoder.set_input_array(*sinks);
593+
}
594+
573595
encoder.set_output_array(intermediate);
574596
encoder.set_output_array(sums);
575597
encoder.set_output_array(maxs);
@@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback(
585607
q.data<DataType>(),
586608
k.data<DataType>(),
587609
v.data<DataType>(),
610+
sinks ? (*sinks).data<DataType>() : nullptr,
588611
intermediate.data<float>(),
589612
sums.data<float>(),
590613
maxs.data<float>(),
@@ -627,15 +650,16 @@ void sdpa_vector_fallback(
627650
const array& v,
628651
const float scale,
629652
array& o,
630-
bool do_causal_ = false) {
653+
bool do_causal,
654+
const std::optional<array>& sinks) {
631655
int kL = k.shape(2);
632656

633657
if (kL > 1024) {
634658
return sdpa_vector_2pass_fallback(
635-
s, encoder, q, k, v, scale, o, do_causal_);
659+
s, encoder, q, k, v, scale, o, do_causal, sinks);
636660
} else {
637661
return sdpa_vector_1pass_fallback(
638-
s, encoder, q, k, v, scale, o, do_causal_);
662+
s, encoder, q, k, v, scale, o, do_causal, sinks);
639663
}
640664
}
641665

@@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu(
691715

692716
// Define some copy functions to ensure the layout of the inputs is as
693717
// expected.
694-
copies.reserve(3);
718+
copies.reserve(inputs.size());
695719
auto copy_unless = [&copies, &s](
696720
auto predicate, const array& arr) -> const array& {
697721
if (!predicate(arr)) {
@@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu(
703727
}
704728
};
705729

730+
// Checks that the headdim dimension has stride 1.
731+
auto is_matrix_contiguous = [](const array& arr) {
732+
return arr.strides(-1) == 1;
733+
};
734+
735+
std::optional<array> sinks = std::nullopt;
736+
if (has_sinks_) {
737+
sinks = copy_unless(is_matrix_contiguous, inputs.back());
738+
}
739+
706740
// We are in vector mode ie single query
707741
if (q_pre.shape(2) < 4) {
708742
auto q_copy_unless = [](const array& arr) {
@@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu(
740774
const auto& k = copy_unless(kv_copy_unless, k_pre);
741775
const auto& v = copy_unless(kv_copy_unless, v_pre);
742776

743-
for (const auto& cp : copies) {
744-
encoder.add_temporary(cp);
745-
}
746-
747777
// Donate the query if possible
748778
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
749779
o.copy_shared_buffer(q);
@@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu(
752782
int64_t str_oH = o.shape(3);
753783
int64_t str_oL = o.shape(1) * str_oH;
754784
int64_t str_oB = o.shape(2) * str_oL;
755-
size_t data_size = o.shape(0) * str_oB;
756785

757786
array::Flags flags{
758787
/* bool contiguous = */ 1,
759788
/* bool row_contiguous = */ o.shape(2) == 1,
760-
/* bool col_contiguous = */ 0,
789+
/* bool col_contiguous = */ o.size() == o.shape(3),
761790
};
762791

763792
o.set_data(
764793
allocator::malloc(o.nbytes()),
765-
data_size,
794+
o.size(),
766795
{str_oB, str_oH, str_oL, str_oD},
767796
flags);
768797
}
769798

770-
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
799+
for (const auto& cp : copies) {
800+
encoder.add_temporary(cp);
801+
}
802+
803+
return sdpa_vector_fallback(
804+
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
771805
}
772806

773807
// Full attention mode should never reach here

mlx/backend/metal/kernels/sdpa_vector.h

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
99
constant bool do_causal [[function_constant(22)]];
1010
constant bool bool_mask [[function_constant(23)]];
1111
constant bool float_mask [[function_constant(24)]];
12+
constant bool has_sinks [[function_constant(25)]];
1213

1314
template <typename T, int D, int V = D>
1415
[[kernel]] void sdpa_vector(
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
3132
[[buffer(14), function_constant(has_mask)]],
3233
const constant int& mask_head_stride
3334
[[buffer(15), function_constant(has_mask)]],
35+
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
36+
const constant int& num_q_heads
37+
[[buffer(17), function_constant(has_sinks)]],
3438
uint3 tid [[threadgroup_position_in_grid]],
3539
uint3 tpg [[threadgroups_per_grid]],
3640
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
5357
threadgroup U sum_exp_scores[BN];
5458

5559
// Adjust positions
56-
const int head_idx = tid.x;
60+
const int q_batch_head_idx = tid.x;
5761
const int q_seq_idx = tid.y;
58-
const int kv_head_idx = head_idx / gqa_factor;
59-
const int o_offset = head_idx * tpg.y + q_seq_idx;
62+
const int kv_head_idx = q_batch_head_idx / gqa_factor;
63+
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
6064
const int q_offset =
61-
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
65+
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
6266
queries += q_offset * D + simd_lid * qk_per_thread;
6367
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
6468
simd_lid * qk_per_thread;
6569
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
6670
simd_lid * v_per_thread;
6771
if (bool_mask) {
68-
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
69-
q_seq_idx * mask_q_seq_stride;
72+
bmask += q_batch_head_idx * mask_head_stride +
73+
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
7074
}
7175
if (float_mask) {
72-
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
73-
q_seq_idx * mask_q_seq_stride;
76+
fmask += q_batch_head_idx * mask_head_stride +
77+
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
7478
}
7579

7680
out += o_offset * V + simd_gid * v_per_thread;
@@ -85,6 +89,10 @@ template <typename T, int D, int V = D>
8589

8690
U max_score = -INFINITY;
8791
U sum_exp_score = 0;
92+
if (has_sinks && simd_gid == 0) {
93+
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
94+
sum_exp_score = 1;
95+
}
8896

8997
// For each key
9098
for (int i = simd_gid; i < N; i += BN) {
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
93101
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
94102
} else if (bool_mask) {
95103
use_key = bmask[0];
104+
} else if (float_mask) {
105+
use_key = (fmask[0] >= Limits<T>::finite_min);
96106
}
97107
if (use_key) {
98108
// Read the key
@@ -107,13 +117,14 @@ template <typename T, int D, int V = D>
107117
}
108118
score = simd_sum(score);
109119
if (float_mask) {
110-
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
120+
score += static_cast<U>(fmask[0]);
111121
}
112122

113123
// Update the accumulators
114124
U new_max = max(max_score, score);
115-
U factor = fast::exp(max_score - new_max);
116-
U exp_score = fast::exp(score - new_max);
125+
bool is_neg_inf = new_max == -INFINITY;
126+
U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max);
127+
U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max);
117128

118129
max_score = new_max;
119130
sum_exp_score = sum_exp_score * factor + exp_score;
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
187198
[[buffer(16), function_constant(has_mask)]],
188199
const constant int& mask_head_stride
189200
[[buffer(17), function_constant(has_mask)]],
201+
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
202+
const constant int& num_q_heads
203+
[[buffer(19), function_constant(has_sinks)]],
190204
uint3 tid [[threadgroup_position_in_grid]],
191205
uint3 tpg [[threadgroups_per_grid]],
192206
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
211225

212226
// Adjust positions
213227
const int block_idx = tid.z;
214-
const int head_idx = tid.x;
228+
const int q_batch_head_idx = tid.x;
215229
const int q_seq_idx = tid.y;
216-
const int o_offset = head_idx * tpg.y + q_seq_idx;
230+
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
217231
const int q_offset =
218-
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
219-
const int kv_head_idx = head_idx / gqa_factor;
232+
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
233+
const int kv_head_idx = q_batch_head_idx / gqa_factor;
220234

221235
queries += q_offset * D + simd_lid * qk_per_thread;
222236
keys += kv_head_idx * k_head_stride +
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
225239
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
226240
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
227241
if (bool_mask) {
228-
bmask += head_idx * mask_head_stride +
242+
bmask += q_batch_head_idx * mask_head_stride +
229243
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
230244
q_seq_idx * mask_q_seq_stride;
231245
}
232246
if (float_mask) {
233-
fmask += head_idx * mask_head_stride +
247+
fmask += q_batch_head_idx * mask_head_stride +
234248
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
235249
q_seq_idx * mask_q_seq_stride;
236250
}
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
245259
o[i] = 0;
246260
}
247261

248-
U max_score = -1e9;
262+
U max_score = -INFINITY;
249263
U sum_exp_score = 0;
264+
if (has_sinks && block_idx == 0 && simd_gid == 0) {
265+
int q_head_idx = q_batch_head_idx % num_q_heads;
266+
max_score = static_cast<U>(sinks[q_head_idx]);
267+
sum_exp_score = 1;
268+
}
250269

251270
// For each key
252271
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
255274
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
256275
} else if (bool_mask) {
257276
use_key = bmask[0];
277+
} else if (float_mask) {
278+
use_key = (fmask[0] >= Limits<T>::finite_min);
258279
}
259280
if (use_key) {
260281
// Read the key
@@ -268,6 +289,10 @@ template <typename T, int D, int V = D>
268289
score += q[i] * k[i];
269290
}
270291
score = simd_sum(score);
292+
if (score < Limits<T>::finite_min) {
293+
continue;
294+
}
295+
271296
if (float_mask) {
272297
score += fmask[0];
273298
}

0 commit comments

Comments
 (0)