Commit 9adf238
authored
[CUDA] GroupQueryAttention with XQA and Quantized KV Cache Support (#27246)
## Summary
This Pull Request introduces significant enhancements to the
`GroupQueryAttention` (GQA) operator, specifically adding support for
**XQA** kernels and **Quantized KV Cache** (INT8 and INT4). These
changes aim to improve inference performance and reduce memory footprint
for large language models.
## Key Features
### 1. XQA Integration for GQA
- Integrated TensorRT-LLM XQA kernels for the GQA operator, allowing for
faster attention computation on supported NVIDIA GPUs.
- Added specialized XQA loaders in
`onnxruntime/contrib_ops/cuda/bert/xqa/` for various precisions and head
sizes.
- Supports head sizes of 64, 128, and 256.
### 2. Quantized KV Cache Support
- Added support for **INT8** and **INT4** quantized KV cache.
- Implemented both **per-tensor** and **per-channel** quantization
scales for flexibility and accuracy conservation.
- Added a build flag `onnxruntime_USE_INT4_KV_CACHE` to enable/disable
INT4 support as needed.
### 3. Optimized RoPE and Quantization Kernels
- Refactored RoPE (Rotary Positional Embedding) and quantization logic
to share common code paths, reducing kernel launch overhead and code
duplication.
- Improved the efficiency of unpacking and appending to the KV cache
when quantization is enabled.
### 4. Consolidated Test & Benchmark Infrastructure
- Introduced `gqa_test_helper.py` to consolidate shared test utilities,
reducing duplication across `test_gqa.py`, `test_sparse_attention.py`,
and benchmarks.
- Updated `benchmark_gqa.py` to include tests for quantized KV cache and
XQA-enabled paths.
## Detailed Changes
### CUDA Kernels
- **New XQA Loaders**: A comprehensive set of loaders for FP16, BF16,
and INT8 quantization (`xqa_loader_fp16_64.cu`,
`xqa_loader_bf16_128.cu`, etc.).
- **`group_query_attention_impl.cu`**: Updated to dispatch to XQA
kernels when applicable.
- **`group_query_attention_qkv.cuh` & `group_query_attention_qdq.cuh`**:
Enhanced RoPE and quantization logic.
### Operator Logic
- **`group_query_attention.cc`**: Updated to handle new attributes for
quantization (bit width, scale types) and manage XQA workspace
allocation.
- **`bert_defs.cc`**: Registered new attributes and updated schema for
the `GroupQueryAttention` operator.
### Testing
- **`test_gqa.py`**: Added hundreds of test cases covering combinations
of quantization types, XQA flags, and head sizes.
- **`gqa_test_helper.py`**: Provides unified logic for input generation,
reference implementation, and session management.
## Verification Results
### Automated Tests
- Verified that all new GQA test cases pass with both FP16 and BF16.
- Confirmed INT8 and INT4 quantization parity with reference
implementations.
- Ensured XQA results match non-XQA (Flash Attention / Memory Efficient
Attention) implementations.
### Benchmarks
- Observed significant latency reductions when enabling XQA for GQA on
supported hardware.
- Reduced memory usage confirmed when using INT8 KV cache options.1 parent 2cf5bbd commit 9adf238
File tree
70 files changed
+15854
-861
lines changed- cmake
- docs
- onnxruntime
- contrib_ops
- cpu
- bert
- utils
- cuda
- bert
- flash_attention
- xqa
- webgpu/bert
- core/graph/contrib_ops
- test/python/transformers
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
70 files changed
+15854
-861
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
| 106 | + | |
106 | 107 | | |
107 | 108 | | |
108 | 109 | | |
| |||
125 | 126 | | |
126 | 127 | | |
127 | 128 | | |
| 129 | + | |
128 | 130 | | |
129 | 131 | | |
130 | 132 | | |
| |||
627 | 629 | | |
628 | 630 | | |
629 | 631 | | |
630 | | - | |
631 | 632 | | |
632 | 633 | | |
633 | 634 | | |
| |||
774 | 775 | | |
775 | 776 | | |
776 | 777 | | |
777 | | - | |
778 | | - | |
| 778 | + | |
| 779 | + | |
| 780 | + | |
| 781 | + | |
| 782 | + | |
| 783 | + | |
| 784 | + | |
779 | 785 | | |
780 | 786 | | |
781 | 787 | | |
| |||
1433 | 1439 | | |
1434 | 1440 | | |
1435 | 1441 | | |
| 1442 | + | |
| 1443 | + | |
| 1444 | + | |
1436 | 1445 | | |
1437 | 1446 | | |
1438 | 1447 | | |
| |||
1446 | 1455 | | |
1447 | 1456 | | |
1448 | 1457 | | |
| 1458 | + | |
| 1459 | + | |
1449 | 1460 | | |
1450 | 1461 | | |
1451 | 1462 | | |
| |||
1779 | 1790 | | |
1780 | 1791 | | |
1781 | 1792 | | |
| 1793 | + | |
| 1794 | + | |
| 1795 | + | |
| 1796 | + | |
1782 | 1797 | | |
1783 | 1798 | | |
1784 | 1799 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
31 | | - | |
32 | 31 | | |
33 | 32 | | |
34 | 33 | | |
35 | 34 | | |
36 | | - | |
37 | 35 | | |
38 | 36 | | |
39 | 37 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2520 | 2520 | | |
2521 | 2521 | | |
2522 | 2522 | | |
2523 | | - | |
| 2523 | + | |
2524 | 2524 | | |
2525 | | - | |
2526 | | - | |
2527 | | - | |
2528 | | - | |
2529 | | - | |
2530 | | - | |
| 2525 | + | |
| 2526 | + | |
2531 | 2527 | | |
| 2528 | + | |
| 2529 | + | |
| 2530 | + | |
| 2531 | + | |
| 2532 | + | |
| 2533 | + | |
| 2534 | + | |
| 2535 | + | |
| 2536 | + | |
| 2537 | + | |
| 2538 | + | |
| 2539 | + | |
| 2540 | + | |
| 2541 | + | |
| 2542 | + | |
2532 | 2543 | | |
2533 | 2544 | | |
2534 | 2545 | | |
| |||
2539 | 2550 | | |
2540 | 2551 | | |
2541 | 2552 | | |
| 2553 | + | |
| 2554 | + | |
| 2555 | + | |
| 2556 | + | |
2542 | 2557 | | |
2543 | 2558 | | |
2544 | 2559 | | |
| |||
2555 | 2570 | | |
2556 | 2571 | | |
2557 | 2572 | | |
| 2573 | + | |
| 2574 | + | |
2558 | 2575 | | |
2559 | 2576 | | |
2560 | | - | |
| 2577 | + | |
2561 | 2578 | | |
2562 | 2579 | | |
2563 | 2580 | | |
| |||
2566 | 2583 | | |
2567 | 2584 | | |
2568 | 2585 | | |
2569 | | - | |
| 2586 | + | |
2570 | 2587 | | |
2571 | | - | |
| 2588 | + | |
2572 | 2589 | | |
2573 | 2590 | | |
2574 | 2591 | | |
| |||
2584 | 2601 | | |
2585 | 2602 | | |
2586 | 2603 | | |
| 2604 | + | |
| 2605 | + | |
| 2606 | + | |
| 2607 | + | |
2587 | 2608 | | |
2588 | 2609 | | |
2589 | 2610 | | |
2590 | 2611 | | |
2591 | 2612 | | |
2592 | 2613 | | |
2593 | 2614 | | |
2594 | | - | |
| 2615 | + | |
2595 | 2616 | | |
2596 | | - | |
| 2617 | + | |
2597 | 2618 | | |
2598 | 2619 | | |
2599 | 2620 | | |
| |||
2604 | 2625 | | |
2605 | 2626 | | |
2606 | 2627 | | |
| 2628 | + | |
| 2629 | + | |
| 2630 | + | |
| 2631 | + | |
2607 | 2632 | | |
2608 | 2633 | | |
2609 | 2634 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
577 | 577 | | |
578 | 578 | | |
579 | 579 | | |
580 | | - | |
| 580 | + | |
581 | 581 | | |
582 | 582 | | |
583 | 583 | | |
| |||
1003 | 1003 | | |
1004 | 1004 | | |
1005 | 1005 | | |
1006 | | - | |
| 1006 | + | |
1007 | 1007 | | |
1008 | 1008 | | |
1009 | 1009 | | |
| |||
1484 | 1484 | | |
1485 | 1485 | | |
1486 | 1486 | | |
1487 | | - | |
| 1487 | + | |
1488 | 1488 | | |
1489 | 1489 | | |
1490 | 1490 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
62 | 69 | | |
63 | 70 | | |
64 | 71 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
96 | 96 | | |
97 | 97 | | |
98 | 98 | | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
99 | 104 | | |
100 | 105 | | |
101 | 106 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
73 | | - | |
| 73 | + | |
| 74 | + | |
74 | 75 | | |
75 | 76 | | |
76 | 77 | | |
| |||
0 commit comments