diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 42432041a8b01..cab8092e30260 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -103,6 +103,7 @@ cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention ke
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
+option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
@@ -125,6 +126,7 @@ option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF)
option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF)
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF)
+option(onnxruntime_DUMP_TENSOR "Dump tensor inside kernel." OFF)
cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF)
# When loading a delay loaded DLL, Windows searches the main EXE's folder first.
@@ -627,7 +629,6 @@ else()
check_cxx_compiler_flag(-Wparentheses HAS_PARENTHESES)
check_cxx_compiler_flag(-Wshorten-64-to-32 HAS_SHORTEN_64_TO_32)
check_cxx_compiler_flag(-Wstrict-aliasing HAS_STRICT_ALIASING)
- check_nvcc_compiler_flag(-Wstrict-aliasing NVCC_HAS_STRICT_ALIASING)
check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW)
check_cxx_compiler_flag(-Wtautological-pointer-compare HAS_TAUTOLOGICAL_POINTER_COMPARE)
check_cxx_compiler_flag(-Wundefined-var-template HAS_UNDEFINED_VAR_TEMPLATE)
@@ -774,8 +775,13 @@ if (onnxruntime_USE_CUDA)
endif()
if (onnxruntime_QUICK_BUILD)
- message( STATUS "Quick build mode: Flash attention limited to fp16 only")
- list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
+ message( STATUS "Quick build mode: Flash attention limited to head dimension 128 only")
+ list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
+ endif()
+
+ if (onnxruntime_USE_INT4_KV_CACHE)
+ message( STATUS "Enable int4 kv cache for CUDA EP")
+ list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1)
endif()
endif()
@@ -1433,6 +1439,9 @@ if (Git_FOUND)
if (onnxruntime_QUICK_BUILD)
string(APPEND ORT_BUILD_INFO "quick-build=1, ")
endif()
+ if (onnxruntime_USE_INT4_KV_CACHE)
+ string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ")
+ endif()
endif()
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)
@@ -1446,6 +1455,8 @@ if (onnxruntime_USE_CUDA)
find_package(CUDAToolkit REQUIRED)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8)
+ add_definitions("-DENABLE_BF16")
+ message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_BF16 flag")
add_definitions("-DENABLE_FP8")
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag")
endif()
@@ -1779,6 +1790,10 @@ if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
add_compile_definitions(DEBUG_NODE_INPUTS_OUTPUTS)
endif()
+if (onnxruntime_DUMP_TENSOR)
+ add_compile_definitions(DUMP_TENSOR_LEVEL=1)
+endif()
+
if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS)
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Linux")
message(FATAL_ERROR "External custom operator schemas feature is only supported on Linux")
diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake
index f77a5dd78fcc5..f6efcb3fad6a9 100644
--- a/cmake/onnxruntime_providers_cpu.cmake
+++ b/cmake/onnxruntime_providers_cpu.cmake
@@ -28,12 +28,10 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
# Quick build mode: Filter flash attention kernels for faster development iteration.
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
-# - This regex matches both `flash_fwd_hdim*` and `flash_fwd_split_hdim*` kernels.
# If new head dimensions are added or removed, update this list to match the supported set.
if(onnxruntime_QUICK_BUILD)
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
- list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16")
endif()
file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index b7b68ff324d9f..149ecdb969bd5 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2520,15 +2520,26 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.GroupQueryAttention**
- Group Query Self/Cross Attention.
+ Group Query Self/Cross Attention with KV Cache Quantization Support.
- *Highly recommend using k-v cache share buffer for both CPU and CUDA. Enabled through IOBinding past and present kv.
- Supports different number of heads for q and kv for CPU and CUDA.
- Only supports causal and local attention.
- Supports rotary position embedding for CPU and CUDA.
- Supports packed input for CPU and CUDA.
- Supports continuous decoding for batch_size == 1 for CPU and CUDA.
+ This operator implements causal grouped-query attention with past state (KV cache) support.
+ It also supports optional float8, int8 or int4 quantization for the KV cache to reduce memory footprint.
+ **Cache Format:**
+ The past and present KV cache tensors are expected in a BNSH format: `(batch_size, num_heads, cache_sequence_length, head_size)`, where `cache_sequence_length` is the length of the cached key/value sequences, or the maximum sequence length when past and present buffer sharing is used.
+
+ **Quantization:**
+ When quantization is enabled, `past_key` and `past_value` inputs can be of type `float8e4m3fn`, `uint8` or `int8`. The corresponding `k_scale` and `v_scale` tensors must be provided.
+ The operator will output `present_key` and `present_value` in same format as the `past_key` and `past_value`.
+
+ For 4-bit quantization, the data type is uint8 where each byte contains two 4-bit values. The bit width of quantized KV cache can be set using `kv_cache_bit_width` attribute.
+
+ The shapes of the k_scale, v_scale tensors shall be broadcastable to present_key shape.
+
+ **Quantization Modes (`k_quant_type`, `v_quant_type` attributes):**
+ - **"NONE"**: No quantization.
+ - **"PER_TENSOR"**: A single scale for the entire tensor. Scale example shape: `[1]`.
+ - **"PER_CHANNEL"**: A scale for each channel. Scale example shape: `[1, num_heads_k, 1, head_size]`.
#### Version
@@ -2539,6 +2550,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- do_rotary : int
- Whether to use rotary position embedding. Default value is 0.
+- k_quant_type : string
+- Quantization type for K cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.
+- kv_cache_bit_width : int
+- Bit width of quantized KV cache. Supported values are 8 and 4.
- kv_num_heads : int (required)
- Number of attention heads for k and v
- local_window_size : int
@@ -2555,9 +2570,11 @@ This version of the operator has been available since version 1 of the 'com.micr
- Use a smooth factor in softmax.
- softcap : float
- Softcap value for attention weights. Default value is 0.
+- v_quant_type : string
+- Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.
-#### Inputs (7 - 12)
+#### Inputs (7 - 14)
- query : T
@@ -2566,9 +2583,9 @@ This version of the operator has been available since version 1 of the 'com.micr
- Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
- value (optional) : T
- Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
-- past_key (optional) : T
+- past_key (optional) : T_CACHE
- past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
-- past_value (optional) : T
+- past_value (optional) : T_CACHE
- past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
- seqlens_k : M
- 1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).
@@ -2584,6 +2601,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
- head_sink (optional) : T
- 1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
+- k_scale (optional) : T_KV_SCALE
+- Scale tensor for past_key.
+- v_scale (optional) : T_KV_SCALE
+- Scale tensor for past_value.
#### Outputs (3 - 4)
@@ -2591,9 +2612,9 @@ This version of the operator has been available since version 1 of the 'com.micr
- output : T
- 3D output tensor with shape (batch_size, sequence_length, hidden_size)
-- present_key : T
+- present_key : T_CACHE
- present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-- present_value : T
+- present_value : T_CACHE
- present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
- output_qk (optional) : T
- Values of QK matrix multiplication, either before or after softmax normalization
@@ -2604,6 +2625,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- T : tensor(float16), tensor(bfloat16), tensor(float)
- Constrain input and output to float tensors.
+- T_CACHE : tensor(float), tensor(float16), tensor(bfloat16), tensor(uint8), tensor(int8), tensor(float8e4m3fn)
+- Constrain KV cache types.
+- T_KV_SCALE : tensor(float)
+- Constrain KV cache scale types.
- M : tensor(int32)
- Constrain mask to int tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 08840c623b709..0230f2866fcb4 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -577,7 +577,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
@@ -1003,7 +1003,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)
**T_KV_SCALE** = tensor(float)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1484,7 +1484,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index 80d374d3f0b25..e886adac03f27 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -59,6 +59,13 @@ enum class QKOutputType : int {
AFTER_SOFTMAX = 2
};
+// Enum to define quantization granularity.
+enum class KVQuantizationType : int {
+ NONE = 0,
+ PER_TENSOR = 1,
+ PER_CHANNEL = 2,
+};
+
constexpr bool LAYOUT_BSNH = false;
constexpr bool LAYOUT_BNSH = true;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
index 4ad11dce7e093..9a123e80adc18 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
@@ -96,6 +96,11 @@ struct GroupQueryAttentionParameters : AttentionParameters {
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;
+
+ // Quantization parameters for KV cache
+ KVQuantizationType k_quant_type = KVQuantizationType::NONE;
+ KVQuantizationType v_quant_type = KVQuantizationType::NONE;
+ int kv_cache_bit_width = 0;
};
// Parameters deduced from node attributes and inputs/outputs.
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index eb1560ac8e341..0fe37ae68c012 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -70,7 +70,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
seqlens_k,
total_seqlen_tensor,
scale_,
- softcap_));
+ softcap_,
+ 0));
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
attention_bias,
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
index f7c54ad456925..1515adb1fc6ff 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
@@ -97,7 +97,7 @@ Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const
}
template
-Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size,
+Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, int kv_cache_bit_width,
int& past_sequence_length) {
const auto& past_key_dims = past_key->Shape().GetDims();
const auto& past_value_dims = past_value->Shape().GetDims();
@@ -141,15 +141,18 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_
// We assume all sequence in past kv are right-padded to max or past sequence length
past_sequence_length = static_cast(past_key_dims[2]);
- if (past_key_dims[3] != head_size) {
+ // For 4-bit quantized KV cache, actual dimension is head_size / 2 because 2 nibbles are packed into one byte.
+ // Note that we have checked that head_size is a multiple of 8 in Check_QKV.
+ int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size;
+ if (past_key_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
- past_key_dims[3]);
+ past_key_dims[3], " expected ", packed_head_size);
}
- if (past_value_dims[3] != head_size) {
+ if (past_value_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 3 should be same as head_size, got ",
- past_value_dims[3]);
+ past_value_dims[3], " expected ", packed_head_size);
}
return Status::OK();
}
@@ -203,7 +206,8 @@ Status CheckInputs(const T* query,
const T* seqlens_k,
const T* total_seqlen,
float scale,
- float softcap) {
+ float softcap,
+ int kv_cache_bit_width) {
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
// past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
// past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
@@ -221,6 +225,18 @@ Status CheckInputs(const T* query,
num_heads % kv_num_heads);
}
+#ifdef USE_INT4_KV_CACHE
+ if (kv_cache_bit_width != 0 && kv_cache_bit_width != 4 && kv_cache_bit_width != 8) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "kv_cache_bit_width must be 0, 4 or 8. Got kv_cache_bit_width == ", kv_cache_bit_width);
+ }
+#else
+ if (kv_cache_bit_width != 0 && kv_cache_bit_width != 8) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "kv_cache_bit_width must be 0 or 8. Got kv_cache_bit_width == ", kv_cache_bit_width);
+ }
+#endif
+
int batch_size = 0;
int sequence_length = 0;
int q_hidden_size = 0;
@@ -239,7 +255,7 @@ Status CheckInputs(const T* query,
// Check past-present KV
int32_t past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
- ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, past_sequence_length));
+ ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, kv_cache_bit_width, past_sequence_length));
} else if (past_key != nullptr || past_value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
@@ -329,12 +345,13 @@ Status CheckInputs(const T* query,
const T* total_seqlen,
float scale,
float softcap,
+ int kv_cache_bit_width,
int max_threads_per_block) {
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
- return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap);
+ return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width);
}
template
diff --git a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h
index 415612582ee4b..9669e0b8622c0 100644
--- a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h
+++ b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h
@@ -46,11 +46,11 @@
#define DUMP_TENSOR_D(...)
#endif
-#if (defined(__GNUC__) || defined(__clang__)) && !defined(NDEBUG)
-#define DEBUG_PRINTF(fmt, ...) \
+#if (defined(__GNUC__) || defined(__clang__)) && (DUMP_TENSOR_LEVEL > 0)
+#define DUMP_PRINTF(fmt, ...) \
std::printf("[DEBUG] " fmt "\n", ##__VA_ARGS__)
#else
-#define DEBUG_PRINTF(fmt, ...) \
- do { \
+#define DUMP_PRINTF(fmt, ...) \
+ do { \
} while (0)
#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h
index 1622bb6622412..486bf05bd86d5 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h
@@ -145,18 +145,21 @@ struct PackedMultiHeadAttentionData {
bool use_memory_efficient_attention;
};
-template
+template
struct GroupQueryAttentionData {
// Input Tensors
const T* query = nullptr;
const T* key = nullptr;
const T* value = nullptr;
- const T* past_key = nullptr;
- const T* past_value = nullptr;
+ const U* past_key = nullptr;
+ const U* past_value = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;
const T* head_sink = nullptr;
+ const float* k_scale = nullptr;
+ const float* v_scale = nullptr;
+
// Total sequence length for each batch. It has shape [batch_size].
int* total_seq_lens = nullptr;
@@ -186,13 +189,18 @@ struct GroupQueryAttentionData {
// Output Tensors
T* output = nullptr;
- void* present_key = nullptr;
- void* present_value = nullptr;
+ U* present_key = nullptr;
+ U* present_value = nullptr;
// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
bool use_flash_attention_fast_decode = false;
+ bool use_xqa = false;
+
+ // XQA buffer
+ void* xqa_buffer = nullptr;
+ size_t xqa_buffer_bytes = 0;
};
template
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
index 83f94a31d1786..7aed9fe10afbd 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -145,12 +145,6 @@ bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_hea
template
bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) {
#ifdef ORT_QUICK_BUILD
- // In quick build mode, only fp16 flash attention is built
- constexpr bool is_bf16 = std::is_same::value;
- if (is_bf16) {
- return false;
- }
-
if (head_size != 128) {
return false;
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h
index eb1d6501a80f6..0e6c733313f06 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h
@@ -66,14 +66,6 @@
#define LOCAL_SWITCH BOOL_SWITCH
#endif
-#ifdef ORT_QUICK_BUILD
-// Quick build mode: only fp16 kernels are compiled
-#define FP16_SWITCH(COND, ...) \
- [&] { \
- using elem_type = cutlass::half_t; \
- return __VA_ARGS__(); \
- }()
-#else
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
@@ -84,7 +76,6 @@
return __VA_ARGS__(); \
} \
}()
-#endif
#ifdef ORT_QUICK_BUILD
// Quick build mode: only hdim128 kernels are compiled
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index 29ef660e562e0..39154ca395fc1 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -1,15 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#include
#include
#include
#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cuda_type_conversion.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention.h"
#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
+#include "contrib_ops/cuda/bert/xqa/xqa_loader.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cpu/utils/debug_macros.h"
@@ -21,23 +24,49 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-#define REGISTER_KERNEL_TYPED(T) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
- GroupQueryAttention, \
- kMSDomain, \
- 1, \
- T, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("M", {DataTypeImpl::GetTensorType()}) \
- .MayInplace(3, 1) \
- .MayInplace(4, 2) \
- .InputMemoryType(OrtMemTypeCPUInput, 6), \
- GroupQueryAttention);
-
-REGISTER_KERNEL_TYPED(MLFloat16)
-REGISTER_KERNEL_TYPED(BFloat16)
+namespace {
+// Map string attribute to quantization type enum
+KVQuantizationType StringToKVQuantizationType(std::string s) {
+ std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::toupper(c); });
+ if (s == "NONE") {
+ return KVQuantizationType::NONE;
+ }
+ if (s == "PER_TENSOR") {
+ return KVQuantizationType::PER_TENSOR;
+ }
+
+ if (s == "PER_CHANNEL") {
+ return KVQuantizationType::PER_CHANNEL;
+ }
+ return KVQuantizationType::NONE;
+}
+} // namespace
+
+#define REGISTER_KERNEL_TYPED(T, U) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ GroupQueryAttention, \
+ kMSDomain, \
+ 1, \
+ T##_##U, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T_CACHE", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T_KV_SCALE", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", {DataTypeImpl::GetTensorType()}) \
+ .MayInplace(3, 1) /* past_key and present_key */ \
+ .MayInplace(4, 2) /* past_value and present_value */ \
+ .InputMemoryType(OrtMemTypeCPUInput, 6), /* total_sequence_length */ \
+ GroupQueryAttention);
+
+REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16, BFloat16)
+REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
+REGISTER_KERNEL_TYPED(BFloat16, int8_t)
+#ifdef USE_INT4_KV_CACHE
+REGISTER_KERNEL_TYPED(MLFloat16, uint8_t)
+REGISTER_KERNEL_TYPED(BFloat16, uint8_t)
+#endif
constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE";
@@ -51,8 +80,8 @@ constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE";
// - Quantized KV Cache (Int8/Int4) via GroupQueryAttentionData
// - Flash Attention and Memory Efficient Attention backends
//
-template
-GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
+template
+GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
: CudaKernel(info) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
@@ -69,6 +98,14 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
softcap_ = info.GetAttrOrDefault("softcap", 0.0f);
use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1;
+ k_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("k_quant_type", "NONE"));
+ v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE"));
+ kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0));
+
+ bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE);
+ int default_enable_xqa = is_quantized ? 1 : 0;
+ enable_xqa_ = (std::is_same_v || std::is_same_v) && ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", default_enable_xqa) != 0;
+
kernel_options_ = this->GetAttentionKernelOptions();
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
@@ -99,8 +136,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
// 9. position_ids (Tensor) - Position indices for RoPE
// 10. attention_bias (Tensor) - Not supported in this kernel
// 11. head_sink (Tensor) - Attention sink for GPT-OSS
-template
-Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
+template
+Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
const Tensor* query = context->Input(0);
const Tensor* key = context->Input(1);
const Tensor* value = context->Input(2);
@@ -119,6 +156,35 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
const Tensor* position_ids = context->Input(9);
const Tensor* attention_bias = context->Input(10);
const Tensor* head_sink = context->Input(11);
+ const Tensor* k_scale = context->Input(12);
+ const Tensor* v_scale = context->Input(13);
+
+ if (k_quant_type_ != KVQuantizationType::NONE) {
+ if (k_scale == nullptr) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "k_scale must be provided when k_quant_type is not NONE");
+ }
+
+ if (k_scale->DataType() != DataTypeImpl::GetType()) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "k_scale must be float tensor");
+ }
+ }
+
+ if (v_quant_type_ != KVQuantizationType::NONE) {
+ if (v_scale == nullptr) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "v_scale must be provided when v_quant_type is not NONE");
+ }
+ if (v_scale->DataType() != DataTypeImpl::GetType()) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "v_scale must be float tensor");
+ }
+ }
if (attention_bias != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -127,8 +193,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
- typedef typename ToCudaType::MappedType CudaT;
- GroupQueryAttentionData data;
+
+ typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
+ typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU;
+ GroupQueryAttentionData data;
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
@@ -144,6 +212,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
total_seqlen,
scale_,
softcap_,
+ kv_cache_bit_width_,
device_prop.maxThreadsPerBlock));
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
@@ -155,6 +224,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
+ parameters.k_quant_type = k_quant_type_;
+ parameters.v_quant_type = v_quant_type_;
+ parameters.kv_cache_bit_width = kv_cache_bit_width_;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;
@@ -176,8 +248,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
Tensor* output = context->Output(0, output_shape);
// Set up present KV output shapes
+ // For 4-bit quantization, we pack two 4-bit values into one uint8 byte.
+ // Therefore, the dense head size in the tensor shape is halved (rounded up).
+ int dense_head_size = (parameters.kv_cache_bit_width == 4) ? (parameters.head_size + 1) / 2 : parameters.head_size;
std::vector present_dims = {
- parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size};
+ parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, dense_head_size};
TensorShape present_shape(present_dims);
context->Output(1, present_shape); // present_key
@@ -203,24 +278,103 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data());
// Handle Past/Present pointers
- data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data());
- data.present_key = reinterpret_cast(context->Output(1)->MutableData());
- data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data());
- data.present_value = reinterpret_cast(context->Output(2)->MutableData());
+ data.k_scale = k_scale == nullptr ? nullptr : reinterpret_cast(k_scale->DataRaw());
+ data.v_scale = v_scale == nullptr ? nullptr : reinterpret_cast(v_scale->DataRaw());
+
+ data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data());
+ data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data());
+
+ data.present_key = reinterpret_cast(context->Output(1)->MutableData());
+ data.present_value = reinterpret_cast(context->Output(2)->MutableData());
// Compute past_present_share_buffer early since it's needed for flash attention path selection.
// This compares the final pointer values after quantization handling.
parameters.past_present_share_buffer = (data.past_key == data.present_key);
+ bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
+
+ // Allocate XQA scratch if needed (only for Flash Decoding path)
+ IAllocatorUniquePtr xqa_scratch_buffer;
+ // Check conditions to enable XQA (Extreme Query Attention) kernel for optimized decoding.
+ // XQA is a highly optimized kernel for generation phase (seq_len=1).
+ // Constraints:
+ // 1. Compute Capability SM 8.0+ (Ampere or newer).
+ // 2. Not the first prompt (decoding phase).
+ // 3. Sequence length is 1.
+ // 4. Past and Present KV cache share the same buffer (required for XQA specific memory access).
+ // 5. No Softcap (XQA doesn't support softcap).
+ // 6. Standard Softmax (no smooth softmax).
+ // 7. No local window attention (global attention only).
+ if (enable_xqa_ &&
+ (device_prop.major >= 8) &&
+ !parameters.is_first_prompt &&
+ parameters.sequence_length == 1 &&
+ parameters.past_present_share_buffer &&
+ parameters.softcap == 0.0f &&
+ !parameters.use_smooth_softmax &&
+ parameters.local_window_size == -1) {
+ int group_size = parameters.num_heads / parameters.kv_num_heads;
+
+ bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
+ v_quant_type_ == KVQuantizationType::PER_TENSOR &&
+ data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor.
+ parameters.kv_cache_bit_width == 8 &&
+ (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
+ (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32));
+
+ bool is_non_quantized_supported = !is_inputs_quantized &&
+ (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
+ (64 % group_size == 0);
+
+ data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported);
+
+ if (data.use_xqa) {
+ size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize(
+ GetDeviceProp(),
+ parameters.batch_size,
+ parameters.num_heads,
+ parameters.kv_num_heads,
+ parameters.head_size,
+ parameters.seqlen_present_kv_cache,
+ parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone,
+ std::is_same::value);
+ assert(xqa_internal_bytes > 0);
+ // Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding
+ size_t xqa_total_bytes = xqa_internal_bytes;
+ if (parameters.do_rotary) {
+ // 1. Q_rotated buffer: B * N * H * sizeof(T) (if rotary)
+ // 2. K_rotated buffer: B * Nk * H * sizeof(T) (if rotary)
+ size_t element_size = sizeof(CudaT);
+ size_t q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size;
+ size_t k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size;
+ q_bytes = (q_bytes + 255) / 256 * 256;
+ k_bytes = (k_bytes + 255) / 256 * 256;
+ xqa_total_bytes += q_bytes + k_bytes;
+ }
+
+ xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, context->GetComputeStream());
+ data.xqa_buffer = xqa_scratch_buffer.get();
+ data.xqa_buffer_bytes = xqa_internal_bytes;
+
+ if (parameters.do_rotary) {
+ data.qkv_buffer = reinterpret_cast(reinterpret_cast(data.xqa_buffer) + xqa_internal_bytes);
+ }
+ }
+ }
+
+ // Compute past_present_share_buffer early since it's needed for flash attention path selection.
+ // This compares the final pointer values after quantization handling.
+
#if USE_FLASH_ATTENTION
- bool use_flash_attention = !disable_flash_attention_ &&
+ bool use_flash_attention = !data.use_xqa &&
+ !disable_flash_attention_ &&
onnxruntime::flash::is_supported(device_prop,
parameters.head_size,
parameters.num_heads,
parameters.kv_num_heads);
data.use_flash_attention = use_flash_attention;
- data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer;
+ data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer && !is_inputs_quantized;
if (use_flash_attention) {
// Allocate Flash specific buffers (Softmax LSE, Accum)
@@ -245,6 +399,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream());
out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream());
+ auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle());
+ if (softmax_lse_accum_bytes > 0) {
+ // Initialize to 0 is fine because Flash kernel will write -inf to it if needed.
+ // However, the standard Flash kernel often doesn't zero it globally.
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(softmax_lse_accum_buffer.get(), 0, softmax_lse_accum_bytes, cuda_stream));
+ }
+ if (out_accum_bytes > 0) {
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(out_accum_buffer.get(), 0, out_accum_bytes, cuda_stream));
+ }
+
data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get());
data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get());
data.out_accum = reinterpret_cast(out_accum_buffer.get());
@@ -281,7 +445,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
}
#if USE_MEMORY_EFFICIENT_ATTENTION
- if (!data.use_flash_attention) {
+ if (!data.use_xqa && !data.use_flash_attention) {
// Fall back to memory efficient attention.
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
@@ -313,6 +477,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
// This ensures allocation logic stays in sync with kernel usage
auto buffer_req = GQABufferRequirements::Compute(
parameters,
+ data.use_xqa,
data.use_flash_attention,
data.use_flash_attention_fast_decode,
data.use_memory_efficient_attention);
@@ -351,11 +516,29 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
data.head_sink = reinterpret_cast(head_sink->Data());
}
- cublasHandle_t cublas = GetCublasHandle(context);
+#if DUMP_TENSOR_LEVEL > 0
+ DUMP_TENSOR_INIT();
+ // Dump Scales
+ if (data.k_scale) {
+ if (parameters.k_quant_type == KVQuantizationType::PER_TENSOR) {
+ DUMP_TENSOR("k_scale", data.k_scale, 1, 1);
+ } else if (parameters.k_quant_type == KVQuantizationType::PER_CHANNEL) {
+ DUMP_TENSOR("k_scale", data.k_scale, parameters.kv_num_heads, 1, parameters.head_size);
+ }
+ }
+ if (data.v_scale) {
+ if (parameters.v_quant_type == KVQuantizationType::PER_TENSOR) {
+ DUMP_TENSOR("v_scale", data.v_scale, 1, 1);
+ } else if (parameters.v_quant_type == KVQuantizationType::PER_CHANNEL) {
+ DUMP_TENSOR("v_scale", data.v_scale, parameters.kv_num_heads, 1, parameters.head_size);
+ }
+ }
+#endif
- ORT_RETURN_IF_ERROR(QkvToContext(
- device_prop, cublas, context->GetComputeStream(), parameters, data));
+ cublasHandle_t cublas = GetCublasHandle(context);
+ ORT_RETURN_IF_ERROR((QkvToContext(
+ device_prop, cublas, context->GetComputeStream(), parameters, data)));
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
index 2536da9fe1379..75a613724e746 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
@@ -15,7 +15,7 @@ namespace cuda {
using namespace onnxruntime::cuda;
-template
+template
class GroupQueryAttention final : public CudaKernel {
public:
GroupQueryAttention(const OpKernelInfo& info);
@@ -35,6 +35,11 @@ class GroupQueryAttention final : public CudaKernel {
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
bool disable_flash_decode_;
+ bool enable_xqa_;
+
+ KVQuantizationType k_quant_type_;
+ KVQuantizationType v_quant_type_;
+ int kv_cache_bit_width_;
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
IAllocatorUniquePtr zeros_;
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index 0b6da63b31af6..e0fa993db29bd 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -42,11 +42,14 @@ limitations under the License.
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh"
+#include "contrib_ops/cuda/bert/group_query_attention_qdq.cuh"
+#include "contrib_ops/cuda/bert/xqa/xqa_loader.h"
#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
#include "contrib_ops/cuda/bert/rotary_common.cuh"
#include "contrib_ops/cuda/bert/transformer_common.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/cuda_type_conversion.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
@@ -74,44 +77,43 @@ namespace cuda {
// 3. Ensuring synchronization between past and present KV caches when necessary.
// 4. Launching the UnpackRoPEQuantizeAppend kernel to unpack, apply RoPE, and update caches.
// 5. Returning strict Q, K, V pointers ready for the core attention operation.
-template
+template
Status PrepareQKV(
cudaStream_t stream,
const int max_threads_per_block,
const GroupQueryAttentionParameters& parameters,
- GroupQueryAttentionData& data,
- const T*& q,
- const T*& k,
- const T*& v) {
+ GroupQueryAttentionData& data,
+ const T*& q) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
- using CudaT = typename ToCudaType::MappedType;
- CudaT* q_out = data.qkv_buffer;
+ typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
+ typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU;
+ CudaT* q_out = reinterpret_cast(data.qkv_buffer);
if (!parameters.is_packed_qkv && !parameters.do_rotary) {
q_out = nullptr;
}
- CudaT* k_final_ptr = reinterpret_cast(data.present_key);
- CudaT* v_final_ptr = reinterpret_cast(data.present_value);
- int final_max_seqlen = parameters.seqlen_present_kv_cache;
- bool final_is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
+ CudaT* k = reinterpret_cast(data.present_key);
+ CudaT* v = reinterpret_cast(data.present_value);
+ int max_cache_length = parameters.seqlen_present_kv_cache;
+ bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
if (!parameters.past_present_share_buffer) {
- size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * final_max_seqlen * head_size * sizeof(CudaT);
+ size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(CudaU);
CUDA_CALL_THROW(cudaMemsetAsync(data.present_key, 0, kv_buffer_size, stream));
CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream));
}
+ // Copy past KV to present KV if needed
if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) {
- bool is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
- if (is_bnsh) {
- size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaT);
- size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaT);
+ if (is_cache_bnsh) {
+ size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaU);
+ size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaU);
size_t width = src_pitch;
size_t height = (size_t)batch_size * kv_num_heads;
@@ -120,8 +122,8 @@ Status PrepareQKV(
CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height,
cudaMemcpyDeviceToDevice, stream));
} else {
- size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaT);
- size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaT);
+ size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaU);
+ size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaU);
size_t width = src_pitch;
size_t height = (size_t)batch_size;
@@ -132,17 +134,17 @@ Status PrepareQKV(
}
}
- ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppendKV(
+ ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
- q_out, k_final_ptr, v_final_ptr,
+ q_out, k, v, data.k_scale, data.v_scale,
num_heads, kv_num_heads, head_size, sequence_length, batch_size,
- final_max_seqlen, data.past_seq_lens,
+ max_cache_length, data.past_seq_lens,
reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
- final_is_bnsh,
+ is_cache_bnsh, parameters.k_quant_type, parameters.kv_cache_bit_width,
stream, max_threads_per_block));
if (q_out != nullptr) {
@@ -150,17 +152,16 @@ Status PrepareQKV(
} else {
q = reinterpret_cast(data.query);
}
- k = reinterpret_cast(k_final_ptr);
- v = reinterpret_cast(v_final_ptr);
+
return Status::OK();
}
////////// Auxiliary Kernels for KV prep
// Concat new to past in present. Supports past BSNH or past BNSH
-template
+template
Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters,
- GroupQueryAttentionData& data,
+ GroupQueryAttentionData& data,
const void* new_key,
const void* new_value,
cudaStream_t stream,
@@ -190,12 +191,12 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters,
is_bsnh,
data.past_seq_lens,
data.total_seq_lens,
- data.past_key,
- data.past_value,
+ reinterpret_cast(data.past_key),
+ reinterpret_cast(data.past_value),
reinterpret_cast(new_key),
reinterpret_cast(new_value),
- data.present_key,
- data.present_value,
+ reinterpret_cast(data.present_key),
+ reinterpret_cast(data.present_value),
stream,
max_threads_per_block,
past_only,
@@ -207,9 +208,9 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters,
}
// Concat new to kv buffer in place
-template
+template
Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters,
- GroupQueryAttentionData& data,
+ GroupQueryAttentionData& data,
const void* new_key,
const void* new_value,
bool is_new_kv_bnsh_format,
@@ -230,8 +231,8 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters,
parameters.sequence_length,
reinterpret_cast(new_key),
reinterpret_cast(new_value),
- data.present_key,
- data.present_value,
+ reinterpret_cast(data.present_key),
+ reinterpret_cast(data.present_value),
is_past_kv_bnsh_format,
is_new_kv_bnsh_format,
stream,
@@ -553,13 +554,105 @@ Status LaunchGetSequenceLengths(
}
// Trace function for debugging
-#define ORT_GQA_TRACE(func_name) \
- DEBUG_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \
- func_name, \
- static_cast(parameters.is_packed_qkv), \
- static_cast(parameters.is_first_prompt), \
- static_cast(parameters.is_subsequent_prompt), \
- static_cast(parameters.past_present_share_buffer));
+#define ORT_GQA_TRACE(func_name) \
+ DUMP_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \
+ func_name, \
+ static_cast(parameters.is_packed_qkv), \
+ static_cast(parameters.is_first_prompt), \
+ static_cast(parameters.is_subsequent_prompt), \
+ static_cast(parameters.past_present_share_buffer));
+
+////////// Kernels (supports right padding but not left padding)
+// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path.
+// Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path.
+template
+Status ExtremeDecoding(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data,
+ float scale) {
+ ORT_GQA_TRACE("ExtremeDecoding");
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ // const int kv_sequence_length = parameters.sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int kv_num_heads = parameters.kv_num_heads;
+ const int head_size = parameters.head_size;
+ AttentionQkvFormat past_kv_format = parameters.past_kv_format;
+ // bool is_causal = parameters.is_unidirectional;
+ // bool is_bf16 = std::is_same::value;
+
+ typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
+ bool past_bsnh = (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+
+ // Ultimate Fused Preprocessing: Unpack, RoPE Q, RoPE K, Quantize K/V, Append K/V
+ // This replaces all manual steps (Rotate Q, Rotate K, Quantize, StridedCopy)
+ CudaT* q_rot_ptr = reinterpret_cast(data.qkv_buffer);
+ const CudaT* q_input_for_xqa = q_rot_ptr;
+ if (q_rot_ptr == nullptr) {
+ q_input_for_xqa = reinterpret_cast(data.query);
+ }
+
+ ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
+ parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
+ q_rot_ptr, // unpacked_q (can be null if !do_rotary)
+ data.present_key,
+ data.present_value,
+ data.k_scale,
+ data.v_scale,
+ num_heads,
+ kv_num_heads,
+ head_size,
+ sequence_length,
+ batch_size,
+ parameters.seqlen_present_kv_cache, // max_seqlen (capacity)
+ data.past_seq_lens,
+ data.cos_cache,
+ data.sin_cache,
+ parameters.do_rotary ? parameters.rotary_dim : 0,
+ data.position_ids,
+ parameters.rotary_interleaved,
+ !past_bsnh, // is_cache_bnsh
+ parameters.k_quant_type,
+ parameters.kv_cache_bit_width,
+ stream,
+ device_prop.maxThreadsPerBlock));
+
+ // Determine workspace size for XQA
+ void* xqa_workspace = data.xqa_buffer;
+ size_t xqa_workspace_size = data.xqa_buffer_bytes;
+
+ // 5. Launch XQA
+ Status status = onnxruntime::contrib::cuda::LaunchXQAKernel(
+ device_prop,
+ stream,
+ q_input_for_xqa,
+ data.present_key,
+ data.present_value,
+ data.output,
+ batch_size,
+ num_heads,
+ kv_num_heads,
+ parameters.head_size,
+ parameters.seqlen_present_kv_cache, // max_seq_len (Capacity)
+ scale,
+ past_bsnh,
+ data.past_seq_lens,
+ data.k_scale, // kv_cache_scale
+ // Map KVQuantizationType (0=NONE, 1=TENSOR, 2=CHANNEL) to XqaQuantType (0=FP16/BF16, 1=INT8, 2=FP8)
+ (parameters.k_quant_type == KVQuantizationType::NONE) ? onnxruntime::contrib::cuda::XqaQuantType::kNone : onnxruntime::contrib::cuda::XqaQuantType::kInt8,
+ xqa_workspace,
+ xqa_workspace_size);
+
+ // If XQA launch fails, debugging info
+
+ return status;
+}
////////// Kernels (supports right padding but not left padding)
// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path.
@@ -568,12 +661,12 @@ Status LaunchGetSequenceLengths(
// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path.
// Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path.
-template
+template
Status FlashDecoding(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
GroupQueryAttentionParameters& parameters,
- GroupQueryAttentionData& data,
+ GroupQueryAttentionData& data,
float scale) {
assert(!parameters.is_first_prompt && parameters.past_present_share_buffer);
@@ -587,7 +680,7 @@ Status FlashDecoding(
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
bool is_causal = parameters.is_unidirectional;
- bool is_bf16 = std::is_same::value;
+ bool is_bf16 = std::is_same::value || std::is_same::value;
void* query = reinterpret_cast(const_cast(data.query));
void* key;
@@ -613,6 +706,9 @@ Status FlashDecoding(
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ DUMP_PRINTF("[FlashDecoding] key=%p, value=%p, present_key=%p, present_value=%p, seqlens_k=%p, is_packed_qkv=%d",
+ key, value, present_key, present_value, seqlens_k, static_cast(parameters.is_packed_qkv));
+
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, key, value, data.output,
reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
@@ -629,12 +725,12 @@ Status FlashDecoding(
// Use extra kernel(s) for unpacking, rotary and kv append.
// Flash attention is used for attention only.
-template
+template
Status FlashAttention(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
GroupQueryAttentionParameters& parameters,
- GroupQueryAttentionData& data,
+ GroupQueryAttentionData& data,
float scale) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
@@ -646,21 +742,14 @@ Status FlashAttention(
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
bool is_causal = parameters.is_unidirectional;
- bool is_bf16 = std::is_same::value;
+ bool is_bf16 = std::is_same::value || std::is_same::value;
DUMP_TENSOR_INIT();
const T* q_prep = nullptr;
- const T* k_prep = nullptr;
- const T* v_prep = nullptr;
- ORT_RETURN_IF_ERROR(PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep, k_prep, v_prep));
+ ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep)));
void* query = const_cast(q_prep);
- (void)k_prep; // Key/value are now processed by PrepareQKV
- (void)v_prep;
-
- bool use_packed_for_fa = false;
-
void* present_key = data.present_key;
void* present_value = data.present_value;
@@ -677,6 +766,9 @@ Status FlashAttention(
// Use padded seq lens for first prompt since mha_fwd_kvcache assumes uniform seqlen_q.
int* seq_lens = parameters.is_first_prompt ? data.padded_seq_lens : data.total_seq_lens;
+ // After PrepareQKV, the input for flash attention is no longer packed.
+ constexpr bool is_packed_qkv_for_flash = false;
+
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value,
kernel_new_k, kernel_new_v,
@@ -689,22 +781,214 @@ Status FlashAttention(
parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast(data.softmax_lse_accum),
reinterpret_cast(data.out_accum), parameters.local_window_size - 1,
- parameters.rotary_interleaved, use_packed_for_fa, 0, 1));
+ parameters.rotary_interleaved, is_packed_qkv_for_flash, 0, 1));
DUMP_TENSOR("Total Seq Lens", data.total_seq_lens, batch_size, 1);
DUMP_TENSOR("Past Seq Lens", data.past_seq_lens, batch_size, 1);
return Status::OK();
}
+
+// Fallback path for decoding quantized kv cache, when XQA is not usable (due to softcap, window, etc.)
+// We dequantize the cache and run standard Flash Attention.
+template
+Status DequantizeFlashAttentionFallback(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data,
+ float scale) {
+ assert(!parameters.is_first_prompt); // Only support first prompt for this function.
+ assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE);
+
+ ORT_GQA_TRACE("DequantizeFlashAttentionFallback");
+
+ // We need to dequantize the entire KV cache (present_key/value) into a float/half buffer (data.qkv_buffer).
+ // Layout in qkv_buffer: [Q (rotated)] [K_dequantized] [V_dequantized]
+ typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
+ CudaT* q_rot = reinterpret_cast(data.qkv_buffer);
+ size_t q_elements = static_cast(parameters.batch_size) * parameters.sequence_length * parameters.num_heads * parameters.head_size;
+ size_t k_elements = static_cast(parameters.batch_size) * parameters.seqlen_present_kv_cache * parameters.kv_num_heads * parameters.head_size;
+ CudaT* k_dequant = q_rot + q_elements;
+ CudaT* v_dequant = k_dequant + k_elements;
+
+ // Step 1: Update Quantized Cache
+ // We can use LaunchUnpackRoPEQuantizeAppend to unpack new QKV, apply RoPE, and append to quantized cache.
+ // This will also put rotated Q into q_rot.
+ ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
+ parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
+ q_rot, data.present_key, data.present_value, data.k_scale, data.v_scale,
+ parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.batch_size,
+ parameters.seqlen_present_kv_cache, data.past_seq_lens,
+ reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
+ parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
+ (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH),
+ parameters.k_quant_type, parameters.kv_cache_bit_width,
+ stream, device_prop.maxThreadsPerBlock));
+
+ // Step 2: Dequantize Entire Cache
+ // We now have the updated quantized cache in data.present_key/value. We need to dequantize it to k_dequant/v_dequant.
+ bool is_bsnh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+
+ if (parameters.kv_cache_bit_width == 8) {
+ ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
+ stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale,
+ nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
+ parameters.head_size, 8, parameters.k_quant_type, is_bsnh)));
+
+ ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
+ stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale,
+ nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
+ parameters.head_size, 8, parameters.v_quant_type, is_bsnh)));
+#ifdef USE_INT4_KV_CACHE
+ } else if (parameters.kv_cache_bit_width == 4) {
+ // Int4 support if needed
+ ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
+ stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale,
+ nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
+ parameters.head_size, 4, parameters.k_quant_type, is_bsnh)));
+
+ ORT_RETURN_IF_ERROR((LaunchDequantizeKV(
+ stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale,
+ nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache,
+ parameters.head_size, 4, parameters.v_quant_type, is_bsnh)));
+#endif
+ }
+
+ // Step 3: Run Flash Attention on dequantized k/v
+ bool is_causal = parameters.is_unidirectional;
+ bool is_bf16 = std::is_same::value || std::is_same::value;
+
+ // Use the total_seq_lens here since k_dequant/v_dequant has both past and new tokens.
+ void* seqlens_k_ptr = const_cast(reinterpret_cast(data.total_seq_lens));
+ int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1;
+
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
+ device_prop, stream, q_rot, k_dequant, v_dequant, nullptr /*new K*/, nullptr /*new V*/, data.output,
+ reinterpret_cast(data.softmax_lse), seqlens_k_ptr, nullptr /*cos_cache*/, nullptr /*sin_cache*/,
+ /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, reinterpret_cast(const_cast(data.head_sink)), /*block_table*/ nullptr,
+ parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length,
+ parameters.seqlen_present_kv_cache, parameters.sequence_length, 0 /*rotary_dim = 0 as it is already rotated*/,
+ scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, is_bsnh, parameters.num_splits,
+ reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum),
+ local_window_size, parameters.rotary_interleaved, false,
+ 0, 1));
+
+ return Status::OK();
+}
+
+// Use Flash Attention for float key and value, then quantize key/value to int8 to save to k/v cache.
+template
+Status FlashAttentionAndQuantizeKV(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data,
+ float scale) {
+ assert(parameters.is_first_prompt); // Only support first prompt for this function.
+ assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE);
+
+ const int max_threads_per_block = device_prop.maxThreadsPerBlock;
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_num_heads = parameters.kv_num_heads;
+ const int num_heads = parameters.num_heads;
+ const int head_size = parameters.head_size;
+
+ ORT_GQA_TRACE("FlashAttentionAndQuantizeKV");
+
+ bool past_bsnh = parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+
+ size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size;
+ size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size;
+
+ using CudaT = typename onnxruntime::cuda::OrtToCudaType::type;
+ CudaT* q_final = reinterpret_cast(data.qkv_buffer);
+
+ // For FlashAttentionAndQuantizeKV, we need float K and V for attention.
+ // We'll write them to qkv_buffer.
+ CudaT* k_final = q_final + q_elements;
+ CudaT* v_final = k_final + k_elements;
+
+ ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend(
+ parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr,
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key),
+ parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value),
+ q_final, k_final, v_final, nullptr, nullptr,
+ num_heads, kv_num_heads, head_size, sequence_length, batch_size,
+ sequence_length, data.past_seq_lens,
+ reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache),
+ parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
+ false, // BSNH for scratch
+ KVQuantizationType::NONE,
+ 0, // bit_width is 0 since we are not quantizing here.
+ stream, max_threads_per_block));
+
+ // 2. Run Float Flash Attention
+ bool is_causal = parameters.is_unidirectional;
+ bool is_bf16 = std::is_same::value || std::is_same::value;
+
+ int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1;
+
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
+ device_prop, stream, q_final, k_final, v_final, data.output,
+ reinterpret_cast(data.softmax_lse),
+ batch_size, num_heads, kv_num_heads, head_size, sequence_length, sequence_length,
+ scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax,
+ parameters.num_splits,
+ reinterpret_cast(data.softmax_lse_accum),
+ reinterpret_cast(data.out_accum),
+ true, // kv_bsnh = true (BSNH)
+ local_window_size));
+
+ // 3. Quantize K and V to present cache
+ if (parameters.k_quant_type != KVQuantizationType::NONE) {
+ if (parameters.kv_cache_bit_width == 8) {
+ ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
+ stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale,
+ nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
+ head_size, 8, parameters.k_quant_type, true, past_bsnh)));
+#ifdef USE_INT4_KV_CACHE
+ } else if (parameters.kv_cache_bit_width == 4) {
+ ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
+ stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale,
+ nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
+ head_size, 4, parameters.k_quant_type, true, past_bsnh)));
+#endif
+ }
+ }
+
+ if (parameters.v_quant_type != KVQuantizationType::NONE) {
+ if (parameters.kv_cache_bit_width == 8) {
+ ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
+ stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale,
+ nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache,
+ head_size, 8, parameters.v_quant_type, true, past_bsnh)));
+#ifdef USE_INT4_KV_CACHE
+ } else if (parameters.kv_cache_bit_width == 4) {
+ ORT_RETURN_IF_ERROR((LaunchQuantizeKV(
+ stream, reinterpret_cast(data.present_value), reinterpret_cast