@@ -590,6 +590,9 @@ cvt_fp16_to_fp4_expert(
590590 uint32_t * SFout, int32_t * mask, bool use_silu_and_mul, int n_experts) {
591591#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
592592 using PackedVecT = PackedVec<Type, CVT_FP16_TO_FP4_ELTS_PER_THREAD>;
593+ // Packed fp4 output type: 8 fp4 elts fit in 32 bits, 16 fp4 elts in 64 bits.
594+ using PackedFp4OutT =
595+ std::conditional_t <CVT_FP16_TO_FP4_ELTS_PER_THREAD == 16 , uint64_t , uint32_t >;
593596 static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
594597 (CVT_FP4_SF_VEC_SIZE / CVT_FP16_TO_FP4_ELTS_PER_THREAD);
595598 static_assert (sizeof (PackedVecT) == sizeof (Type) * CVT_FP16_TO_FP4_ELTS_PER_THREAD,
@@ -653,9 +656,9 @@ cvt_fp16_to_fp4_expert(
653656 }
654657
655658 // Get the output tensor offset.
656- // Same as inOffset because 8 elements are packed into one uint32_t.
659+ // Same as inOffset because CVT_FP16_TO_FP4_ELTS_PER_THREAD elements are
660+ // packed into one PackedFp4OutT (uint32_t for 8 elts, uint64_t for 16 elts).
657661 int64_t outOffset = rowIdx * colsPerRow + colIdx;
658- auto & out_pos = out[outOffset];
659662
660663 // Get the global scaling factor, which will be applied to the SF.
661664 // Note SFScale is the same as next GEMM's alpha, which is
@@ -672,7 +675,7 @@ cvt_fp16_to_fp4_expert(
672675 CVT_FP4_NUM_THREADS_PER_SF>(
673676 rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
674677
675- out_pos =
678+ reinterpret_cast <PackedFp4OutT*>(out)[outOffset] =
676679 cvt_warp_fp16_to_fp4<Type, CVT_FP4_SF_VEC_SIZE, CVT_FP16_TO_FP4_ELTS_PER_THREAD, UE8M0_SF>(
677680 in_vec, SFScaleVal, sf_out);
678681 }
0 commit comments