Skip to content

Commit bfcc10f

Browse files
committed
Use cuda intrinistic instead of PTX
1 parent d1810b2 commit bfcc10f

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -835,20 +835,20 @@ __device__ __forceinline__ void vec_convert(
835835
out[j] = DstT(static_cast<float>(in[j]));
836836
}
837837

838-
// BF16 → FP8 e4m3: paired PTX cvt.rn.satfinite.e4m3x2.bf16x2 (SM100+, Blackwell).
838+
// BF16 → FP8 e4m3: use CUDA intrinsic (SM100+, Blackwell).
839+
// Inline PTX "cvt.rn.satfinite.e4m3x2.bf16x2 %h, %r" is rejected by SM100a ptxas
840+
// ("Unexpected instruction types for cvt") because SM100a requires a 32-bit output
841+
// register for this instruction. __nv_fp8x2_e4m3(bfloat162) emits the correct form.
839842
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
840843
template <size_t VEC_SIZE, std::enable_if_t<(VEC_SIZE % 2 == 0), int> = 0>
841844
__device__ __forceinline__ void vec_convert(
842845
flashinfer::vec_t<__nv_fp8_e4m3, VEC_SIZE>& out,
843846
flashinfer::vec_t<__nv_bfloat16, VEC_SIZE> const& in) {
844-
uint32_t const* src_u32 = reinterpret_cast<uint32_t const*>(&in);
845-
uint16_t* dst_u16 = reinterpret_cast<uint16_t*>(&out);
847+
__nv_fp8x2_e4m3* out_fp8x2 = reinterpret_cast<__nv_fp8x2_e4m3*>(&out);
848+
__nv_bfloat162 const* in_bf16x2 = reinterpret_cast<__nv_bfloat162 const*>(&in);
846849
#pragma unroll
847-
for (int p = 0; p < VEC_SIZE / 2; ++p) {
848-
uint16_t d;
849-
asm volatile("cvt.rn.satfinite.e4m3x2.bf16x2 %0, %1;" : "=h"(d) : "r"(src_u32[p]));
850-
dst_u16[p] = d;
851-
}
850+
for (int p = 0; p < static_cast<int>(VEC_SIZE) / 2; ++p)
851+
out_fp8x2[p] = __nv_fp8x2_e4m3(in_bf16x2[p]);
852852
}
853853
#endif
854854

0 commit comments

Comments
 (0)