Skip to content

Commit 019a2b1

Browse files
committed
Support fp8 kv cache
1 parent 9aa8deb commit 019a2b1

22 files changed

+949
-150
lines changed

cmake/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled do
104104
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
105105
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
106106
option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
107+
option(onnxruntime_USE_FP8_KV_CACHE "Build cuda kernels for fp8 kv cache" ON)
107108
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)
108109

109110
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
@@ -783,6 +784,11 @@ if (onnxruntime_USE_CUDA)
783784
message( STATUS "Enable int4 kv cache for CUDA EP")
784785
list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1)
785786
endif()
787+
788+
if (onnxruntime_USE_FP8_KV_CACHE)
789+
message( STATUS "Enable fp8 kv cache for CUDA EP")
790+
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FP8_KV_CACHE=1)
791+
endif()
786792
endif()
787793

788794
if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA))
@@ -1442,6 +1448,15 @@ if (Git_FOUND)
14421448
if (onnxruntime_USE_INT4_KV_CACHE)
14431449
string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ")
14441450
endif()
1451+
if (onnxruntime_USE_FP8_KV_CACHE)
1452+
string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ")
1453+
endif()
1454+
if (onnxruntime_DUMP_TENSOR)
1455+
string(APPEND ORT_BUILD_INFO "dump-tensor=1, ")
1456+
endif()
1457+
if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
1458+
string(APPEND ORT_BUILD_INFO "dump-node=1, ")
1459+
endif()
14451460
endif()
14461461
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
14471462
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16)
6363
REGISTER_KERNEL_TYPED(BFloat16, BFloat16)
6464
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
6565
REGISTER_KERNEL_TYPED(BFloat16, int8_t)
66+
#ifdef USE_FP8_KV_CACHE
67+
REGISTER_KERNEL_TYPED(MLFloat16, Float8E4M3FN)
68+
REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN)
69+
#endif
6670
#ifdef USE_INT4_KV_CACHE
6771
REGISTER_KERNEL_TYPED(MLFloat16, uint8_t)
6872
REGISTER_KERNEL_TYPED(BFloat16, uint8_t)
@@ -292,6 +296,8 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
292296
parameters.past_present_share_buffer = (data.past_key == data.present_key);
293297

294298
bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
299+
constexpr bool is_int8 = std::is_same<U, int8_t>::value;
300+
constexpr bool is_fp8 = std::is_same<U, Float8E4M3FN>::value;
295301

296302
// Allocate XQA scratch if needed (only for Flash Decoding path)
297303
IAllocatorUniquePtr<void> xqa_scratch_buffer;
@@ -315,18 +321,30 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
315321
parameters.local_window_size == -1) {
316322
int group_size = parameters.num_heads / parameters.kv_num_heads;
317323

318-
bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
324+
bool is_int8_quantized_supported = is_int8 &&
325+
(k_quant_type_ == KVQuantizationType::PER_TENSOR &&
319326
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
320327
data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor.
321-
parameters.kv_cache_bit_width == 8 &&
322328
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
323329
(group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32));
324330

331+
#ifdef USE_FP8_KV_CACHE
332+
bool is_fp8_quantized_supported = is_fp8 &&
333+
(k_quant_type_ == KVQuantizationType::PER_TENSOR &&
334+
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
335+
data.k_scale == data.v_scale &&
336+
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
337+
(group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32) &&
338+
(device_prop.major >= 9 || (device_prop.major == 8 && device_prop.minor == 9))); // FP8 requires SM89+ (Ada Lovelace)
339+
#else
340+
constexpr bool is_fp8_quantized_supported = false;
341+
#endif
342+
325343
bool is_non_quantized_supported = !is_inputs_quantized &&
326344
(parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) &&
327345
(64 % group_size == 0);
328346

329-
data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported);
347+
data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported);
330348

331349
if (data.use_xqa) {
332350
size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize(
@@ -336,7 +354,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
336354
parameters.kv_num_heads,
337355
parameters.head_size,
338356
parameters.seqlen_present_kv_cache,
339-
parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone,
357+
parameters.k_quant_type != KVQuantizationType::NONE ? (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType::kInt8) : XqaQuantType::kNone,
340358
std::is_same<T, BFloat16>::value);
341359
assert(xqa_internal_bytes > 0);
342360
// Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding

0 commit comments

Comments
 (0)