@@ -1175,7 +1175,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV
11751175// ============================== Quant Device Function ==============================
11761176template <typename T, typename PackedType, int ELTS_PER_THREAD>
11771177inline __device__ void quant_fp8 (PackedVec<PackedType, T> packedAccum, void * quantOutPtr,
1178- float * outputScale , uint32_t threadOffset) {
1178+ float invOutputScale , uint32_t threadOffset) {
11791179 static_assert (ELTS_PER_THREAD == 8 || ELTS_PER_THREAD == 4 , " ELTS_PER_THREAD must be 8 or 4" );
11801180 using QuantizedPackedType = std::conditional_t <ELTS_PER_THREAD == 8 , float2 , float >;
11811181
@@ -1184,7 +1184,7 @@ inline __device__ void quant_fp8(PackedVec<PackedType, T> packedAccum, void* qua
11841184#pragma unroll
11851185 for (int i = 0 ; i < ELTS_PER_THREAD; i++) {
11861186 quantizedAccum.elements [i] =
1187- __nv_fp8_e4m3 (toFloat<T>(packedAccum.elements [i]) * (*outputScale) );
1187+ __nv_fp8_e4m3 (toFloat<T>(packedAccum.elements [i]) * invOutputScale );
11881188 }
11891189 reinterpret_cast <QuantizedPackedType*>(&quantOut[threadOffset])[0 ] = quantizedAccum.packed ;
11901190}
@@ -1373,7 +1373,8 @@ __global__ void __launch_bounds__(config::kMaxBlockSize) oneshotAllreduceFusionK
13731373 }
13741374
13751375 if constexpr (QType == QuantType::kFP8 ) {
1376- quant::quant_fp8<T, PackedType, kELTS_PER_THREAD >(packedAccum, quantOutPtr, outputScale,
1376+ float invOutputScale = 1 .0f / (*outputScale); // We need to apply inv_scale to the output
1377+ quant::quant_fp8<T, PackedType, kELTS_PER_THREAD >(packedAccum, quantOutPtr, invOutputScale,
13771378 threadOffset);
13781379 }
13791380#if CUDA_VERSION >= 12080
@@ -1805,7 +1806,8 @@ __global__ __launch_bounds__(config::kMaxBlockSize) void rmsNormLamport_fusion(
18051806 *reinterpret_cast <float4 *>(&outputNorm[blockLoadOffset + threadLoadOffset]) = rOut.packed ;
18061807 }
18071808 if constexpr (QType == QuantType::kFP8 ) {
1808- quant::quant_fp8<T, float4 , kELTS_PER_LOAD >(rOut, quantOut, outputScale,
1809+ float invOutputScale = 1 .0f / (*outputScale);
1810+ quant::quant_fp8<T, float4 , kELTS_PER_LOAD >(rOut, quantOut, invOutputScale,
18091811 blockLoadOffset + threadLoadOffset);
18101812 }
18111813#if CUDA_VERSION >= 12080
0 commit comments