@@ -46,6 +46,7 @@ __global__ void kernel_sdpav_1pass(
4646 const T* K,
4747 const T* V,
4848 T* O,
49+ const T* sinks,
4950 __grid_constant__ const AttnParams params) {
5051 constexpr int BN = 32 ;
5152 constexpr int BD = 32 ;
@@ -65,7 +66,7 @@ __global__ void kernel_sdpav_1pass(
6566 __shared__ U max_scores[BN];
6667 __shared__ U sum_exp_scores[BN];
6768
68- const U scale_log2 = params.scale * 1 . 44269504089f ;
69+ const U scale_log2 = params.scale * M_LOG2E ;
6970
7071 auto block = cg::this_thread_block ();
7172 auto warp = cg::tiled_partition<32 >(block);
@@ -110,6 +111,10 @@ __global__ void kernel_sdpav_1pass(
110111
111112 U max_score = -INFINITY;
112113 U sum_exp_score = 0 .f ;
114+ if (sinks && warp_idx == 0 ) {
115+ max_score = M_LOG2E * static_cast <U>(sinks[head_idx]);
116+ sum_exp_score = 1 .f ;
117+ }
113118
114119 // For each key
115120 for (int i = kv_seq_idx; i < params.kL ; i += BN) {
@@ -137,8 +142,9 @@ __global__ void kernel_sdpav_1pass(
137142
138143 // Update the accumulators
139144 U new_max = max (max_score, score);
140- U factor = exp2f (max_score - new_max);
141- U exp_score = exp2f (score - new_max);
145+ bool is_neg_inf = new_max == -INFINITY;
146+ U factor = is_neg_inf ? 1 : exp2f (max_score - new_max);
147+ U exp_score = is_neg_inf ? 0 : exp2f (score - new_max);
142148
143149 max_score = new_max;
144150 sum_exp_score = sum_exp_score * factor + exp_score;
@@ -193,6 +199,7 @@ __global__ void kernel_sdpav_2pass_1(
193199 const T* Q,
194200 const T* K,
195201 const T* V,
202+ const T* sinks,
196203 float * partials,
197204 float * sums,
198205 float * maxs,
@@ -268,8 +275,12 @@ __global__ void kernel_sdpav_2pass_1(
268275 o[i] = 0 .f ;
269276 }
270277
271- U max_score = -1e9 ;
278+ U max_score = -INFINITY ;
272279 U sum_exp_score = 0 .f ;
280+ if (sinks && warp_idx == 0 && block_idx == 0 ) {
281+ max_score = M_LOG2E * static_cast <U>(sinks[head_idx]);
282+ sum_exp_score = 1 .f ;
283+ }
273284
274285 // For each key
275286 for (int i = kv_seq_idx; i < params.kL ; i += blocks * BN) {
@@ -297,8 +308,9 @@ __global__ void kernel_sdpav_2pass_1(
297308
298309 // Update the accumulators
299310 U new_max = max (max_score, score);
300- U factor = exp2f (max_score - new_max);
301- U exp_score = exp2f (score - new_max);
311+ bool is_neg_inf = new_max == -INFINITY;
312+ U factor = is_neg_inf ? 1 : exp2f (max_score - new_max);
313+ U exp_score = is_neg_inf ? 0 : exp2f (score - new_max);
302314
303315 max_score = new_max;
304316 sum_exp_score = sum_exp_score * factor + exp_score;
@@ -463,10 +475,14 @@ void sdpa_vector_1pass_fallback(
463475 const array& v,
464476 const float scale,
465477 array& o,
466- bool do_causal_ = false ) {
478+ bool do_causal,
479+ const std::optional<array>& sinks) {
467480 encoder.set_input_array (q);
468481 encoder.set_input_array (k);
469482 encoder.set_input_array (v);
483+ if (sinks) {
484+ encoder.set_input_array (*sinks);
485+ }
470486 encoder.set_output_array (o);
471487
472488 cu::AttnParams params{
@@ -489,7 +505,7 @@ void sdpa_vector_1pass_fallback(
489505 dim3 block_dim (1024 , 1 , 1 );
490506
491507 dispatch_float_types (o.dtype (), " kernel_sdpav_1pass" , [&](auto type_tag) {
492- dispatch_bool (do_causal_ , [&](auto do_causal) {
508+ dispatch_bool (do_causal , [&](auto do_causal) {
493509 dispatch_headdim (params.D , [&](auto headdim) {
494510 using DataType = cuda_type_t <MLX_GET_TYPE (type_tag)>;
495511
@@ -504,6 +520,7 @@ void sdpa_vector_1pass_fallback(
504520 k.data <DataType>(),
505521 v.data <DataType>(),
506522 o.data <DataType>(),
523+ sinks ? (*sinks).data <DataType>() : nullptr ,
507524 params);
508525 });
509526 });
@@ -518,7 +535,8 @@ void sdpa_vector_2pass_fallback(
518535 const array& v,
519536 const float scale,
520537 array& o,
521- bool do_causal_ = false ) {
538+ bool do_causal,
539+ const std::optional<array>& sinks) {
522540 cu::AttnParams params{
523541 /* int B = */ q.shape (0 ),
524542 /* int H = */ q.shape (1 ),
@@ -559,7 +577,7 @@ void sdpa_vector_2pass_fallback(
559577 encoder.add_temporary (maxs);
560578
561579 dispatch_float_types (o.dtype (), " kernel_sdpav_2pass" , [&](auto type_tag) {
562- dispatch_bool (do_causal_ , [&](auto do_causal) {
580+ dispatch_bool (do_causal , [&](auto do_causal) {
563581 dispatch_headdim (params.D , [&](auto headdim) {
564582 using DataType = cuda_type_t <MLX_GET_TYPE (type_tag)>;
565583
@@ -570,6 +588,10 @@ void sdpa_vector_2pass_fallback(
570588 encoder.set_input_array (q);
571589 encoder.set_input_array (k);
572590 encoder.set_input_array (v);
591+ if (sinks) {
592+ encoder.set_input_array (*sinks);
593+ }
594+
573595 encoder.set_output_array (intermediate);
574596 encoder.set_output_array (sums);
575597 encoder.set_output_array (maxs);
@@ -585,6 +607,7 @@ void sdpa_vector_2pass_fallback(
585607 q.data <DataType>(),
586608 k.data <DataType>(),
587609 v.data <DataType>(),
610+ sinks ? (*sinks).data <DataType>() : nullptr ,
588611 intermediate.data <float >(),
589612 sums.data <float >(),
590613 maxs.data <float >(),
@@ -627,15 +650,16 @@ void sdpa_vector_fallback(
627650 const array& v,
628651 const float scale,
629652 array& o,
630- bool do_causal_ = false ) {
653+ bool do_causal,
654+ const std::optional<array>& sinks) {
631655 int kL = k.shape (2 );
632656
633657 if (kL > 1024 ) {
634658 return sdpa_vector_2pass_fallback (
635- s, encoder, q, k, v, scale, o, do_causal_ );
659+ s, encoder, q, k, v, scale, o, do_causal, sinks );
636660 } else {
637661 return sdpa_vector_1pass_fallback (
638- s, encoder, q, k, v, scale, o, do_causal_ );
662+ s, encoder, q, k, v, scale, o, do_causal, sinks );
639663 }
640664}
641665
@@ -691,7 +715,7 @@ void ScaledDotProductAttention::eval_gpu(
691715
692716 // Define some copy functions to ensure the layout of the inputs is as
693717 // expected.
694- copies.reserve (3 );
718+ copies.reserve (inputs. size () );
695719 auto copy_unless = [&copies, &s](
696720 auto predicate, const array& arr) -> const array& {
697721 if (!predicate (arr)) {
@@ -703,6 +727,16 @@ void ScaledDotProductAttention::eval_gpu(
703727 }
704728 };
705729
730+ // Checks that the headdim dimension has stride 1.
731+ auto is_matrix_contiguous = [](const array& arr) {
732+ return arr.strides (-1 ) == 1 ;
733+ };
734+
735+ std::optional<array> sinks = std::nullopt ;
736+ if (has_sinks_) {
737+ sinks = copy_unless (is_matrix_contiguous, inputs.back ());
738+ }
739+
706740 // We are in vector mode ie single query
707741 if (q_pre.shape (2 ) < 4 ) {
708742 auto q_copy_unless = [](const array& arr) {
@@ -740,10 +774,6 @@ void ScaledDotProductAttention::eval_gpu(
740774 const auto & k = copy_unless (kv_copy_unless, k_pre);
741775 const auto & v = copy_unless (kv_copy_unless, v_pre);
742776
743- for (const auto & cp : copies) {
744- encoder.add_temporary (cp);
745- }
746-
747777 // Donate the query if possible
748778 if (q.is_donatable () && q.flags ().row_contiguous && q.size () == o.size ()) {
749779 o.copy_shared_buffer (q);
@@ -752,22 +782,26 @@ void ScaledDotProductAttention::eval_gpu(
752782 int64_t str_oH = o.shape (3 );
753783 int64_t str_oL = o.shape (1 ) * str_oH;
754784 int64_t str_oB = o.shape (2 ) * str_oL;
755- size_t data_size = o.shape (0 ) * str_oB;
756785
757786 array::Flags flags{
758787 /* bool contiguous = */ 1 ,
759788 /* bool row_contiguous = */ o.shape (2 ) == 1 ,
760- /* bool col_contiguous = */ 0 ,
789+ /* bool col_contiguous = */ o. size () == o. shape ( 3 ) ,
761790 };
762791
763792 o.set_data (
764793 allocator::malloc (o.nbytes ()),
765- data_size ,
794+ o. size () ,
766795 {str_oB, str_oH, str_oL, str_oD},
767796 flags);
768797 }
769798
770- return sdpa_vector_fallback (s, encoder, q, k, v, scale_, o, do_causal_);
799+ for (const auto & cp : copies) {
800+ encoder.add_temporary (cp);
801+ }
802+
803+ return sdpa_vector_fallback (
804+ s, encoder, q, k, v, scale_, o, do_causal_, sinks);
771805 }
772806
773807 // Full attention mode should never reach here
0 commit comments