@@ -835,20 +835,20 @@ __device__ __forceinline__ void vec_convert(
835835 out[j] = DstT (static_cast <float >(in[j]));
836836}
837837
838- // BF16 → FP8 e4m3: paired PTX cvt.rn.satfinite.e4m3x2.bf16x2 (SM100+, Blackwell).
838+ // BF16 → FP8 e4m3: use CUDA intrinsic (SM100+, Blackwell).
839+ // Inline PTX "cvt.rn.satfinite.e4m3x2.bf16x2 %h, %r" is rejected by SM100a ptxas
840+ // ("Unexpected instruction types for cvt") because SM100a requires a 32-bit output
841+ // register for this instruction. __nv_fp8x2_e4m3(bfloat162) emits the correct form.
839842#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
840843template <size_t VEC_SIZE, std::enable_if_t <(VEC_SIZE % 2 == 0 ), int > = 0 >
841844__device__ __forceinline__ void vec_convert (
842845 flashinfer::vec_t <__nv_fp8_e4m3, VEC_SIZE>& out,
843846 flashinfer::vec_t <__nv_bfloat16, VEC_SIZE> const & in) {
844- uint32_t const * src_u32 = reinterpret_cast <uint32_t const *>(&in );
845- uint16_t * dst_u16 = reinterpret_cast <uint16_t *>(&out );
847+ __nv_fp8x2_e4m3* out_fp8x2 = reinterpret_cast <__nv_fp8x2_e4m3 *>(&out );
848+ __nv_bfloat162 const * in_bf16x2 = reinterpret_cast <__nv_bfloat162 const *>(&in );
846849#pragma unroll
847- for (int p = 0 ; p < VEC_SIZE / 2 ; ++p) {
848- uint16_t d;
849- asm volatile (" cvt.rn.satfinite.e4m3x2.bf16x2 %0, %1;" : " =h" (d) : " r" (src_u32[p]));
850- dst_u16[p] = d;
851- }
850+ for (int p = 0 ; p < static_cast <int >(VEC_SIZE) / 2 ; ++p)
851+ out_fp8x2[p] = __nv_fp8x2_e4m3 (in_bf16x2[p]);
852852}
853853#endif
854854
0 commit comments