Commit 5a9877a
Support group query attention in Attention(23) CUDA (#27082)
This pull request introduces improvements and bug fixes to the attention
mechanism in ONNX Runtime, particularly focusing on the handling of
attention masks and the computation of attention probabilities for both
CPU and CUDA providers. The most significant changes include the
addition of a new CUDA implementation for converting boolean attention
masks to sequence lengths with validation, and several bug fixes in the
CPU attention kernel to correctly handle head indices during
computation.
**CUDA Attention Mask Conversion and Validation:**
* Added a new CUDA implementation (`attention_mask_impl.cu` and
`attention_mask_impl.h`) that efficiently converts a boolean attention
mask to sequence lengths for GQA (Grouped Query Attention) kernels. This
includes:
- A CUDA kernel that processes each batch, validates that the mask
starts with True and that padding is contiguous (right-padding only),
and computes the correct sequence length per batch.
- Support for 2D, 3D, and 4D mask shapes with proper broadcasting logic.
- Error handling for masks that do not start with True or contain
non-contiguous True/False values.
[[1]](diffhunk://#diff-00f7d49ccee44f1573357c07633bd03f21b9c2e1b1617c7a6a878a79ee6a6e49R1-R149)
[[2]](diffhunk://#diff-8aa9a15a92d7dc138346dce5de055911895d940ba2183b4ba45bd95ac0e5bfc9R1-R56)
**CPU Attention Kernel Bug Fixes:**
* Fixed bugs in the CPU attention kernel (`attention.cc`) by replacing
incorrect uses of `(head_i % parameters.kv_num_heads)` and `head_i` with
the correct `head_ki` and `head_vi` indices when accessing the K and V
matrices. This ensures correct head alignment, especially in multi-head
or grouped attention scenarios.
[[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L399-R399)
[[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L416-R416)
[[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L620-R620)
[[4]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L634-R634)
---
NOT supported in this PR
* Cross attention (q_sequence_kength != kv_sequence_length)
* 4d QKV (BNSH format)
* is_causal=False
* fp32
* Softmax precision
* qk_output_mode
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>1 parent 16667b1 commit 5a9877a
File tree
7 files changed
+2303
-35
lines changed- onnxruntime
- core/providers
- cpu/llm
- cuda/llm
- test
- providers/cpu/llm
- python/transformers
- testdata
7 files changed
+2303
-35
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
396 | 396 | | |
397 | 397 | | |
398 | 398 | | |
399 | | - | |
| 399 | + | |
400 | 400 | | |
401 | 401 | | |
402 | 402 | | |
| |||
413 | 413 | | |
414 | 414 | | |
415 | 415 | | |
416 | | - | |
| 416 | + | |
417 | 417 | | |
418 | 418 | | |
419 | 419 | | |
| |||
617 | 617 | | |
618 | 618 | | |
619 | 619 | | |
620 | | - | |
| 620 | + | |
621 | 621 | | |
622 | 622 | | |
623 | 623 | | |
624 | 624 | | |
625 | 625 | | |
626 | 626 | | |
627 | 627 | | |
628 | | - | |
629 | | - | |
630 | | - | |
631 | | - | |
632 | | - | |
633 | | - | |
634 | | - | |
635 | | - | |
636 | | - | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
637 | 637 | | |
638 | 638 | | |
639 | 639 | | |
| |||
0 commit comments