@@ -63,6 +63,10 @@ REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16)
6363REGISTER_KERNEL_TYPED(BFloat16, BFloat16)
6464REGISTER_KERNEL_TYPED(MLFloat16, int8_t )
6565REGISTER_KERNEL_TYPED(BFloat16, int8_t )
66+ #ifdef USE_FP8_KV_CACHE
67+ REGISTER_KERNEL_TYPED (MLFloat16, Float8E4M3FN)
68+ REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN)
69+ #endif
6670#ifdef USE_INT4_KV_CACHE
6771REGISTER_KERNEL_TYPED (MLFloat16, uint8_t )
6872REGISTER_KERNEL_TYPED(BFloat16, uint8_t )
@@ -292,6 +296,8 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
292296 parameters.past_present_share_buffer = (data.past_key == data.present_key );
293297
294298 bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE);
299+ constexpr bool is_int8 = std::is_same<U, int8_t >::value;
300+ constexpr bool is_fp8 = std::is_same<U, Float8E4M3FN>::value;
295301
296302 // Allocate XQA scratch if needed (only for Flash Decoding path)
297303 IAllocatorUniquePtr<void > xqa_scratch_buffer;
@@ -315,18 +321,30 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
315321 parameters.local_window_size == -1 ) {
316322 int group_size = parameters.num_heads / parameters.kv_num_heads ;
317323
318- bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
324+ bool is_int8_quantized_supported = is_int8 &&
325+ (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
319326 v_quant_type_ == KVQuantizationType::PER_TENSOR &&
320327 data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor.
321- parameters.kv_cache_bit_width == 8 &&
322328 (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64 ) &&
323329 (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32 ));
324330
331+ #ifdef USE_FP8_KV_CACHE
332+ bool is_fp8_quantized_supported = is_fp8 &&
333+ (k_quant_type_ == KVQuantizationType::PER_TENSOR &&
334+ v_quant_type_ == KVQuantizationType::PER_TENSOR &&
335+ data.k_scale == data.v_scale &&
336+ (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64 ) &&
337+ (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32 ) &&
338+ (device_prop.major >= 9 || (device_prop.major == 8 && device_prop.minor == 9 ))); // FP8 requires SM89+ (Ada Lovelace)
339+ #else
340+ constexpr bool is_fp8_quantized_supported = false ;
341+ #endif
342+
325343 bool is_non_quantized_supported = !is_inputs_quantized &&
326344 (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64 ) &&
327345 (64 % group_size == 0 );
328346
329- data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported);
347+ data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported || is_fp8_quantized_supported );
330348
331349 if (data.use_xqa ) {
332350 size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize (
@@ -336,7 +354,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
336354 parameters.kv_num_heads ,
337355 parameters.head_size ,
338356 parameters.seqlen_present_kv_cache ,
339- parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone ,
357+ parameters.k_quant_type != KVQuantizationType::NONE ? (is_fp8 ? XqaQuantType::kFp8 : XqaQuantType:: kInt8 ) : XqaQuantType::kNone ,
340358 std::is_same<T, BFloat16>::value);
341359 assert (xqa_internal_bytes > 0 );
342360 // Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding
0 commit comments