diff --git a/csrc/concat_mla.cu b/csrc/concat_mla.cu index 932b26a65e..7da9cdaaf6 100644 --- a/csrc/concat_mla.cu +++ b/csrc/concat_mla.cu @@ -84,7 +84,7 @@ void concat_mla_k(TensorView k, TensorView k_nope, TensorView k_rope) { ffi::CUDADeviceGuard device_guard(k.device().device_id); const cudaStream_t stream = get_stream(k.device()); - bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(k.dtype(), c_type, [&] { + bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8(k.dtype(), c_type, [&] { cudaError_t status = ConcatMLAK( static_cast(k.data_ptr()), static_cast(k_nope.data_ptr()), static_cast(k_rope.data_ptr()), num_tokens, k_stride_0, k_stride_1, diff --git a/csrc/tvm_ffi_utils.h b/csrc/tvm_ffi_utils.h index b0150aecfb..2b485175ad 100644 --- a/csrc/tvm_ffi_utils.h +++ b/csrc/tvm_ffi_utils.h @@ -166,6 +166,20 @@ constexpr DLDevice cpu = DLDevice{kDLCPU, 0}; } \ }() +#define DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8(dlpack_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (encode_dlpack_dtype(dlpack_dtype)) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + TVM_FFI_ICHECK(false) << __PRETTY_FUNCTION__ << " failed to dispatch data type " \ + << (dlpack_dtype).code << " " << (dlpack_dtype).bits; \ + return false; \ + } \ + }() + #ifdef FLASHINFER_ENABLE_F32 #define _DISPATCH_CASE_F32(c_type, ...) \ case float32_code: { \ diff --git a/flashinfer/concat_ops.py b/flashinfer/concat_ops.py index 8957092a22..6e5f97618b 100644 --- a/flashinfer/concat_ops.py +++ b/flashinfer/concat_ops.py @@ -42,9 +42,12 @@ def concat_mla_k( - k_nope: per-head nope values - k_rope: shared rope values (broadcast to all heads) + Supported dtypes: ``torch.bfloat16``, ``torch.float16``, + ``torch.float8_e4m3fn``, ``torch.float8_e5m2``. + Key optimizations: - Warp-based processing with software pipelining - - Vectorized memory access (int2 for nope, int for rope) + - Vectorized memory access (compile-time dispatch per dtype) - L2 prefetching for next row while processing current - Register reuse for rope values across all heads in a chunk @@ -67,6 +70,7 @@ def concat_mla_k( >>> num_heads = 128 >>> nope_dim = 128 >>> rope_dim = 64 + >>> # BF16 example >>> k = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, dtype=torch.bfloat16, device="cuda") >>> k_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=torch.bfloat16, device="cuda") >>> k_rope = torch.randn(num_tokens, 1, rope_dim, dtype=torch.bfloat16, device="cuda") diff --git a/include/flashinfer/concat_mla.cuh b/include/flashinfer/concat_mla.cuh index 097a41f2ce..8109cc3659 100644 --- a/include/flashinfer/concat_mla.cuh +++ b/include/flashinfer/concat_mla.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include "utils.cuh" @@ -33,6 +34,91 @@ constexpr int MLA_K_HEAD_DIM = MLA_QK_NOPE_HEAD_DIM + MLA_QK_ROPE_HEAD_DIM; constexpr int MLA_HEAD_CHUNK_SIZE = 16; constexpr int MLA_NUM_HEAD_CHUNKS = MLA_NUM_LOCAL_HEADS / MLA_HEAD_CHUNK_SIZE; +// ======================= Vec Traits for DType dispatch ======================= +// BF16/FP16: int2 (8B) for nope, int (4B) for rope → ld/st_v2 / ld/st_v1 +// FP8: int (4B) for nope, short (2B) for rope → ld/st_v1 / ld/st_s16 +// Elements-per-vec is always 4 (nope) and 2 (rope), so stride arithmetic is +// identical across all DTypes; only the vector width and PTX instructions differ. + +template +struct ConcatMLAVecTraits; + +template <> +struct ConcatMLAVecTraits { + using NopeVec = int2; + using RopeVec = int; + + static __forceinline__ __device__ NopeVec load_nope(const NopeVec* ptr) { + return ld_na_global_v2(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ RopeVec load_rope(const RopeVec* ptr) { + return ld_na_global_v1(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ void store_nope(NopeVec* ptr, NopeVec val) { + st_na_global_v2(reinterpret_cast(ptr), val); + } + static __forceinline__ __device__ void store_rope(RopeVec* ptr, RopeVec val) { + st_na_global_v1(reinterpret_cast(ptr), val); + } +}; + +template <> +struct ConcatMLAVecTraits { + using NopeVec = int2; + using RopeVec = int; + + static __forceinline__ __device__ NopeVec load_nope(const NopeVec* ptr) { + return ld_na_global_v2(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ RopeVec load_rope(const RopeVec* ptr) { + return ld_na_global_v1(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ void store_nope(NopeVec* ptr, NopeVec val) { + st_na_global_v2(reinterpret_cast(ptr), val); + } + static __forceinline__ __device__ void store_rope(RopeVec* ptr, RopeVec val) { + st_na_global_v1(reinterpret_cast(ptr), val); + } +}; + +template <> +struct ConcatMLAVecTraits<__nv_fp8_e4m3> { + using NopeVec = int; + using RopeVec = short; + + static __forceinline__ __device__ NopeVec load_nope(const NopeVec* ptr) { + return ld_na_global_v1(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ RopeVec load_rope(const RopeVec* ptr) { + return ld_na_global_s16(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ void store_nope(NopeVec* ptr, NopeVec val) { + st_na_global_v1(reinterpret_cast(ptr), val); + } + static __forceinline__ __device__ void store_rope(RopeVec* ptr, RopeVec val) { + st_na_global_s16(reinterpret_cast(ptr), val); + } +}; + +template <> +struct ConcatMLAVecTraits<__nv_fp8_e5m2> { + using NopeVec = int; + using RopeVec = short; + + static __forceinline__ __device__ NopeVec load_nope(const NopeVec* ptr) { + return ld_na_global_v1(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ RopeVec load_rope(const RopeVec* ptr) { + return ld_na_global_s16(reinterpret_cast(ptr)); + } + static __forceinline__ __device__ void store_nope(NopeVec* ptr, NopeVec val) { + st_na_global_v1(reinterpret_cast(ptr), val); + } + static __forceinline__ __device__ void store_rope(RopeVec* ptr, RopeVec val) { + st_na_global_s16(reinterpret_cast(ptr), val); + } +}; + // ======================= Optimized Kernel ======================= /*! * \brief Optimized CUDA kernel for concatenating k_nope and k_rope for MLA @@ -45,11 +131,11 @@ constexpr int MLA_NUM_HEAD_CHUNKS = MLA_NUM_LOCAL_HEADS / MLA_HEAD_CHUNK_SIZE; * * Key optimizations: * - Warp-based processing: each warp handles one (token, head_chunk) pair - * - Vectorized memory access: int2 (8B) for nope, int (4B) for rope + * - Vectorized memory access via ConcatMLAVecTraits (compile-time dispatch) * - L2 prefetching: prefetch next row while processing current * - Register reuse: rope is loaded once and written to all heads in chunk * - * \tparam DType Data type (nv_bfloat16 or nv_half) + * \tparam DType Data type (nv_bfloat16, nv_half, __nv_fp8_e4m3, __nv_fp8_e5m2) */ template __global__ void ConcatMLAKKernel(DType* __restrict__ k, const DType* __restrict__ k_nope, @@ -70,59 +156,53 @@ __global__ void ConcatMLAKKernel(DType* __restrict__ k, const DType* __restrict_ if (token_id >= num_tokens) return; - // Vector types for efficient memory access - // NopeVec: 8B/thread, 32 threads = 256B/row (covers nope_dim bf16 elements) - // RopeVec: 4B/thread, 32 threads = 128B/row (covers rope_dim bf16 elements) - using NopeVec = int2; - using RopeVec = int; + using Traits = ConcatMLAVecTraits; + using NopeVec = typename Traits::NopeVec; + using RopeVec = typename Traits::RopeVec; static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(DType), "nope vec mismatch"); static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(DType), "rope vec mismatch"); const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE; - // Source pointer for k_nope (indexed by token and head) - const int2* __restrict__ nope_src = - reinterpret_cast(k_nope + token_id * k_nope_stride_0 + - head_row0 * k_nope_stride_1) + + const NopeVec* __restrict__ nope_src = + reinterpret_cast(k_nope + token_id * k_nope_stride_0 + + head_row0 * k_nope_stride_1) + lane_id; - // Destination pointers for output k (nope part and rope part) - int2* __restrict__ nope_dst = - reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id; + NopeVec* __restrict__ nope_dst = + reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id; - int* __restrict__ rope_dst = reinterpret_cast(k + token_id * k_stride_0 + - head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + - lane_id; + RopeVec* __restrict__ rope_dst = + reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1 + + QK_NOPE_HEAD_DIM) + + lane_id; - // Stride calculations for vector types - const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16 + const int nope_src_stride_v = (k_nope_stride_1 >> 2); // 4 elements per vec for all DTypes const int nope_dst_stride_v = (k_stride_1 >> 2); - const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16 + const int rope_dst_stride_v = (k_stride_1 >> 1); // 2 elements per vec for all DTypes - // Load rope value once - it's shared across all heads - const int* rope_base = reinterpret_cast(k_rope + token_id * k_rope_stride_0); - const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id); + // Load rope value once - shared across all heads + const RopeVec* rope_ptr = + reinterpret_cast(k_rope + token_id * k_rope_stride_0) + lane_id; + RopeVec rope_val = Traits::load_rope(rope_ptr); // Prefetch first nope row and load it prefetch_L2(nope_src); - NopeVec cur = ld_na_global_v2(nope_src); + NopeVec cur = Traits::load_nope(nope_src); // Process all heads in this chunk with software pipelining #pragma unroll for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { NopeVec next; if (i + 1 < HEAD_CHUNK_SIZE) { - // Prefetch and load next row while processing current - const int2* next_src = nope_src + nope_src_stride_v; + const NopeVec* next_src = nope_src + nope_src_stride_v; prefetch_L2(next_src); - next = ld_na_global_v2(next_src); + next = Traits::load_nope(next_src); } - // Write current nope and rope values - st_na_global_v2(nope_dst, cur); - st_na_global_v1(rope_dst, rope_val); + Traits::store_nope(nope_dst, cur); + Traits::store_rope(rope_dst, rope_val); - // Advance pointers nope_src += nope_src_stride_v; nope_dst += nope_dst_stride_v; rope_dst += rope_dst_stride_v; diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index c7edf5ab57..787a6de6d6 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -446,6 +446,22 @@ __forceinline__ __device__ int get_lane_id() { return lane_id; } +/*! + * \brief Non-atomic global load for short (2 bytes) with cache streaming hint + */ +__forceinline__ __device__ short ld_na_global_s16(const short* addr) { + short val; + asm volatile("ld.global.cs.b16 %0, [%1];" : "=h"(val) : "l"(addr)); + return val; +} + +/*! + * \brief Non-atomic global store for short (2 bytes) with cache streaming hint + */ +__forceinline__ __device__ void st_na_global_s16(short* addr, short val) { + asm volatile("st.global.cs.b16 [%0], %1;" ::"l"(addr), "h"(val)); +} + /*! * \brief Non-atomic global load for int (4 bytes) with cache streaming hint */ diff --git a/tests/utils/test_concat_mla.py b/tests/utils/test_concat_mla.py new file mode 100644 index 0000000000..75a9747c5e --- /dev/null +++ b/tests/utils/test_concat_mla.py @@ -0,0 +1,191 @@ +""" +Tests for concat_mla_k kernel — verifies correctness across BF16, FP16, and FP8 dtypes. + +concat_mla_k is a pure memory movement operation (copy + broadcast), so the output +must be **bit-exact** compared to the PyTorch slice-assign reference. +""" + +import pytest +import torch + +from flashinfer.concat_ops import concat_mla_k +from flashinfer.utils import get_compute_capability + +NUM_LOCAL_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +QK_ROPE_HEAD_DIM = 64 +K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM + + +def _reference_concat(k_nope: torch.Tensor, k_rope: torch.Tensor) -> torch.Tensor: + """PyTorch reference: slice-assign with broadcast.""" + k = torch.empty( + (*k_nope.shape[:-1], K_HEAD_DIM), + dtype=k_nope.dtype, + device=k_nope.device, + ) + k[..., :QK_NOPE_HEAD_DIM] = k_nope + k[..., QK_NOPE_HEAD_DIM:] = k_rope + return k + + +def _make_tensors(num_tokens: int, dtype: torch.dtype, device: str = "cuda"): + """Create contiguous k_nope, k_rope, and pre-allocated output k.""" + # Generate in BF16 then cast — FP8 doesn't support randn directly + k_nope = ( + torch.randn( + num_tokens, + NUM_LOCAL_HEADS, + QK_NOPE_HEAD_DIM, + device=device, + dtype=torch.bfloat16, + ) + .to(dtype) + .contiguous() + ) + k_rope = ( + torch.randn( + num_tokens, 1, QK_ROPE_HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + .to(dtype) + .contiguous() + ) + k = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, dtype=dtype, device=device) + return k, k_nope, k_rope + + +# ────────────────────────── Core correctness tests ────────────────────────── + + +@pytest.mark.parametrize("num_tokens", [1, 32, 1024, 8192]) +@pytest.mark.parametrize( + "dtype", + [ + torch.bfloat16, + torch.float16, + pytest.param(torch.float8_e4m3fn, id="fp8_e4m3"), + pytest.param(torch.float8_e5m2, id="fp8_e5m2"), + ], +) +def test_concat_mla_k_correctness(num_tokens, dtype): + """Bit-exact correctness: flashinfer output == PyTorch reference.""" + if dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + major, minor = get_compute_capability(torch.device("cuda")) + if (major, minor) < (8, 9): + pytest.skip("FP8 requires SM >= 89 (Ada/Hopper)") + + k, k_nope, k_rope = _make_tensors(num_tokens, dtype) + concat_mla_k(k, k_nope, k_rope) + + ref = _reference_concat(k_nope, k_rope) + + # Pure copy — must be bit-exact + if dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + assert torch.equal(k.view(torch.uint8), ref.view(torch.uint8)), ( + f"Mismatch for dtype={dtype}, num_tokens={num_tokens}." + ) + else: + assert torch.equal(k, ref), ( + f"Mismatch for dtype={dtype}, num_tokens={num_tokens}. " + f"max abs diff = {(k.to(torch.float32) - ref.to(torch.float32)).abs().max().item()}" + ) + + +# ────────────────────────── Zero-token edge case ────────────────────────── + + +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16, torch.float16, torch.float8_e4m3fn], +) +def test_concat_mla_k_zero_tokens(dtype): + """num_tokens=0 should return immediately without error.""" + if dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + major, minor = get_compute_capability(torch.device("cuda")) + if (major, minor) < (8, 9): + pytest.skip("FP8 requires SM >= 89") + + k, k_nope, k_rope = _make_tensors(0, dtype) + concat_mla_k(k, k_nope, k_rope) # should not crash + + +# ────────────────────────── Strided (non-contiguous last dim) inputs ────── + + +@pytest.mark.parametrize( + "dtype", + [ + torch.bfloat16, + pytest.param(torch.float8_e4m3fn, id="fp8_e4m3"), + ], +) +def test_concat_mla_k_strided_inputs(dtype): + """Verify correctness when k_nope is a slice of a larger contiguous tensor.""" + if dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + major, minor = get_compute_capability(torch.device("cuda")) + if (major, minor) < (8, 9): + pytest.skip("FP8 requires SM >= 89") + + num_tokens = 2048 + + # k_nope is a slice — last-dim contiguous but has a stride gap on dim-1 + nope_container = torch.randn( + num_tokens, + NUM_LOCAL_HEADS, + QK_NOPE_HEAD_DIM + 128, + device="cuda", + dtype=torch.bfloat16, + ).to(dtype) + k_nope = nope_container[:, :, :QK_NOPE_HEAD_DIM] + + k_rope = ( + torch.randn( + num_tokens, 1, QK_ROPE_HEAD_DIM, device="cuda", dtype=torch.bfloat16 + ) + .to(dtype) + .contiguous() + ) + + k = torch.empty(num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, dtype=dtype, device="cuda") + concat_mla_k(k, k_nope, k_rope) + + ref = _reference_concat(k_nope, k_rope) + if dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + assert torch.equal(k.view(torch.uint8), ref.view(torch.uint8)) + else: + assert torch.equal(k, ref) + + +# ────────────────────────── Cross-dtype guard ────────────────────────── + + +def test_concat_mla_k_dtype_mismatch_raises(): + """Passing mismatched dtypes should raise an error from the C++ side.""" + num_tokens = 64 + k_nope = torch.randn( + num_tokens, + NUM_LOCAL_HEADS, + QK_NOPE_HEAD_DIM, + device="cuda", + dtype=torch.bfloat16, + ) + k_rope = torch.randn( + num_tokens, + 1, + QK_ROPE_HEAD_DIM, + device="cuda", + dtype=torch.float16, # intentional mismatch + ) + k = torch.empty( + num_tokens, + NUM_LOCAL_HEADS, + K_HEAD_DIM, + device="cuda", + dtype=torch.bfloat16, + ) + with pytest.raises(RuntimeError): + concat_mla_k(k, k_nope, k_rope) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])