Skip to content

Commit 5a9877a

Browse files
Copilottitaiwangms
andauthored
Support group query attention in Attention(23) CUDA (#27082)
This pull request introduces improvements and bug fixes to the attention mechanism in ONNX Runtime, particularly focusing on the handling of attention masks and the computation of attention probabilities for both CPU and CUDA providers. The most significant changes include the addition of a new CUDA implementation for converting boolean attention masks to sequence lengths with validation, and several bug fixes in the CPU attention kernel to correctly handle head indices during computation. **CUDA Attention Mask Conversion and Validation:** * Added a new CUDA implementation (`attention_mask_impl.cu` and `attention_mask_impl.h`) that efficiently converts a boolean attention mask to sequence lengths for GQA (Grouped Query Attention) kernels. This includes: - A CUDA kernel that processes each batch, validates that the mask starts with True and that padding is contiguous (right-padding only), and computes the correct sequence length per batch. - Support for 2D, 3D, and 4D mask shapes with proper broadcasting logic. - Error handling for masks that do not start with True or contain non-contiguous True/False values. [[1]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49R1-R149) [[2]](diffhunk://#diff-8aa9a15a92d7dc138346dce5de055911895d940ba2183b4ba45bd95ac0e5bfc9R1-R56) **CPU Attention Kernel Bug Fixes:** * Fixed bugs in the CPU attention kernel (`attention.cc`) by replacing incorrect uses of `(head_i % parameters.kv_num_heads)` and `head_i` with the correct `head_ki` and `head_vi` indices when accessing the K and V matrices. This ensures correct head alignment, especially in multi-head or grouped attention scenarios. [[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L399-R399) [[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L416-R416) [[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L620-R620) [[4]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L634-R634) --- NOT supported in this PR * Cross attention (q_sequence_kength != kv_sequence_length) * 4d QKV (BNSH format) * is_causal=False * fp32 * Softmax precision * qk_output_mode --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com> Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>
1 parent 16667b1 commit 5a9877a

File tree

7 files changed

+2303
-35
lines changed

7 files changed

+2303
-35
lines changed

onnxruntime/core/providers/cpu/llm/attention.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
396396
parameters.transpose_output
397397
? parameters.head_size * parameters.q_num_heads
398398
: static_cast<int>(parameters.head_size), // lda
399-
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters.kv_num_heads) * parameters.head_size : k,
399+
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
400400
transposed_k
401401
? parameters.head_size * parameters.kv_num_heads
402402
: static_cast<int>(parameters.head_size), // ldb
@@ -413,7 +413,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
413413
MLFloat16(alpha),
414414
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size,
415415
parameters.head_size * parameters.q_num_heads, // lda
416-
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters.kv_num_heads) * parameters.head_size : k,
416+
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
417417
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb
418418
MLFloat16(beta),
419419
output,
@@ -617,23 +617,23 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
617617
total_sequence_length, // K
618618
attention_probs + attention_probs_offset,
619619
total_sequence_length, // lda
620-
transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v,
620+
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v,
621621
transposed_v ? static_cast<int>(v_head_size * kv_num_heads) : static_cast<int>(v_head_size), // ldb
622622
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
623623
v_head_size * num_heads, // ldc
624624
MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr);
625625
} else {
626626
math::GemmEx<T, ThreadPool>(CblasNoTrans,
627627
CblasNoTrans,
628-
sequence_length, // M
629-
v_head_size, // N
630-
total_sequence_length, // K
631-
MLFloat16(1.f), // alpha
632-
attention_probs + attention_probs_offset, // QK
633-
total_sequence_length, // lda
634-
transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
635-
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
636-
MLFloat16(0.f), // beta
628+
sequence_length, // M
629+
v_head_size, // N
630+
total_sequence_length, // K
631+
MLFloat16(1.f), // alpha
632+
attention_probs + attention_probs_offset, // QK
633+
total_sequence_length, // lda
634+
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
635+
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
636+
MLFloat16(0.f), // beta
637637
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
638638
v_head_size * num_heads, // ldc
639639
nullptr);

0 commit comments

Comments
 (0)