Skip to content

Commit abcbdad

Browse files
committed
update comments
1 parent 019a2b1 commit abcbdad

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ Status DequantizeFlashAttentionFallback(
858858
return Status::OK();
859859
}
860860

861-
// Use Flash Attention for float key and value, then quantize key/value to int8 to save to k/v cache.
861+
// Use Flash Attention for float key and value, then quantize key/value (int8/fp8/int4) to save to k/v cache.
862862
template <typename T, typename U>
863863
Status FlashAttentionAndQuantizeKV(
864864
const cudaDeviceProp& device_prop,

onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
// Enable quantized KV cache support for INT8/INT4
5+
// Enable quantized KV cache support for INT8/INT4/FP8
66
#define KV_QUANT_SUPPORTED 1
77

88
#include <cuda_fp16.h>
@@ -49,7 +49,7 @@ struct TypeConverter<__nv_bfloat16> {
4949
// ============================================================================
5050
//
5151
// This file implements symmetric quantization for KV cache in GroupQueryAttention.
52-
// Supports INT4 and INT8 with PER_TENSOR and PER_CHANNEL quantization modes.
52+
// Supports INT4, INT8, and FP8 (E4M3) with PER_TENSOR and PER_CHANNEL quantization modes.
5353
//
5454
// QUANTIZATION SCHEME:
5555
// -------------------
@@ -96,7 +96,7 @@ struct TypeConverter<__nv_bfloat16> {
9696
// - Conversion: Native CUDA cast via __nv_cvt_float_to_fp8/fp8_to_float
9797
// ============================================================================
9898

99-
// Dequantization Kernel: Converts Quantized (Int8/Int4) KV cache back to Floating Point (T).
99+
// Dequantization Kernel: Converts Quantized (Int8/Int4/FP8) KV cache back to Floating Point (T).
100100
// Iterates over every individual element with one thread per element.
101101
template <typename T, typename T_QUANT, typename T_SCALE>
102102
__global__ void DequantizeKernel(T* dequantized_data,
@@ -195,7 +195,7 @@ Status LaunchDequantizeKV(cudaStream_t stream, T* dequantized_data,
195195
return CUDA_CALL(cudaGetLastError());
196196
}
197197

198-
// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4) values.
198+
// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4/FP8) values.
199199
// Note: This kernel is used to quantize a full input tensor, e.g. during graph initialization
200200
// or fallback paths. The main prompt path uses the fused UnpackRoPEAppend kernel.
201201
template <typename T, typename T_QUANT, typename T_SCALE>

0 commit comments

Comments
 (0)