[CUDA] GroupQueryAttention with XQA and Quantized KV Cache Support#27246
Merged
[CUDA] GroupQueryAttention with XQA and Quantized KV Cache Support#27246
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This pull request introduces significant enhancements to the GroupQueryAttention (GQA) operator for CUDA, adding:
- XQA (Extreme Query Attention) kernel support - High-performance kernels from TensorRT-LLM for token generation phase
- Quantized KV Cache - INT8 and INT4 quantization support with per-tensor and per-channel modes
- Refactored RoPE and quantization - Fused UnpackRoPEAppend kernel that handles quantization on-the-fly
- Enhanced testing infrastructure - Consolidated test helpers in
gqa_test_helper.py
Changes:
- Added XQA loader implementations for FP16/BF16 with head sizes 64, 128, 256
- Implemented INT8/INT4 quantization/dequantization kernels with symmetric quantization
- Extended operator schema with new type constraints (T_CACHE, T_KV_SCALE) and attributes
- Refactored test utilities to reduce code duplication
Reviewed changes
Copilot reviewed 66 out of 68 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
bert_defs.cc |
Updated operator schema with quantization attributes and new type constraints |
group_query_attention.h/cc |
Extended kernel registration for quantized types, added XQA support |
group_query_attention_qkv.cuh |
Enhanced UnpackRoPEAppend kernel with INT4/INT8 quantization |
group_query_attention_qdq.cuh |
New quantization/dequantization kernels for KV cache |
xqa/xqa_loader_*.cu |
New XQA kernel loaders for various configurations |
test_sparse_attention.py |
Refactored to use shared test helper module |
benchmark_gqa.py |
Added quantization benchmarking support |
CMakeLists.txt |
Added INT4_KV_CACHE build option |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Outdated
Show resolved
Hide resolved
kunal-vaishnavi
approved these changes
Feb 10, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This Pull Request introduces significant enhancements to the
GroupQueryAttention(GQA) operator, specifically adding support for XQA kernels and Quantized KV Cache (INT8 and INT4). These changes aim to improve inference performance and reduce memory footprint for large language models.Key Features
1. XQA Integration for GQA
onnxruntime/contrib_ops/cuda/bert/xqa/for various precisions and head sizes.2. Quantized KV Cache Support
onnxruntime_USE_INT4_KV_CACHEto enable/disable INT4 support as needed.3. Optimized RoPE and Quantization Kernels
4. Consolidated Test & Benchmark Infrastructure
gqa_test_helper.pyto consolidate shared test utilities, reducing duplication acrosstest_gqa.py,test_sparse_attention.py, and benchmarks.benchmark_gqa.pyto include tests for quantized KV cache and XQA-enabled paths.Detailed Changes
CUDA Kernels
xqa_loader_fp16_64.cu,xqa_loader_bf16_128.cu, etc.).group_query_attention_impl.cu: Updated to dispatch to XQA kernels when applicable.group_query_attention_qkv.cuh&group_query_attention_qdq.cuh: Enhanced RoPE and quantization logic.Operator Logic
group_query_attention.cc: Updated to handle new attributes for quantization (bit width, scale types) and manage XQA workspace allocation.bert_defs.cc: Registered new attributes and updated schema for theGroupQueryAttentionoperator.Testing
test_gqa.py: Added hundreds of test cases covering combinations of quantization types, XQA flags, and head sizes.gqa_test_helper.py: Provides unified logic for input generation, reference implementation, and session management.Verification Results
Automated Tests
Benchmarks
GQA Performance Comparison in H200 GPU
Config:
Summary
Fp16 (Flash Attention)
flash::flash_fwd_splitkv_kernel<onnxruntime::flash::Flash_fwd_ke...flash::flash_fwd_splitkv_combine_kernel<onnxruntime::flash::Flas...Fp16 (XQA)
H128::grp4_fp16::kernel_mha(unsigned int, float, Vec<__half, (un...UnpackRoPEAppend<__half, __half, (int)16, (int)128>(const T1 *, ...GetSequenceLengths(const int *, int *, int *, int *, int, int, b...Int8 (XQA)
H128::grp4_int8::kernel_mha(unsigned int, float, Vec<__half, (un...UnpackRoPEAppend<__half, signed char, (int)8, (int)128>(const T1...GetSequenceLengths(const int *, int *, int *, int *, int, int, b...