Skip to content

Commit 9adf238

Browse files
authored
[CUDA] GroupQueryAttention with XQA and Quantized KV Cache Support (#27246)
## Summary This Pull Request introduces significant enhancements to the `GroupQueryAttention` (GQA) operator, specifically adding support for **XQA** kernels and **Quantized KV Cache** (INT8 and INT4). These changes aim to improve inference performance and reduce memory footprint for large language models. ## Key Features ### 1. XQA Integration for GQA - Integrated TensorRT-LLM XQA kernels for the GQA operator, allowing for faster attention computation on supported NVIDIA GPUs. - Added specialized XQA loaders in `onnxruntime/contrib_ops/cuda/bert/xqa/` for various precisions and head sizes. - Supports head sizes of 64, 128, and 256. ### 2. Quantized KV Cache Support - Added support for **INT8** and **INT4** quantized KV cache. - Implemented both **per-tensor** and **per-channel** quantization scales for flexibility and accuracy conservation. - Added a build flag `onnxruntime_USE_INT4_KV_CACHE` to enable/disable INT4 support as needed. ### 3. Optimized RoPE and Quantization Kernels - Refactored RoPE (Rotary Positional Embedding) and quantization logic to share common code paths, reducing kernel launch overhead and code duplication. - Improved the efficiency of unpacking and appending to the KV cache when quantization is enabled. ### 4. Consolidated Test & Benchmark Infrastructure - Introduced `gqa_test_helper.py` to consolidate shared test utilities, reducing duplication across `test_gqa.py`, `test_sparse_attention.py`, and benchmarks. - Updated `benchmark_gqa.py` to include tests for quantized KV cache and XQA-enabled paths. ## Detailed Changes ### CUDA Kernels - **New XQA Loaders**: A comprehensive set of loaders for FP16, BF16, and INT8 quantization (`xqa_loader_fp16_64.cu`, `xqa_loader_bf16_128.cu`, etc.). - **`group_query_attention_impl.cu`**: Updated to dispatch to XQA kernels when applicable. - **`group_query_attention_qkv.cuh` & `group_query_attention_qdq.cuh`**: Enhanced RoPE and quantization logic. ### Operator Logic - **`group_query_attention.cc`**: Updated to handle new attributes for quantization (bit width, scale types) and manage XQA workspace allocation. - **`bert_defs.cc`**: Registered new attributes and updated schema for the `GroupQueryAttention` operator. ### Testing - **`test_gqa.py`**: Added hundreds of test cases covering combinations of quantization types, XQA flags, and head sizes. - **`gqa_test_helper.py`**: Provides unified logic for input generation, reference implementation, and session management. ## Verification Results ### Automated Tests - Verified that all new GQA test cases pass with both FP16 and BF16. - Confirmed INT8 and INT4 quantization parity with reference implementations. - Ensured XQA results match non-XQA (Flash Attention / Memory Efficient Attention) implementations. ### Benchmarks - Observed significant latency reductions when enabling XQA for GQA on supported hardware. - Reduced memory usage confirmed when using INT8 KV cache options.
1 parent 2cf5bbd commit 9adf238

File tree

70 files changed

+15854
-861
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+15854
-861
lines changed

cmake/CMakeLists.txt

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention ke
103103
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
104104
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
105105
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
106+
option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
106107
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)
107108

108109
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
@@ -125,6 +126,7 @@ option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF)
125126

126127
option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF)
127128
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF)
129+
option(onnxruntime_DUMP_TENSOR "Dump tensor inside kernel." OFF)
128130
cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF)
129131

130132
# When loading a delay loaded DLL, Windows searches the main EXE's folder first.
@@ -627,7 +629,6 @@ else()
627629
check_cxx_compiler_flag(-Wparentheses HAS_PARENTHESES)
628630
check_cxx_compiler_flag(-Wshorten-64-to-32 HAS_SHORTEN_64_TO_32)
629631
check_cxx_compiler_flag(-Wstrict-aliasing HAS_STRICT_ALIASING)
630-
check_nvcc_compiler_flag(-Wstrict-aliasing NVCC_HAS_STRICT_ALIASING)
631632
check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW)
632633
check_cxx_compiler_flag(-Wtautological-pointer-compare HAS_TAUTOLOGICAL_POINTER_COMPARE)
633634
check_cxx_compiler_flag(-Wundefined-var-template HAS_UNDEFINED_VAR_TEMPLATE)
@@ -774,8 +775,13 @@ if (onnxruntime_USE_CUDA)
774775
endif()
775776

776777
if (onnxruntime_QUICK_BUILD)
777-
message( STATUS "Quick build mode: Flash attention limited to fp16 only")
778-
list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
778+
message( STATUS "Quick build mode: Flash attention limited to head dimension 128 only")
779+
list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
780+
endif()
781+
782+
if (onnxruntime_USE_INT4_KV_CACHE)
783+
message( STATUS "Enable int4 kv cache for CUDA EP")
784+
list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1)
779785
endif()
780786
endif()
781787

@@ -1433,6 +1439,9 @@ if (Git_FOUND)
14331439
if (onnxruntime_QUICK_BUILD)
14341440
string(APPEND ORT_BUILD_INFO "quick-build=1, ")
14351441
endif()
1442+
if (onnxruntime_USE_INT4_KV_CACHE)
1443+
string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ")
1444+
endif()
14361445
endif()
14371446
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
14381447
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)
@@ -1446,6 +1455,8 @@ if (onnxruntime_USE_CUDA)
14461455
find_package(CUDAToolkit REQUIRED)
14471456

14481457
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8)
1458+
add_definitions("-DENABLE_BF16")
1459+
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_BF16 flag")
14491460
add_definitions("-DENABLE_FP8")
14501461
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag")
14511462
endif()
@@ -1779,6 +1790,10 @@ if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
17791790
add_compile_definitions(DEBUG_NODE_INPUTS_OUTPUTS)
17801791
endif()
17811792

1793+
if (onnxruntime_DUMP_TENSOR)
1794+
add_compile_definitions(DUMP_TENSOR_LEVEL=1)
1795+
endif()
1796+
17821797
if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS)
17831798
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Linux")
17841799
message(FATAL_ERROR "External custom operator schemas feature is only supported on Linux")

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
2828
# Quick build mode: Filter flash attention kernels for faster development iteration.
2929
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
3030
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
31-
# - This regex matches both `flash_fwd_hdim*` and `flash_fwd_split_hdim*` kernels.
3231
# If new head dimensions are added or removed, update this list to match the supported set.
3332
if(onnxruntime_QUICK_BUILD)
3433
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
3534
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
36-
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16")
3735
endif()
3836

3937
file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS

docs/ContribOperators.md

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,15 +2520,26 @@ This version of the operator has been available since version 1 of the 'com.micr
25202520

25212521
### <a name="com.microsoft.GroupQueryAttention"></a><a name="com.microsoft.groupqueryattention">**com.microsoft.GroupQueryAttention**</a>
25222522

2523-
Group Query Self/Cross Attention.
2523+
Group Query Self/Cross Attention with KV Cache Quantization Support.
25242524

2525-
*Highly recommend using k-v cache share buffer for both CPU and CUDA. Enabled through IOBinding past and present kv.
2526-
Supports different number of heads for q and kv for CPU and CUDA.
2527-
Only supports causal and local attention.
2528-
Supports rotary position embedding for CPU and CUDA.
2529-
Supports packed input for CPU and CUDA.
2530-
Supports continuous decoding for batch_size == 1 for CPU and CUDA.
2525+
This operator implements causal grouped-query attention with past state (KV cache) support.
2526+
It also supports optional float8, int8 or int4 quantization for the KV cache to reduce memory footprint.
25312527

2528+
**Cache Format:**
2529+
The past and present KV cache tensors are expected in a BNSH format: `(batch_size, num_heads, cache_sequence_length, head_size)`, where `cache_sequence_length` is the length of the cached key/value sequences, or the maximum sequence length when past and present buffer sharing is used.
2530+
2531+
**Quantization:**
2532+
When quantization is enabled, `past_key` and `past_value` inputs can be of type `float8e4m3fn`, `uint8` or `int8`. The corresponding `k_scale` and `v_scale` tensors must be provided.
2533+
The operator will output `present_key` and `present_value` in same format as the `past_key` and `past_value`.
2534+
2535+
For 4-bit quantization, the data type is uint8 where each byte contains two 4-bit values. The bit width of quantized KV cache can be set using `kv_cache_bit_width` attribute.
2536+
2537+
The shapes of the k_scale, v_scale tensors shall be broadcastable to present_key shape.
2538+
2539+
**Quantization Modes (`k_quant_type`, `v_quant_type` attributes):**
2540+
- **"NONE"**: No quantization.
2541+
- **"PER_TENSOR"**: A single scale for the entire tensor. Scale example shape: `[1]`.
2542+
- **"PER_CHANNEL"**: A scale for each channel. Scale example shape: `[1, num_heads_k, 1, head_size]`.
25322543

25332544
#### Version
25342545

@@ -2539,6 +2550,10 @@ This version of the operator has been available since version 1 of the 'com.micr
25392550
<dl>
25402551
<dt><tt>do_rotary</tt> : int</dt>
25412552
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
2553+
<dt><tt>k_quant_type</tt> : string</dt>
2554+
<dd>Quantization type for K cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.</dd>
2555+
<dt><tt>kv_cache_bit_width</tt> : int</dt>
2556+
<dd>Bit width of quantized KV cache. Supported values are 8 and 4.</dd>
25422557
<dt><tt>kv_num_heads</tt> : int (required)</dt>
25432558
<dd>Number of attention heads for k and v</dd>
25442559
<dt><tt>local_window_size</tt> : int</dt>
@@ -2555,9 +2570,11 @@ This version of the operator has been available since version 1 of the 'com.micr
25552570
<dd>Use a smooth factor in softmax.</dd>
25562571
<dt><tt>softcap</tt> : float</dt>
25572572
<dd>Softcap value for attention weights. Default value is 0.</dd>
2573+
<dt><tt>v_quant_type</tt> : string</dt>
2574+
<dd>Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.</dd>
25582575
</dl>
25592576

2560-
#### Inputs (7 - 12)
2577+
#### Inputs (7 - 14)
25612578

25622579
<dl>
25632580
<dt><tt>query</tt> : T</dt>
@@ -2566,9 +2583,9 @@ This version of the operator has been available since version 1 of the 'com.micr
25662583
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
25672584
<dt><tt>value</tt> (optional) : T</dt>
25682585
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
2569-
<dt><tt>past_key</tt> (optional) : T</dt>
2586+
<dt><tt>past_key</tt> (optional) : T_CACHE</dt>
25702587
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
2571-
<dt><tt>past_value</tt> (optional) : T</dt>
2588+
<dt><tt>past_value</tt> (optional) : T_CACHE</dt>
25722589
<dd>past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
25732590
<dt><tt>seqlens_k</tt> : M</dt>
25742591
<dd>1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).</dd>
@@ -2584,16 +2601,20 @@ This version of the operator has been available since version 1 of the 'com.micr
25842601
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
25852602
<dt><tt>head_sink</tt> (optional) : T</dt>
25862603
<dd>1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.</dd>
2604+
<dt><tt>k_scale</tt> (optional) : T_KV_SCALE</dt>
2605+
<dd>Scale tensor for past_key.</dd>
2606+
<dt><tt>v_scale</tt> (optional) : T_KV_SCALE</dt>
2607+
<dd>Scale tensor for past_value.</dd>
25872608
</dl>
25882609

25892610
#### Outputs (3 - 4)
25902611

25912612
<dl>
25922613
<dt><tt>output</tt> : T</dt>
25932614
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
2594-
<dt><tt>present_key</tt> : T</dt>
2615+
<dt><tt>present_key</tt> : T_CACHE</dt>
25952616
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
2596-
<dt><tt>present_value</tt> : T</dt>
2617+
<dt><tt>present_value</tt> : T_CACHE</dt>
25972618
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
25982619
<dt><tt>output_qk</tt> (optional) : T</dt>
25992620
<dd>Values of QK matrix multiplication, either before or after softmax normalization</dd>
@@ -2604,6 +2625,10 @@ This version of the operator has been available since version 1 of the 'com.micr
26042625
<dl>
26052626
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16), tensor(float)</dt>
26062627
<dd>Constrain input and output to float tensors.</dd>
2628+
<dt><tt>T_CACHE</tt> : tensor(float), tensor(float16), tensor(bfloat16), tensor(uint8), tensor(int8), tensor(float8e4m3fn)</dt>
2629+
<dd>Constrain KV cache types.</dd>
2630+
<dt><tt>T_KV_SCALE</tt> : tensor(float)</dt>
2631+
<dd>Constrain KV cache scale types.</dd>
26072632
<dt><tt>M</tt> : tensor(int32)</dt>
26082633
<dd>Constrain mask to int tensor.</dd>
26092634
</dl>

docs/OperatorKernels.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ Do not modify directly.*
577577
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
578578
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
579579
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
580-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
580+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
581581
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
582582
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
583583
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
@@ -1003,7 +1003,7 @@ Do not modify directly.*
10031003
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
10041004
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
10051005
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
1006-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
1006+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
10071007
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
10081008
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
10091009
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1484,7 +1484,7 @@ Do not modify directly.*
14841484
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
14851485
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
14861486
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
1487-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
1487+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
14881488
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
14891489
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
14901490
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cpu/bert/attention_common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ enum class QKOutputType : int {
5959
AFTER_SOFTMAX = 2
6060
};
6161

62+
// Enum to define quantization granularity.
63+
enum class KVQuantizationType : int {
64+
NONE = 0,
65+
PER_TENSOR = 1,
66+
PER_CHANNEL = 2,
67+
};
68+
6269
constexpr bool LAYOUT_BSNH = false;
6370
constexpr bool LAYOUT_BNSH = true;
6471

onnxruntime/contrib_ops/cpu/bert/attention_parameters.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ struct GroupQueryAttentionParameters : AttentionParameters {
9696
AttentionQkvFormat past_kv_format;
9797
int zeros_count;
9898
int* zero_ptr;
99+
100+
// Quantization parameters for KV cache
101+
KVQuantizationType k_quant_type = KVQuantizationType::NONE;
102+
KVQuantizationType v_quant_type = KVQuantizationType::NONE;
103+
int kv_cache_bit_width = 0;
99104
};
100105

101106
// Parameters deduced from node attributes and inputs/outputs.

onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
7070
seqlens_k,
7171
total_seqlen_tensor,
7272
scale_,
73-
softcap_));
73+
softcap_,
74+
0));
7475

7576
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
7677
attention_bias,

0 commit comments

Comments
 (0)