Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion csrc/concat_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@

using namespace flashinfer;

#define DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8(dlpack_dtype, c_type, ...) \
[&]() -> bool { \
switch (encode_dlpack_dtype(dlpack_dtype)) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
TVM_FFI_ICHECK(false) << __PRETTY_FUNCTION__ << " failed to dispatch data type " \
<< (dlpack_dtype).code << " " << (dlpack_dtype).bits; \
return false; \
} \
}()

/*!
* \brief Concatenate k_nope and k_rope tensors for MLA attention
*
Expand Down Expand Up @@ -84,7 +98,7 @@ void concat_mla_k(TensorView k, TensorView k_nope, TensorView k_rope) {
ffi::CUDADeviceGuard device_guard(k.device().device_id);
const cudaStream_t stream = get_stream(k.device());

bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(k.dtype(), c_type, [&] {
bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8(k.dtype(), c_type, [&] {
cudaError_t status = ConcatMLAK<c_type>(
static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(k_nope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()), num_tokens, k_stride_0, k_stride_1,
Expand Down
6 changes: 5 additions & 1 deletion flashinfer/concat_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ def concat_mla_k(
- k_nope: per-head nope values
- k_rope: shared rope values (broadcast to all heads)

Supported dtypes: ``torch.bfloat16``, ``torch.float16``,
``torch.float8_e4m3fn``, ``torch.float8_e5m2``.

Key optimizations:
- Warp-based processing with software pipelining
- Vectorized memory access (int2 for nope, int for rope)
- Vectorized memory access (compile-time dispatch per dtype)
- L2 prefetching for next row while processing current
- Register reuse for rope values across all heads in a chunk

Expand All @@ -67,6 +70,7 @@ def concat_mla_k(
>>> num_heads = 128
>>> nope_dim = 128
>>> rope_dim = 64
>>> # BF16 example
>>> k = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, dtype=torch.bfloat16, device="cuda")
>>> k_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=torch.bfloat16, device="cuda")
>>> k_rope = torch.randn(num_tokens, 1, rope_dim, dtype=torch.bfloat16, device="cuda")
Expand Down
142 changes: 111 additions & 31 deletions include/flashinfer/concat_mla.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>

#include "utils.cuh"
Expand All @@ -33,6 +34,91 @@ constexpr int MLA_K_HEAD_DIM = MLA_QK_NOPE_HEAD_DIM + MLA_QK_ROPE_HEAD_DIM;
constexpr int MLA_HEAD_CHUNK_SIZE = 16;
constexpr int MLA_NUM_HEAD_CHUNKS = MLA_NUM_LOCAL_HEADS / MLA_HEAD_CHUNK_SIZE;

// ======================= Vec Traits for DType dispatch =======================
// BF16/FP16: int2 (8B) for nope, int (4B) for rope → ld/st_v2 / ld/st_v1
// FP8: int (4B) for nope, short (2B) for rope → ld/st_v1 / ld/st_s16
// Elements-per-vec is always 4 (nope) and 2 (rope), so stride arithmetic is
// identical across all DTypes; only the vector width and PTX instructions differ.

template <typename DType>
struct ConcatMLAVecTraits;

template <>
struct ConcatMLAVecTraits<nv_half> {
using NopeVec = int2;
using RopeVec = int;

__forceinline__ __device__ static NopeVec load_nope(const NopeVec* ptr) {
return ld_na_global_v2(reinterpret_cast<const int2*>(ptr));
}
__forceinline__ __device__ static RopeVec load_rope(const RopeVec* ptr) {
return ld_na_global_v1(reinterpret_cast<const int*>(ptr));
}
__forceinline__ __device__ static void store_nope(NopeVec* ptr, NopeVec val) {
st_na_global_v2(reinterpret_cast<int2*>(ptr), val);
}
__forceinline__ __device__ static void store_rope(RopeVec* ptr, RopeVec val) {
st_na_global_v1(reinterpret_cast<int*>(ptr), val);
}
};

template <>
struct ConcatMLAVecTraits<nv_bfloat16> {
using NopeVec = int2;
using RopeVec = int;

__forceinline__ __device__ static NopeVec load_nope(const NopeVec* ptr) {
return ld_na_global_v2(reinterpret_cast<const int2*>(ptr));
}
__forceinline__ __device__ static RopeVec load_rope(const RopeVec* ptr) {
return ld_na_global_v1(reinterpret_cast<const int*>(ptr));
}
__forceinline__ __device__ static void store_nope(NopeVec* ptr, NopeVec val) {
st_na_global_v2(reinterpret_cast<int2*>(ptr), val);
}
__forceinline__ __device__ static void store_rope(RopeVec* ptr, RopeVec val) {
st_na_global_v1(reinterpret_cast<int*>(ptr), val);
}
};

template <>
struct ConcatMLAVecTraits<__nv_fp8_e4m3> {
using NopeVec = int;
using RopeVec = short;

__forceinline__ __device__ static NopeVec load_nope(const NopeVec* ptr) {
return ld_na_global_v1(reinterpret_cast<const int*>(ptr));
}
__forceinline__ __device__ static RopeVec load_rope(const RopeVec* ptr) {
return ld_na_global_s16(reinterpret_cast<const short*>(ptr));
}
__forceinline__ __device__ static void store_nope(NopeVec* ptr, NopeVec val) {
st_na_global_v1(reinterpret_cast<int*>(ptr), val);
}
__forceinline__ __device__ static void store_rope(RopeVec* ptr, RopeVec val) {
st_na_global_s16(reinterpret_cast<short*>(ptr), val);
}
};

template <>
struct ConcatMLAVecTraits<__nv_fp8_e5m2> {
using NopeVec = int;
using RopeVec = short;

__forceinline__ __device__ static NopeVec load_nope(const NopeVec* ptr) {
return ld_na_global_v1(reinterpret_cast<const int*>(ptr));
}
__forceinline__ __device__ static RopeVec load_rope(const RopeVec* ptr) {
return ld_na_global_s16(reinterpret_cast<const short*>(ptr));
}
__forceinline__ __device__ static void store_nope(NopeVec* ptr, NopeVec val) {
st_na_global_v1(reinterpret_cast<int*>(ptr), val);
}
__forceinline__ __device__ static void store_rope(RopeVec* ptr, RopeVec val) {
st_na_global_s16(reinterpret_cast<short*>(ptr), val);
}
};

// ======================= Optimized Kernel =======================
/*!
* \brief Optimized CUDA kernel for concatenating k_nope and k_rope for MLA
Expand All @@ -45,11 +131,11 @@ constexpr int MLA_NUM_HEAD_CHUNKS = MLA_NUM_LOCAL_HEADS / MLA_HEAD_CHUNK_SIZE;
*
* Key optimizations:
* - Warp-based processing: each warp handles one (token, head_chunk) pair
* - Vectorized memory access: int2 (8B) for nope, int (4B) for rope
* - Vectorized memory access via ConcatMLAVecTraits (compile-time dispatch)
* - L2 prefetching: prefetch next row while processing current
* - Register reuse: rope is loaded once and written to all heads in chunk
*
* \tparam DType Data type (nv_bfloat16 or nv_half)
* \tparam DType Data type (nv_bfloat16, nv_half, __nv_fp8_e4m3, __nv_fp8_e5m2)
*/
template <typename DType>
__global__ void ConcatMLAKKernel(DType* __restrict__ k, const DType* __restrict__ k_nope,
Expand All @@ -70,59 +156,53 @@ __global__ void ConcatMLAKKernel(DType* __restrict__ k, const DType* __restrict_

if (token_id >= num_tokens) return;

// Vector types for efficient memory access
// NopeVec: 8B/thread, 32 threads = 256B/row (covers nope_dim bf16 elements)
// RopeVec: 4B/thread, 32 threads = 128B/row (covers rope_dim bf16 elements)
using NopeVec = int2;
using RopeVec = int;
using Traits = ConcatMLAVecTraits<DType>;
using NopeVec = typename Traits::NopeVec;
using RopeVec = typename Traits::RopeVec;
static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(DType), "nope vec mismatch");
static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(DType), "rope vec mismatch");

const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE;

// Source pointer for k_nope (indexed by token and head)
const int2* __restrict__ nope_src =
reinterpret_cast<const int2*>(k_nope + token_id * k_nope_stride_0 +
head_row0 * k_nope_stride_1) +
const NopeVec* __restrict__ nope_src =
reinterpret_cast<const NopeVec*>(k_nope + token_id * k_nope_stride_0 +
head_row0 * k_nope_stride_1) +
lane_id;

// Destination pointers for output k (nope part and rope part)
int2* __restrict__ nope_dst =
reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id;
NopeVec* __restrict__ nope_dst =
reinterpret_cast<NopeVec*>(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id;

int* __restrict__ rope_dst = reinterpret_cast<int*>(k + token_id * k_stride_0 +
head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) +
lane_id;
RopeVec* __restrict__ rope_dst =
reinterpret_cast<RopeVec*>(k + token_id * k_stride_0 + head_row0 * k_stride_1 +
QK_NOPE_HEAD_DIM) +
lane_id;

// Stride calculations for vector types
const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16
const int nope_src_stride_v = (k_nope_stride_1 >> 2); // 4 elements per vec for all DTypes
const int nope_dst_stride_v = (k_stride_1 >> 2);
const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16
const int rope_dst_stride_v = (k_stride_1 >> 1); // 2 elements per vec for all DTypes

// Load rope value once - it's shared across all heads
const int* rope_base = reinterpret_cast<const int*>(k_rope + token_id * k_rope_stride_0);
const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id);
// Load rope value once - shared across all heads
const RopeVec* rope_ptr =
reinterpret_cast<const RopeVec*>(k_rope + token_id * k_rope_stride_0) + lane_id;
RopeVec rope_val = Traits::load_rope(rope_ptr);

// Prefetch first nope row and load it
prefetch_L2(nope_src);
NopeVec cur = ld_na_global_v2(nope_src);
NopeVec cur = Traits::load_nope(nope_src);

// Process all heads in this chunk with software pipelining
#pragma unroll
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {
NopeVec next;
if (i + 1 < HEAD_CHUNK_SIZE) {
// Prefetch and load next row while processing current
const int2* next_src = nope_src + nope_src_stride_v;
const NopeVec* next_src = nope_src + nope_src_stride_v;
prefetch_L2(next_src);
next = ld_na_global_v2(next_src);
next = Traits::load_nope(next_src);
}

// Write current nope and rope values
st_na_global_v2(nope_dst, cur);
st_na_global_v1(rope_dst, rope_val);
Traits::store_nope(nope_dst, cur);
Traits::store_rope(rope_dst, rope_val);

// Advance pointers
nope_src += nope_src_stride_v;
nope_dst += nope_dst_stride_v;
rope_dst += rope_dst_stride_v;
Expand Down
16 changes: 16 additions & 0 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,22 @@ __forceinline__ __device__ int get_lane_id() {
return lane_id;
}

/*!
* \brief Non-atomic global load for short (2 bytes) with cache streaming hint
*/
__forceinline__ __device__ short ld_na_global_s16(const short* addr) {
short val;
asm volatile("ld.global.cs.b16 %0, [%1];" : "=h"(val) : "l"(addr));
return val;
}

/*!
* \brief Non-atomic global store for short (2 bytes) with cache streaming hint
*/
__forceinline__ __device__ void st_na_global_s16(short* addr, short val) {
asm volatile("st.global.cs.b16 [%0], %1;" ::"l"(addr), "h"(val));
}

/*!
* \brief Non-atomic global load for int (4 bytes) with cache streaming hint
*/
Expand Down
Loading
Loading