22// Licensed under the MIT License.
33
44#include < vector>
5+ #include < algorithm>
56#include " core/providers/cuda/cuda_common.h"
67#include " core/platform/env_var_utils.h"
78#include " contrib_ops/cuda/bert/group_query_attention_impl.h"
@@ -39,8 +40,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)
3940REGISTER_KERNEL_TYPED(BFloat16)
4041
4142constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE";
42- constexpr const char * kDisableFusedKv = " ORT_DISABLE_FUSED_KV" ;
4343
44+ // Group Query Attention (GQA) Operator
45+ //
46+ // This operator implements Group Query Attention, a variation of Multi-Head Attention (MHA)
47+ // where the number of key/value heads is smaller than the number of query heads.
48+ // It supports:
49+ // - Rotary Positional Embeddings (RoPE)
50+ // - KV Cache (past/present key/value)
51+ // - Quantized KV Cache (Int8/Int4) via GroupQueryAttentionData
52+ // - Flash Attention and Memory Efficient Attention backends
53+ //
4454template <typename T>
4555GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
4656 : CudaKernel(info) {
@@ -63,7 +73,7 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
6373
6474 disable_flash_attention_ = sizeof (T) != 2 || !kernel_options_->UseFlashAttention ();
6575
66- // Memory efficient attention supports float and float16. BFloat16 support is added for SM80+ via cutlass kernels .
76+ // Memory efficient attention supports float and float16. BFloat16 support added for SM80+.
6777 disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention ();
6878
6979 if (!disable_flash_attention_) {
@@ -72,9 +82,23 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
7282 }
7383
7484 disable_flash_decode_ = ParseEnvironmentVariableWithDefault<bool >(kDisableFlashDecode , false );
75- disable_fused_kv_ = ParseEnvironmentVariableWithDefault<bool >(kDisableFusedKv , false );
7685}
7786
87+ // ComputeInternal executes the GQA kernel.
88+ //
89+ // Inputs:
90+ // 0. query (Tensor) [batch, sequence_length, hidden_size]
91+ // 1. key (Tensor) [batch, sequence_length, kv_hidden_size] (Optional)
92+ // 2. value (Tensor) [batch, sequence_length, kv_hidden_size] (Optional)
93+ // 3. past_key (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional)
94+ // 4. past_value (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional)
95+ // 5. seqlens_k (Tensor) [batch] - Total sequence length minus 1 (for historical compatibility)
96+ // 6. total_seqlen (Tensor) - Max total sequence length
97+ // 7. cos_cache (Tensor) - Precomputed cosine table for RoPE
98+ // 8. sin_cache (Tensor) - Precomputed sine table for RoPE
99+ // 9. position_ids (Tensor) - Position indices for RoPE
100+ // 10. attention_bias (Tensor) - Not supported in this kernel
101+ // 11. head_sink (Tensor) - Attention sink for GPT-OSS
78102template <typename T>
79103Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
80104 const Tensor* query = context->Input <Tensor>(0 );
@@ -162,7 +186,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
162186 IAllocatorUniquePtr<void > k_buffer;
163187 IAllocatorUniquePtr<void > v_buffer;
164188 IAllocatorUniquePtr<void > rotary_buffer;
165- IAllocatorUniquePtr<void > position_ids_buffer;
166189 IAllocatorUniquePtr<void > fmha_buffer;
167190 IAllocatorUniquePtr<void > unpacked_qkv_buffer;
168191 IAllocatorUniquePtr<int > seq_lens_buffer;
@@ -185,24 +208,39 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
185208 data.past_value = (past_value == nullptr ) ? nullptr : reinterpret_cast <const CudaT*>(past_value->Data <T>());
186209 data.present_value = reinterpret_cast <CudaT*>(context->Output <Tensor>(2 )->MutableData <T>());
187210
211+ // Compute past_present_share_buffer early since it's needed for flash attention path selection.
212+ // This compares the final pointer values after quantization handling.
213+ parameters.past_present_share_buffer = (data.past_key == data.present_key );
214+
188215#if USE_FLASH_ATTENTION
189216 bool use_flash_attention = !disable_flash_attention_ &&
190217 onnxruntime::flash::is_supported<T>(device_prop,
191218 parameters.head_size ,
192219 parameters.num_heads ,
193220 parameters.kv_num_heads );
194- data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_share_buffer ;
195- if (use_flash_attention) {
196- data.use_flash_attention = true ;
197- data.use_memory_efficient_attention = false ;
198221
222+ data.use_flash_attention = use_flash_attention;
223+ data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer ;
224+
225+ if (use_flash_attention) {
199226 // Allocate Flash specific buffers (Softmax LSE, Accum)
200227 size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size (parameters.sequence_length , parameters.batch_size , parameters.num_heads );
228+
229+ int num_heads_for_split = data.use_flash_attention_fast_decode ? parameters.kv_num_heads : parameters.num_heads ;
201230 auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes (
202- parameters.batch_size , parameters.sequence_length , parameters.total_sequence_length , parameters. num_heads ,
231+ parameters.batch_size , parameters.sequence_length , parameters.total_sequence_length , num_heads_for_split ,
203232 parameters.head_size , device_prop.multiProcessorCount );
233+
204234 parameters.num_splits = static_cast <int >(num_splits);
205235
236+ if (data.use_flash_attention_fast_decode && num_splits > 1 ) {
237+ // The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel.
238+ // However, the LSE and Accum buffers must store results for ALL num_heads.
239+ softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size (num_splits, parameters.batch_size , parameters.num_heads , parameters.sequence_length );
240+ auto round_multiple = [](size_t x, size_t m) { return (x + m - 1 ) / m * m; };
241+ out_accum_bytes = onnxruntime::flash::get_out_accum_size (num_splits, parameters.batch_size , parameters.num_heads , parameters.sequence_length , round_multiple (parameters.head_size , 32 ));
242+ }
243+
206244 softmax_lse_buffer = GetScratchBuffer<void >(softmax_lse_bytes, context->GetComputeStream ());
207245 softmax_lse_accum_buffer = GetScratchBuffer<void >(softmax_lse_accum_bytes, context->GetComputeStream ());
208246 out_accum_buffer = GetScratchBuffer<void >(out_accum_bytes, context->GetComputeStream ());
@@ -214,11 +252,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
214252#endif
215253
216254 if (data.use_flash_attention_fast_decode && parameters.sequence_length == 1 ) {
217- // FlashAttentionDecoding Fast Path:
255+ // FlashDecoding Fast Path:
218256 // - Uses Flash Attention's internal KV append logic, so total_seq_lens and padded_seq_lens are not needed.
219- // - Past_seq_lens is passed as seqlens_k to Flash Attention, which uses it to:
220- // 1. Determine where to append new K/V in the cache
221- // 2. Apply correct causal masking (attention only to positions [0, past_seq_len])
222257 // - The input seqlens_k from ONNX graph is (total_len - 1), which equals past_seq_len when seq_len == 1.
223258 // - This optimization avoids launching GetSequenceLengths kernel for single-token decoding.
224259 data.past_seq_lens = const_cast <int *>(total_seq_lens_minus_one->Data <int >());
@@ -239,16 +274,20 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
239274 parameters.is_first_prompt ,
240275 cuda_stream,
241276 device_prop.maxThreadsPerBlock ));
277+ DUMP_TENSOR_INIT ();
278+ DUMP_TENSOR (" total_seq_lens" , data.total_seq_lens , parameters.batch_size , 1 );
279+ DUMP_TENSOR (" past_seq_lens" , data.past_seq_lens , parameters.batch_size , 1 );
280+ DUMP_TENSOR (" padded_seq_lens" , data.padded_seq_lens , parameters.batch_size , 1 );
242281 }
243282
244- if (!use_flash_attention) {
245- // Fall back to memory efficient attention.
246283#if USE_MEMORY_EFFICIENT_ATTENTION
284+ if (!data.use_flash_attention ) {
285+ // Fall back to memory efficient attention.
247286 int sm = (device_prop.major * 10 ) + device_prop.minor ;
248287 bool use_memory_efficient_attention =
249- !use_flash_attention &&
250288 !disable_memory_efficient_attention_ &&
251289 has_memory_efficient_attention (sm, std::is_same<T, MLFloat16>::value, std::is_same<T, BFloat16>::value, parameters.head_size , parameters.head_size );
290+ data.use_memory_efficient_attention = use_memory_efficient_attention;
252291
253292 // KV buffer for head expansion (when num_heads != kv_num_heads)
254293 size_t kv_buffer_bytes = (use_memory_efficient_attention && (parameters.num_heads != parameters.kv_num_heads ))
@@ -262,49 +301,30 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
262301 k_buffer = GetScratchBuffer<void >(kv_buffer_bytes, context->GetComputeStream ());
263302 v_buffer = GetScratchBuffer<void >(kv_buffer_bytes, context->GetComputeStream ());
264303 fmha_buffer = GetScratchBuffer<void >(fmha_buffer_bytes, context->GetComputeStream ());
265- #else
266- constexpr bool use_memory_efficient_attention = false ;
267- #endif
268-
269- data.use_memory_efficient_attention = use_memory_efficient_attention;
270- data.use_flash_attention = false ;
271304
272305 data.k = reinterpret_cast <CudaT*>(k_buffer.get ());
273306 data.v = reinterpret_cast <CudaT*>(v_buffer.get ());
274307 data.fmha_buffer = reinterpret_cast <CudaT*>(fmha_buffer.get ());
275- data.disable_fused_kv = disable_fused_kv_;
276308 }
309+ #endif
277310
311+ // -------------
278312 // Centralized scratch buffer allocation using GQABufferRequirements
279313 // This ensures allocation logic stays in sync with kernel usage
280314 auto buffer_req = GQABufferRequirements::Compute<T>(
281315 parameters,
282- use_flash_attention,
316+ data. use_flash_attention ,
283317 data.use_flash_attention_fast_decode ,
284318 data.use_memory_efficient_attention );
285319
286- if (buffer_req.unpacked_qkv_bytes > 0 ) {
287- unpacked_qkv_buffer = GetScratchBuffer<void >(buffer_req.unpacked_qkv_bytes , context->GetComputeStream ());
288- data.unpacked_qkv_buffer = reinterpret_cast <CudaT*>(unpacked_qkv_buffer.get ());
289- }
290- if (buffer_req.rotary_buffer_bytes > 0 ) {
291- rotary_buffer = GetScratchBuffer<void >(buffer_req.rotary_buffer_bytes , context->GetComputeStream ());
292- data.rotary_buffer = reinterpret_cast <CudaT*>(rotary_buffer.get ());
320+ if (buffer_req.qkv_buffer_bytes > 0 ) {
321+ unpacked_qkv_buffer = GetScratchBuffer<void >(buffer_req.qkv_buffer_bytes , context->GetComputeStream ());
322+ data.qkv_buffer = reinterpret_cast <CudaT*>(unpacked_qkv_buffer.get ());
293323 }
294- if (buffer_req.position_ids_bytes > 0 ) {
295- position_ids_buffer = GetScratchBuffer<void >(buffer_req.position_ids_bytes , context->GetComputeStream ());
296- data.position_ids_buffer = reinterpret_cast <int64_t *>(position_ids_buffer.get ());
297- }
298- #ifndef NDEBUG
299- // Track allocated sizes for validation
300- data.unpacked_qkv_buffer_size = buffer_req.unpacked_qkv_bytes ;
301- data.rotary_buffer_size = buffer_req.rotary_buffer_bytes ;
302- data.position_ids_buffer_size = buffer_req.position_ids_bytes ;
303- #endif
304324
305325 if (kernel_options_->AllowDebugInfo ()) {
306326 AttentionKernelDebugInfo debug_info;
307- debug_info.use_flash_attention = use_flash_attention;
327+ debug_info.use_flash_attention = data. use_flash_attention ;
308328 debug_info.use_efficient_attention = data.use_memory_efficient_attention ;
309329
310330 debug_info.Print (" GroupQueryAttention" ,
@@ -313,12 +333,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
313333 std::is_same<T, BFloat16>::value);
314334 }
315335
316- if (data. past_key == data. present_key ) {
317- parameters.kv_share_buffer = true ;
318- ORT_ENFORCE (data.past_value == data.present_value , " past_value and present_value must be the same tensor when kv_share_buffer is true" );
336+ // Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup)
337+ if ( parameters.past_present_share_buffer ) {
338+ ORT_ENFORCE (data.past_value == data.present_value , " past_value and present_value must be the same tensor when past_present_share_buffer is true" );
319339 } else {
320- parameters.kv_share_buffer = false ;
321- ORT_ENFORCE (data.past_value != data.present_value , " past_value and present_value must be different tensors when kv_share_buffer is false" );
340+ ORT_ENFORCE (data.past_value != data.present_value , " past_value and present_value must be different tensors when past_present_share_buffer is false" );
322341 }
323342
324343 data.output = reinterpret_cast <CudaT*>(output->MutableData <T>());
@@ -337,19 +356,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
337356 ORT_RETURN_IF_ERROR (QkvToContext<CudaT>(
338357 device_prop, cublas, context->GetComputeStream (), parameters, data));
339358
340- #ifndef NDEBUG
341- // Validate buffer usage matches allocation exactly
342- ORT_ENFORCE (data.unpacked_qkv_max_used == data.unpacked_qkv_buffer_size ,
343- " unpacked_qkv_buffer: used " , data.unpacked_qkv_max_used ,
344- " bytes but allocated " , data.unpacked_qkv_buffer_size );
345- ORT_ENFORCE (data.rotary_max_used == data.rotary_buffer_size ,
346- " rotary_buffer: used " , data.rotary_max_used ,
347- " bytes but allocated " , data.rotary_buffer_size );
348- ORT_ENFORCE (data.position_ids_max_used == data.position_ids_buffer_size ,
349- " position_ids_buffer: used " , data.position_ids_max_used ,
350- " bytes but allocated " , data.position_ids_buffer_size );
351- #endif
352-
353359 return Status::OK ();
354360}
355361
0 commit comments