feat: Enable FP8 (E4M3/E5M2) in concat_mla_k for optimize long-context prefill performance and refactor type dispatch for BF16/FP16#3129
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds FP8 (E4M3, E5M2) plus BF16 support to concat_mla: introduces dtype-specific vector traits and 16-bit non-atomic PTX helpers, replaces the FP16-only dispatch with a new FP16+FP8 dispatch macro, updates docstrings, and adds tests covering BF16/FP16/FP8 and edge cases. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for FP8 data types to the concat_mla_k kernel by introducing a ConcatMLAVecTraits template for compile-time dispatch of vector types and memory instructions. It also includes new 16-bit global load/store utilities and comprehensive unit tests for correctness across various dtypes and input configurations. Feedback suggests refactoring the kernel's dispatch logic into the traits structure to improve maintainability and reduce code repetition.
concat_mla_k previously only supported BF16/FP16, causing crashes when vLLM's chunked prefill path passes FP8 quantized K tensors. Changes: - utils.cuh: add ld_na_global_s16/st_na_global_s16 for 2-byte vectorized access - concat_mla.cuh: add ConcatMLAVecTraits template for compile-time type dispatch (BF16/FP16 -> int2/int, FP8 -> int/short), using if constexpr for zero overhead - concat_mla.cu: add DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 macro for FP8 dispatch - concat_ops.py: update docstring to list FP8 dtypes - test_concat_mla.py: add comprehensive pytest for all supported dtypes Tested end-to-end on GB300 with DeepSeek-R1-0528-FP4 (ISL=128K, 16 prompts): - Without fix: RuntimeError "k and k_rope must have the same dtype" - With fix: 16/16 success, Median TTFT -27% vs BF16 baseline Signed-off-by: Albert Cheng <albecheng@nvidia.com>
3bafeb4 to
4d93ec2
Compare
|
/bot run |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/concat_mla.cu (1)
101-105:⚠️ Potential issue | 🟠 MajorReject strides the vector kernel cannot represent.
The binding only checks last-dim contiguity, but the kernel later turns element strides into vector strides with
>> 2/>> 1. Views with last-dim stride 1 but head/token strides like 129 pass validation and then get silently truncated, reading/writing the wrong rows. Add launcher checks for the vector-stride assumptions before dispatch.Suggested validation before dispatch
int64_t k_stride_0 = k.stride(0); int k_stride_1 = k.stride(1); int64_t k_nope_stride_0 = k_nope.stride(0); int k_nope_stride_1 = k_nope.stride(1); int64_t k_rope_stride_0 = k_rope.stride(0); + + // ConcatMLAKKernel reinterprets nope rows as 4-element vectors and rope rows + // as 2-element vectors. Reject strided views that would truncate in vector + // stride arithmetic. + TVM_FFI_ICHECK_EQ(k_stride_0 % 4, 0) << "k token stride must be divisible by 4"; + TVM_FFI_ICHECK_EQ(k_stride_1 % 4, 0) << "k head stride must be divisible by 4"; + TVM_FFI_ICHECK_EQ(k_nope_stride_0 % 4, 0) + << "k_nope token stride must be divisible by 4"; + TVM_FFI_ICHECK_EQ(k_nope_stride_1 % 4, 0) + << "k_nope head stride must be divisible by 4"; + TVM_FFI_ICHECK_EQ(k_rope_stride_0 % 2, 0) + << "k_rope token stride must be divisible by 2"; ffi::CUDADeviceGuard device_guard(k.device().device_id);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/concat_mla.cu` around lines 101 - 105, The launcher currently only checks last-dim contiguity but must also reject inputs whose element strides cannot be represented by the kernel's vectorized strides; before calling DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 and launching ConcatMLAK, add validation that k_stride_1, k_nope_stride_1, and k_rope_stride_1 are divisible by the vector lane width (the shifts used in the kernel, e.g. >>2 or >>1) and that the resulting vector strides (k_stride_1 >> N, etc.) fit into the kernel's expected integer range; if any fail, return/error out with a clear message instead of dispatching. Ensure you reference the same stride variables (k_stride_0, k_stride_1, k_nope_stride_0, k_nope_stride_1, k_rope_stride_0, k_rope_stride_1) and perform these checks immediately before invoking DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8/ConcatMLAK.include/flashinfer/concat_mla.cuh (1)
193-210:⚠️ Potential issue | 🟠 MajorGuard the final iteration to avoid reading uninitialized
next.On the last unrolled iteration (when
i = HEAD_CHUNK_SIZE - 1), the conditioni + 1 < HEAD_CHUNK_SIZEis false, sonextis never assigned butcur = nextstill reads it. This is undefined behavior. Add the guard to prevent the read:Suggested fix
nope_src += nope_src_stride_v; nope_dst += nope_dst_stride_v; rope_dst += rope_dst_stride_v; - cur = next; + if (i + 1 < HEAD_CHUNK_SIZE) { + cur = next; + } } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/concat_mla.cuh` around lines 193 - 210, The loop assigns cur = next at the end of every iteration but next is not set on the final unrolled iteration, causing undefined reads; fix by guarding that assignment or initializing next: either initialize NopeVec next = cur; before the loop or change the tail behavior so cur = next only happens when (i + 1 < HEAD_CHUNK_SIZE). Update the loop in concat_mla.cuh (the software-pipelined for-loop using HEAD_CHUNK_SIZE, next, cur, nope_src/nope_src_stride_v and Traits::load_nope/Traits::store_nope) to ensure next is always valid before assigning it to cur.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@csrc/concat_mla.cu`:
- Around line 101-105: The launcher currently only checks last-dim contiguity
but must also reject inputs whose element strides cannot be represented by the
kernel's vectorized strides; before calling
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 and launching ConcatMLAK, add validation
that k_stride_1, k_nope_stride_1, and k_rope_stride_1 are divisible by the
vector lane width (the shifts used in the kernel, e.g. >>2 or >>1) and that the
resulting vector strides (k_stride_1 >> N, etc.) fit into the kernel's expected
integer range; if any fail, return/error out with a clear message instead of
dispatching. Ensure you reference the same stride variables (k_stride_0,
k_stride_1, k_nope_stride_0, k_nope_stride_1, k_rope_stride_0, k_rope_stride_1)
and perform these checks immediately before invoking
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8/ConcatMLAK.
In `@include/flashinfer/concat_mla.cuh`:
- Around line 193-210: The loop assigns cur = next at the end of every iteration
but next is not set on the final unrolled iteration, causing undefined reads;
fix by guarding that assignment or initializing next: either initialize NopeVec
next = cur; before the loop or change the tail behavior so cur = next only
happens when (i + 1 < HEAD_CHUNK_SIZE). Update the loop in concat_mla.cuh (the
software-pipelined for-loop using HEAD_CHUNK_SIZE, next, cur,
nope_src/nope_src_stride_v and Traits::load_nope/Traits::store_nope) to ensure
next is always valid before assigning it to cur.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5fc9a351-e174-4438-b30a-5a6a9c55104c
📒 Files selected for processing (5)
csrc/concat_mla.cuflashinfer/concat_ops.pyinclude/flashinfer/concat_mla.cuhinclude/flashinfer/utils.cuhtests/utils/test_concat_mla.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
include/flashinfer/concat_mla.cuh (1)
46-120: Consider collapsing the four traits specializations into two (or oneif constexprblock).The
nv_halfandnv_bfloat16specializations are byte-for-byte identical, and the__nv_fp8_e4m3/__nv_fp8_e5m2specializations are likewise identical. Since the traits only depend onsizeof(DType), one generic primary template (or a shared base keyed on element size) would remove ~40 lines of duplication and matches the "if constexprfor compile-time vector type selection" design mentioned in the PR description. Something along these lines:♻️ Example consolidation
template <typename DType> struct ConcatMLAVecTraits { static_assert(sizeof(DType) == 1 || sizeof(DType) == 2, "ConcatMLAVecTraits only supports 1B (FP8) or 2B (FP16/BF16) DTypes"); using NopeVec = std::conditional_t<sizeof(DType) == 2, int2, int>; using RopeVec = std::conditional_t<sizeof(DType) == 2, int, short>; static __forceinline__ __device__ NopeVec load_nope(const NopeVec* ptr) { if constexpr (sizeof(DType) == 2) return ld_na_global_v2(ptr); else return ld_na_global_v1(ptr); } static __forceinline__ __device__ RopeVec load_rope(const RopeVec* ptr) { if constexpr (sizeof(DType) == 2) return ld_na_global_v1(ptr); else return ld_na_global_s16(ptr); } static __forceinline__ __device__ void store_nope(NopeVec* ptr, NopeVec v) { if constexpr (sizeof(DType) == 2) st_na_global_v2(ptr, v); else st_na_global_v1(ptr, v); } static __forceinline__ __device__ void store_rope(RopeVec* ptr, RopeVec v) { if constexpr (sizeof(DType) == 2) st_na_global_v1(ptr, v); else st_na_global_s16(ptr, v); } };Not blocking — purely a DRY/maintainability win; the static asserts in
ConcatMLAKKernelalready guard against mismatches.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/concat_mla.cuh` around lines 46 - 120, The four nearly-identical specializations of ConcatMLAVecTraits (for nv_half, nv_bfloat16, __nv_fp8_e4m3, __nv_fp8_e5m2) should be collapsed into a single templated ConcatMLAVecTraits<DType> that selects NopeVec/RopeVec and the correct ld_/st_ helpers via sizeof(DType) and if constexpr; update the type aliases (NopeVec, RopeVec) and the four methods (load_nope, load_rope, store_nope, store_rope) to use std::conditional_t and if constexpr to call ld_na_global_v2 vs ld_na_global_v1 and ld_na_global_s16 accordingly, and add a static_assert(sizeof(DType)==1 || sizeof(DType)==2) to preserve the existing guard used by ConcatMLAKernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/tvm_ffi_utils.h`:
- Around line 169-181: The macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 has
misaligned trailing backslashes on the
_DISPATCH_CASE_F16/_DISPATCH_CASE_BF16/_DISPATCH_CASE_FP8_* lines; run
clang-format (or manually adjust whitespace) on csrc/tvm_ffi_utils.h to
normalize indentation so the trailing '\' of those _DISPATCH_CASE_* lines aligns
with the other lines in the macro block, then re-run pre-commit to ensure the
file is formatted consistently.
---
Nitpick comments:
In `@include/flashinfer/concat_mla.cuh`:
- Around line 46-120: The four nearly-identical specializations of
ConcatMLAVecTraits (for nv_half, nv_bfloat16, __nv_fp8_e4m3, __nv_fp8_e5m2)
should be collapsed into a single templated ConcatMLAVecTraits<DType> that
selects NopeVec/RopeVec and the correct ld_/st_ helpers via sizeof(DType) and if
constexpr; update the type aliases (NopeVec, RopeVec) and the four methods
(load_nope, load_rope, store_nope, store_rope) to use std::conditional_t and if
constexpr to call ld_na_global_v2 vs ld_na_global_v1 and ld_na_global_s16
accordingly, and add a static_assert(sizeof(DType)==1 || sizeof(DType)==2) to
preserve the existing guard used by ConcatMLAKernel.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1fb03bda-3139-4553-9c9a-f8e0644e7e05
📒 Files selected for processing (3)
csrc/concat_mla.cucsrc/tvm_ffi_utils.hinclude/flashinfer/concat_mla.cuh
…fier order - Move DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 from concat_mla.cu to tvm_ffi_utils.h so it can be reused by other operators - Normalize static member function qualifier order to `static __forceinline__ __device__` (consistent with codebase style) Signed-off-by: Albert Cheng <albecheng@nvidia.com>
9a95c0c to
896d3ac
Compare
Summary
Enable FP8 (E4M3/E5M2) support in
concat_mla_k, fixing a crash that blocks FP8 chunked prefill for all MLA models (DeepSeek-V2/V3/R1) on long-context workloads.Motivation
For long-context inference (ISL >= 4K) in vLLM, chunked prefill + FP8 quantization (
use_prefill_query_quantization: true) is critical for reducing TTFT and improve throughput. The FP8 FMHA kernel is ~1.35x faster than BF16 at 128K context, but this path was complete unusable becauseconcat_mla_krejected FP8 inputs.In vLLM's
_compute_prefill_context(the chunked prefill path for MLA), K/V tensors are cast to FP8 before being passed toflashinfer_concat_mla_k:The kernel uses
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16which only dispatches BF16/FP16, causing RuntimeErrorA vLLM-side workaround (my PR #39841 reordering cast after concat) works, but it introduces an extra BF16 to FP8 round-trip and does not address the root cause. The proper fix is enabling FP8 at the kernel level. We keep it for temporal workaround.
Changes
include/flashinfer/utils.cuhld_na_global_s16/st_na_global_s16for 2-byte vectorized load and store (FP8 rope = 64 elements × 1B = 2B/thread)include/flashinfer/concat_mla.cuhConcatMLAVecTraits<DType>template for compile time vector type selection (BF16/FP16 to int2/int, FP8 → int or short) withif constexprdispatchcsrc/concat_mla.cuDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8macro extending dispatch to FP8 E4M3/E5M2flashinfer/concat_ops.pytests/utils/test_concat_mla.pyDesign
The key insight is that
concat_mla_kis pure memory movement, so FP8 support we adjust vectorized load and store widths:int2(8B/thread × 32 threads), rope 64 elem × 2B = 128B toint(4B/thread × 32 threads)int(4B/thread × 32 threads), rope 64 elem × 1B = 64B toshort(2B/thread × 32 threads)if constexprensures that we do not addition runtime overhead.Benchmark Results
End-to-end on GB300 (DeepSeek-R1-0528-FP4, DP=4, chunked prefill, ISL=128K, 16 requests):
Test Plan
pytest tests/utils/test_concat_mla.py, bit exact correctness for all 4 dtypesuse_prefill_query_quantization=true, all succeedSummary by CodeRabbit