|
2 | 2 | // Licensed under the MIT License. |
3 | 3 | #pragma once |
4 | 4 |
|
5 | | -// Enable quantized KV cache support for INT8/INT4 |
| 5 | +// Enable quantized KV cache support for INT8/INT4/FP8 |
6 | 6 | #define KV_QUANT_SUPPORTED 1 |
7 | 7 |
|
8 | 8 | #include <cuda_fp16.h> |
@@ -49,7 +49,7 @@ struct TypeConverter<__nv_bfloat16> { |
49 | 49 | // ============================================================================ |
50 | 50 | // |
51 | 51 | // This file implements symmetric quantization for KV cache in GroupQueryAttention. |
52 | | -// Supports INT4 and INT8 with PER_TENSOR and PER_CHANNEL quantization modes. |
| 52 | +// Supports INT4, INT8, and FP8 (E4M3) with PER_TENSOR and PER_CHANNEL quantization modes. |
53 | 53 | // |
54 | 54 | // QUANTIZATION SCHEME: |
55 | 55 | // ------------------- |
@@ -96,7 +96,7 @@ struct TypeConverter<__nv_bfloat16> { |
96 | 96 | // - Conversion: Native CUDA cast via __nv_cvt_float_to_fp8/fp8_to_float |
97 | 97 | // ============================================================================ |
98 | 98 |
|
99 | | -// Dequantization Kernel: Converts Quantized (Int8/Int4) KV cache back to Floating Point (T). |
| 99 | +// Dequantization Kernel: Converts Quantized (Int8/Int4/FP8) KV cache back to Floating Point (T). |
100 | 100 | // Iterates over every individual element with one thread per element. |
101 | 101 | template <typename T, typename T_QUANT, typename T_SCALE> |
102 | 102 | __global__ void DequantizeKernel(T* dequantized_data, |
@@ -195,7 +195,7 @@ Status LaunchDequantizeKV(cudaStream_t stream, T* dequantized_data, |
195 | 195 | return CUDA_CALL(cudaGetLastError()); |
196 | 196 | } |
197 | 197 |
|
198 | | -// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4) values. |
| 198 | +// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4/FP8) values. |
199 | 199 | // Note: This kernel is used to quantize a full input tensor, e.g. during graph initialization |
200 | 200 | // or fallback paths. The main prompt path uses the fused UnpackRoPEAppend kernel. |
201 | 201 | template <typename T, typename T_QUANT, typename T_SCALE> |
|
0 commit comments