Skip to content

Commit 5b1d53b

Browse files
committed
Fix scaling factor usage in FP8 quant.
1 parent 809b222 commit 5b1d53b

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV
11751175
// ============================== Quant Device Function ==============================
11761176
template <typename T, typename PackedType, int ELTS_PER_THREAD>
11771177
inline __device__ void quant_fp8(PackedVec<PackedType, T> packedAccum, void* quantOutPtr,
1178-
float* outputScale, uint32_t threadOffset) {
1178+
float invOutputScale, uint32_t threadOffset) {
11791179
static_assert(ELTS_PER_THREAD == 8 || ELTS_PER_THREAD == 4, "ELTS_PER_THREAD must be 8 or 4");
11801180
using QuantizedPackedType = std::conditional_t<ELTS_PER_THREAD == 8, float2, float>;
11811181

@@ -1184,7 +1184,7 @@ inline __device__ void quant_fp8(PackedVec<PackedType, T> packedAccum, void* qua
11841184
#pragma unroll
11851185
for (int i = 0; i < ELTS_PER_THREAD; i++) {
11861186
quantizedAccum.elements[i] =
1187-
__nv_fp8_e4m3(toFloat<T>(packedAccum.elements[i]) * (*outputScale));
1187+
__nv_fp8_e4m3(toFloat<T>(packedAccum.elements[i]) * invOutputScale);
11881188
}
11891189
reinterpret_cast<QuantizedPackedType*>(&quantOut[threadOffset])[0] = quantizedAccum.packed;
11901190
}
@@ -1373,7 +1373,8 @@ __global__ void __launch_bounds__(config::kMaxBlockSize) oneshotAllreduceFusionK
13731373
}
13741374

13751375
if constexpr (QType == QuantType::kFP8) {
1376-
quant::quant_fp8<T, PackedType, kELTS_PER_THREAD>(packedAccum, quantOutPtr, outputScale,
1376+
float invOutputScale = 1.0f / (*outputScale); // We need to apply inv_scale to the output
1377+
quant::quant_fp8<T, PackedType, kELTS_PER_THREAD>(packedAccum, quantOutPtr, invOutputScale,
13771378
threadOffset);
13781379
}
13791380
#if CUDA_VERSION >= 12080
@@ -1805,7 +1806,8 @@ __global__ __launch_bounds__(config::kMaxBlockSize) void rmsNormLamport_fusion(
18051806
*reinterpret_cast<float4*>(&outputNorm[blockLoadOffset + threadLoadOffset]) = rOut.packed;
18061807
}
18071808
if constexpr (QType == QuantType::kFP8) {
1808-
quant::quant_fp8<T, float4, kELTS_PER_LOAD>(rOut, quantOut, outputScale,
1809+
float invOutputScale = 1.0f / (*outputScale);
1810+
quant::quant_fp8<T, float4, kELTS_PER_LOAD>(rOut, quantOut, invOutputScale,
18091811
blockLoadOffset + threadLoadOffset);
18101812
}
18111813
#if CUDA_VERSION >= 12080

0 commit comments

Comments
 (0)