Skip to content

[CUDA] GroupQueryAttention with XQA and Quantized KV Cache Support#27246

Merged
tianleiwu merged 8 commits intomainfrom
tlwu/gqa_xqa_quantized_kv_cache
Feb 11, 2026
Merged

[CUDA] GroupQueryAttention with XQA and Quantized KV Cache Support#27246
tianleiwu merged 8 commits intomainfrom
tlwu/gqa_xqa_quantized_kv_cache

Conversation

@tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Feb 5, 2026

Summary

This Pull Request introduces significant enhancements to the GroupQueryAttention (GQA) operator, specifically adding support for XQA kernels and Quantized KV Cache (INT8 and INT4). These changes aim to improve inference performance and reduce memory footprint for large language models.

Key Features

1. XQA Integration for GQA

  • Integrated TensorRT-LLM XQA kernels for the GQA operator, allowing for faster attention computation on supported NVIDIA GPUs.
  • Added specialized XQA loaders in onnxruntime/contrib_ops/cuda/bert/xqa/ for various precisions and head sizes.
  • Supports head sizes of 64, 128, and 256.

2. Quantized KV Cache Support

  • Added support for INT8 and INT4 quantized KV cache.
  • Implemented both per-tensor and per-channel quantization scales for flexibility and accuracy conservation.
  • Added a build flag onnxruntime_USE_INT4_KV_CACHE to enable/disable INT4 support as needed.

3. Optimized RoPE and Quantization Kernels

  • Refactored RoPE (Rotary Positional Embedding) and quantization logic to share common code paths, reducing kernel launch overhead and code duplication.
  • Improved the efficiency of unpacking and appending to the KV cache when quantization is enabled.

4. Consolidated Test & Benchmark Infrastructure

  • Introduced gqa_test_helper.py to consolidate shared test utilities, reducing duplication across test_gqa.py, test_sparse_attention.py, and benchmarks.
  • Updated benchmark_gqa.py to include tests for quantized KV cache and XQA-enabled paths.

Detailed Changes

CUDA Kernels

  • New XQA Loaders: A comprehensive set of loaders for FP16, BF16, and INT8 quantization (xqa_loader_fp16_64.cu, xqa_loader_bf16_128.cu, etc.).
  • group_query_attention_impl.cu: Updated to dispatch to XQA kernels when applicable.
  • group_query_attention_qkv.cuh & group_query_attention_qdq.cuh: Enhanced RoPE and quantization logic.

Operator Logic

  • group_query_attention.cc: Updated to handle new attributes for quantization (bit width, scale types) and manage XQA workspace allocation.
  • bert_defs.cc: Registered new attributes and updated schema for the GroupQueryAttention operator.

Testing

  • test_gqa.py: Added hundreds of test cases covering combinations of quantization types, XQA flags, and head sizes.
  • gqa_test_helper.py: Provides unified logic for input generation, reference implementation, and session management.

Verification Results

Automated Tests

  • Verified that all new GQA test cases pass with both FP16 and BF16.
  • Confirmed INT8 and INT4 quantization parity with reference implementations.
  • Ensured XQA results match non-XQA (Flash Attention / Memory Efficient Attention) implementations.

Benchmarks

  • Observed significant latency reductions when enabling XQA for GQA on supported hardware.
  • Reduced memory usage confirmed when using INT8 KV cache options.

GQA Performance Comparison in H200 GPU

Config:

  • batch=1
  • seq_len=1
  • past_seq=2048
  • num_heads=32
  • kv_heads=8
  • head_size=128
Summary
Algorithm Total Kernel Time (us)
Fp16 (Flash Attention) 30.83
Fp16 (XQA) 16.08
Int8 (XQA) 18.81
Fp16 (Flash Attention)
Kernel Name Avg(us)
flash::flash_fwd_splitkv_kernel<onnxruntime::flash::Flash_fwd_ke... 25.78
flash::flash_fwd_splitkv_combine_kernel<onnxruntime::flash::Flas... 5.05
Fp16 (XQA)
Kernel Name Avg(us)
H128::grp4_fp16::kernel_mha(unsigned int, float, Vec<__half, (un... 11.49
UnpackRoPEAppend<__half, __half, (int)16, (int)128>(const T1 *, ... 3.01
GetSequenceLengths(const int *, int *, int *, int *, int, int, b... 1.58
Int8 (XQA)
Kernel Name Avg(us)
H128::grp4_int8::kernel_mha(unsigned int, float, Vec<__half, (un... 13.76
UnpackRoPEAppend<__half, signed char, (int)8, (int)128>(const T1... 3.56
GetSequenceLengths(const int *, int *, int *, int *, int, int, b... 1.49

@tianleiwu tianleiwu marked this pull request as draft February 5, 2026 00:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request introduces significant enhancements to the GroupQueryAttention (GQA) operator for CUDA, adding:

  1. XQA (Extreme Query Attention) kernel support - High-performance kernels from TensorRT-LLM for token generation phase
  2. Quantized KV Cache - INT8 and INT4 quantization support with per-tensor and per-channel modes
  3. Refactored RoPE and quantization - Fused UnpackRoPEAppend kernel that handles quantization on-the-fly
  4. Enhanced testing infrastructure - Consolidated test helpers in gqa_test_helper.py

Changes:

  • Added XQA loader implementations for FP16/BF16 with head sizes 64, 128, 256
  • Implemented INT8/INT4 quantization/dequantization kernels with symmetric quantization
  • Extended operator schema with new type constraints (T_CACHE, T_KV_SCALE) and attributes
  • Refactored test utilities to reduce code duplication

Reviewed changes

Copilot reviewed 66 out of 68 changed files in this pull request and generated no comments.

Show a summary per file
File Description
bert_defs.cc Updated operator schema with quantization attributes and new type constraints
group_query_attention.h/cc Extended kernel registration for quantized types, added XQA support
group_query_attention_qkv.cuh Enhanced UnpackRoPEAppend kernel with INT4/INT8 quantization
group_query_attention_qdq.cuh New quantization/dequantization kernels for KV cache
xqa/xqa_loader_*.cu New XQA kernel loaders for various configurations
test_sparse_attention.py Refactored to use shared test helper module
benchmark_gqa.py Added quantization benchmarking support
CMakeLists.txt Added INT4_KV_CACHE build option

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu tianleiwu marked this pull request as ready for review February 8, 2026 23:09
@tianleiwu tianleiwu enabled auto-merge (squash) February 10, 2026 06:09
@tianleiwu tianleiwu merged commit 9adf238 into main Feb 11, 2026
125 of 141 checks passed
@tianleiwu tianleiwu deleted the tlwu/gqa_xqa_quantized_kv_cache branch February 11, 2026 18:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants