Skip to content

feat: Enable FP8 (E4M3/E5M2) in concat_mla_k for optimize long-context prefill performance and refactor type dispatch for BF16/FP16#3129

Open
qiching wants to merge 2 commits intoflashinfer-ai:mainfrom
qiching:fix/concat-mla-k-fp8-support
Open

feat: Enable FP8 (E4M3/E5M2) in concat_mla_k for optimize long-context prefill performance and refactor type dispatch for BF16/FP16#3129
qiching wants to merge 2 commits intoflashinfer-ai:mainfrom
qiching:fix/concat-mla-k-fp8-support

Conversation

@qiching
Copy link
Copy Markdown
Collaborator

@qiching qiching commented Apr 21, 2026

Summary

Enable FP8 (E4M3/E5M2) support in concat_mla_k, fixing a crash that blocks FP8 chunked prefill for all MLA models (DeepSeek-V2/V3/R1) on long-context workloads.

Motivation

For long-context inference (ISL >= 4K) in vLLM, chunked prefill + FP8 quantization (use_prefill_query_quantization: true) is critical for reducing TTFT and improve throughput. The FP8 FMHA kernel is ~1.35x faster than BF16 at 128K context, but this path was complete unusable because concat_mla_k rejected FP8 inputs.

In vLLM's _compute_prefill_context (the chunked prefill path for MLA), K/V tensors are cast to FP8 before being passed to flashinfer_concat_mla_k:

if use_fp8_prefill:
    kv_nope = kv_nope.to(prefill_metadata.q_data_type)  # BF16 to FP8
    k_pe = k_pe.to(prefill_metadata.q_data_type)
k_nope, v = kv_nope.split(...)
k = self._concat_k_nope_k_pe(k_nope, k_pe)  # ← crash: FP8 not supported

The kernel uses DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 which only dispatches BF16/FP16, causing RuntimeError

A vLLM-side workaround (my PR #39841 reordering cast after concat) works, but it introduces an extra BF16 to FP8 round-trip and does not address the root cause. The proper fix is enabling FP8 at the kernel level. We keep it for temporal workaround.

Changes

File Change
include/flashinfer/utils.cuh Add ld_na_global_s16 / st_na_global_s16 for 2-byte vectorized load and store (FP8 rope = 64 elements × 1B = 2B/thread)
include/flashinfer/concat_mla.cuh Add ConcatMLAVecTraits<DType> template for compile time vector type selection (BF16/FP16 to int2/int, FP8 → int or short) with if constexpr dispatch
csrc/concat_mla.cu Add DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 macro extending dispatch to FP8 E4M3/E5M2
flashinfer/concat_ops.py Update docstring to list supported FP8 dtypes
tests/utils/test_concat_mla.py Add full pytest covering BF16, FP16, FP8-E4M3, FP8-E5M2 with bit exact correctness checks

Design

The key insight is that concat_mla_k is pure memory movement, so FP8 support we adjust vectorized load and store widths:

  • BF16/FP16 (2B/elem): nope 128 elem × 2B = 256B to int2 (8B/thread × 32 threads), rope 64 elem × 2B = 128B to int (4B/thread × 32 threads)
  • FP8 (1B/elem): nope 128 elem × 1B = 128B to int (4B/thread × 32 threads), rope 64 elem × 1B = 64B to short (2B/thread × 32 threads)

if constexpr ensures that we do not addition runtime overhead.

Benchmark Results

End-to-end on GB300 (DeepSeek-R1-0528-FP4, DP=4, chunked prefill, ISL=128K, 16 requests):

Metric BF16 (baseline) FP8 (this fix) Delta
Median TTFT 42.0 s 30.1 s -28.3%
Mean TTFT 41.7 s 30.5 s -27.0%
P99 TTFT 43.8s 33.5s -23.5%
Token throughput 12,069 tok/s 16,524 tok/s +37.0%

Test Plan

  • Unit test: pytest tests/utils/test_concat_mla.py, bit exact correctness for all 4 dtypes
  • E2E crash check: ISL=128K with use_prefill_query_quantization=true, all succeed
  • Performance: FP8 prefill -27% Median TTFT vs BF16 at long context
  • No regression: BF16 baseline all succeed with identical perf to stock flashinfer

Summary by CodeRabbit

  • New Features
    • Added support for two FP8 formats alongside FP16 and BF16 in the concat operation.
  • Documentation
    • Updated docs to list supported dtypes and clarify compile-time dtype dispatch semantics.
  • Refactor
    • Generalized vector and memory access handling to uniformly support additional low-precision dtypes.
  • Tests
    • Added comprehensive tests for BF16/FP16/FP8 correctness, edge cases, strided inputs, and dtype-mismatch checks.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 73b3523c-883d-4b2a-b32d-30c027337e90

📥 Commits

Reviewing files that changed from the base of the PR and between 9a95c0c and 896d3ac.

📒 Files selected for processing (3)
  • csrc/concat_mla.cu
  • csrc/tvm_ffi_utils.h
  • include/flashinfer/concat_mla.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/concat_mla.cu
  • include/flashinfer/concat_mla.cuh

📝 Walkthrough

Walkthrough

Adds FP8 (E4M3, E5M2) plus BF16 support to concat_mla: introduces dtype-specific vector traits and 16-bit non-atomic PTX helpers, replaces the FP16-only dispatch with a new FP16+FP8 dispatch macro, updates docstrings, and adds tests covering BF16/FP16/FP8 and edge cases.

Changes

Cohort / File(s) Summary
Dispatch & Kernel
csrc/concat_mla.cu, include/flashinfer/concat_mla.cuh
Switched dtype dispatch to DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8; added ConcatMLAVecTraits<DType> specializations for nv_half, nv_bfloat16, __nv_fp8_e4m3, __nv_fp8_e5m2; refactored kernel to use trait-driven vector types and load_*/store_* helpers.
CUDA utilities
include/flashinfer/utils.cuh
Added ld_na_global_s16 and st_na_global_s16 device helpers implementing 16-bit non-atomic global loads/stores with cache-streaming PTX variants.
TVM FFI dispatch macro
csrc/tvm_ffi_utils.h
Introduced DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 macro dispatching F16, BF16, FP8 E4M3, FP8 E5M2 and centralized default failure check.
Python docs & API usage
flashinfer/concat_ops.py
Updated concat_mla_k docstring to list supported dtypes (bfloat16, float16, fp8_e4m3, fp8_e5m2) and adjusted optimization description to reflect compile-time dtype dispatch.
Tests
tests/utils/test_concat_mla.py
Added tests validating correctness across BF16/FP16/FP8 (with SM guard for FP8), zero-token case, strided/non-contiguous inputs, and dtype-mismatch error behavior.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested labels

run-ci, op: attention

Suggested reviewers

  • yzh119
  • aleozlx
  • nvmbreughe
  • jimmyzho
  • kahyunnam
  • djmmoss

Poem

🐰
Tiny formats hop, a kernel's cheer,
Traits line up so data's clear.
BF16, half, and FP8 two,
I stitch the tokens—bits slip through.
Hop, load, store—concat says woo!

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly describes the main changes: enabling FP8 support in concat_mla_k and refactoring type dispatch, which aligns with the substantial code changes across multiple files and the primary objective of fixing FP8 support.
Description check ✅ Passed The description comprehensively covers motivation, design rationale, changes, benchmark results, and test plan. While the template expects pre-commit and testing checklist items, the author provides detailed technical content that effectively communicates the PR's purpose and validates the implementation.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FP8 data types to the concat_mla_k kernel by introducing a ConcatMLAVecTraits template for compile-time dispatch of vector types and memory instructions. It also includes new 16-bit global load/store utilities and comprehensive unit tests for correctness across various dtypes and input configurations. Feedback suggests refactoring the kernel's dispatch logic into the traits structure to improve maintainability and reduce code repetition.

Comment thread include/flashinfer/concat_mla.cuh Outdated
concat_mla_k previously only supported BF16/FP16, causing crashes when
vLLM's chunked prefill path passes FP8 quantized K tensors.
Changes:
- utils.cuh: add ld_na_global_s16/st_na_global_s16 for 2-byte vectorized access
- concat_mla.cuh: add ConcatMLAVecTraits template for compile-time type dispatch
  (BF16/FP16 -> int2/int, FP8 -> int/short), using if constexpr for zero overhead
- concat_mla.cu: add DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 macro for FP8 dispatch
- concat_ops.py: update docstring to list FP8 dtypes
- test_concat_mla.py: add comprehensive pytest for all supported dtypes
Tested end-to-end on GB300 with DeepSeek-R1-0528-FP4 (ISL=128K, 16 prompts):
- Without fix: RuntimeError "k and k_rope must have the same dtype"
- With fix: 16/16 success, Median TTFT -27% vs BF16 baseline
Signed-off-by: Albert Cheng <albecheng@nvidia.com>
@qiching qiching force-pushed the fix/concat-mla-k-fp8-support branch from 3bafeb4 to 4d93ec2 Compare April 21, 2026 03:54
@qiching qiching marked this pull request as ready for review April 21, 2026 04:17
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Apr 21, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !576 has been created, and the CI pipeline #49059087 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/concat_mla.cu (1)

101-105: ⚠️ Potential issue | 🟠 Major

Reject strides the vector kernel cannot represent.

The binding only checks last-dim contiguity, but the kernel later turns element strides into vector strides with >> 2 / >> 1. Views with last-dim stride 1 but head/token strides like 129 pass validation and then get silently truncated, reading/writing the wrong rows. Add launcher checks for the vector-stride assumptions before dispatch.

Suggested validation before dispatch
   int64_t k_stride_0 = k.stride(0);
   int k_stride_1 = k.stride(1);
   int64_t k_nope_stride_0 = k_nope.stride(0);
   int k_nope_stride_1 = k_nope.stride(1);
   int64_t k_rope_stride_0 = k_rope.stride(0);
+
+  // ConcatMLAKKernel reinterprets nope rows as 4-element vectors and rope rows
+  // as 2-element vectors. Reject strided views that would truncate in vector
+  // stride arithmetic.
+  TVM_FFI_ICHECK_EQ(k_stride_0 % 4, 0) << "k token stride must be divisible by 4";
+  TVM_FFI_ICHECK_EQ(k_stride_1 % 4, 0) << "k head stride must be divisible by 4";
+  TVM_FFI_ICHECK_EQ(k_nope_stride_0 % 4, 0)
+      << "k_nope token stride must be divisible by 4";
+  TVM_FFI_ICHECK_EQ(k_nope_stride_1 % 4, 0)
+      << "k_nope head stride must be divisible by 4";
+  TVM_FFI_ICHECK_EQ(k_rope_stride_0 % 2, 0)
+      << "k_rope token stride must be divisible by 2";
 
   ffi::CUDADeviceGuard device_guard(k.device().device_id);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/concat_mla.cu` around lines 101 - 105, The launcher currently only
checks last-dim contiguity but must also reject inputs whose element strides
cannot be represented by the kernel's vectorized strides; before calling
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 and launching ConcatMLAK, add validation
that k_stride_1, k_nope_stride_1, and k_rope_stride_1 are divisible by the
vector lane width (the shifts used in the kernel, e.g. >>2 or >>1) and that the
resulting vector strides (k_stride_1 >> N, etc.) fit into the kernel's expected
integer range; if any fail, return/error out with a clear message instead of
dispatching. Ensure you reference the same stride variables (k_stride_0,
k_stride_1, k_nope_stride_0, k_nope_stride_1, k_rope_stride_0, k_rope_stride_1)
and perform these checks immediately before invoking
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8/ConcatMLAK.
include/flashinfer/concat_mla.cuh (1)

193-210: ⚠️ Potential issue | 🟠 Major

Guard the final iteration to avoid reading uninitialized next.

On the last unrolled iteration (when i = HEAD_CHUNK_SIZE - 1), the condition i + 1 < HEAD_CHUNK_SIZE is false, so next is never assigned but cur = next still reads it. This is undefined behavior. Add the guard to prevent the read:

Suggested fix
     nope_src += nope_src_stride_v;
     nope_dst += nope_dst_stride_v;
     rope_dst += rope_dst_stride_v;
 
-    cur = next;
+    if (i + 1 < HEAD_CHUNK_SIZE) {
+      cur = next;
+    }
   }
 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/concat_mla.cuh` around lines 193 - 210, The loop assigns
cur = next at the end of every iteration but next is not set on the final
unrolled iteration, causing undefined reads; fix by guarding that assignment or
initializing next: either initialize NopeVec next = cur; before the loop or
change the tail behavior so cur = next only happens when (i + 1 <
HEAD_CHUNK_SIZE). Update the loop in concat_mla.cuh (the software-pipelined
for-loop using HEAD_CHUNK_SIZE, next, cur, nope_src/nope_src_stride_v and
Traits::load_nope/Traits::store_nope) to ensure next is always valid before
assigning it to cur.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@csrc/concat_mla.cu`:
- Around line 101-105: The launcher currently only checks last-dim contiguity
but must also reject inputs whose element strides cannot be represented by the
kernel's vectorized strides; before calling
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 and launching ConcatMLAK, add validation
that k_stride_1, k_nope_stride_1, and k_rope_stride_1 are divisible by the
vector lane width (the shifts used in the kernel, e.g. >>2 or >>1) and that the
resulting vector strides (k_stride_1 >> N, etc.) fit into the kernel's expected
integer range; if any fail, return/error out with a clear message instead of
dispatching. Ensure you reference the same stride variables (k_stride_0,
k_stride_1, k_nope_stride_0, k_nope_stride_1, k_rope_stride_0, k_rope_stride_1)
and perform these checks immediately before invoking
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8/ConcatMLAK.

In `@include/flashinfer/concat_mla.cuh`:
- Around line 193-210: The loop assigns cur = next at the end of every iteration
but next is not set on the final unrolled iteration, causing undefined reads;
fix by guarding that assignment or initializing next: either initialize NopeVec
next = cur; before the loop or change the tail behavior so cur = next only
happens when (i + 1 < HEAD_CHUNK_SIZE). Update the loop in concat_mla.cuh (the
software-pipelined for-loop using HEAD_CHUNK_SIZE, next, cur,
nope_src/nope_src_stride_v and Traits::load_nope/Traits::store_nope) to ensure
next is always valid before assigning it to cur.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5fc9a351-e174-4438-b30a-5a6a9c55104c

📥 Commits

Reviewing files that changed from the base of the PR and between 8a9970b and 4d93ec2.

📒 Files selected for processing (5)
  • csrc/concat_mla.cu
  • flashinfer/concat_ops.py
  • include/flashinfer/concat_mla.cuh
  • include/flashinfer/utils.cuh
  • tests/utils/test_concat_mla.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
include/flashinfer/concat_mla.cuh (1)

46-120: Consider collapsing the four traits specializations into two (or one if constexpr block).

The nv_half and nv_bfloat16 specializations are byte-for-byte identical, and the __nv_fp8_e4m3 / __nv_fp8_e5m2 specializations are likewise identical. Since the traits only depend on sizeof(DType), one generic primary template (or a shared base keyed on element size) would remove ~40 lines of duplication and matches the "if constexpr for compile-time vector type selection" design mentioned in the PR description. Something along these lines:

♻️ Example consolidation
template <typename DType>
struct ConcatMLAVecTraits {
  static_assert(sizeof(DType) == 1 || sizeof(DType) == 2,
                "ConcatMLAVecTraits only supports 1B (FP8) or 2B (FP16/BF16) DTypes");
  using NopeVec = std::conditional_t<sizeof(DType) == 2, int2, int>;
  using RopeVec = std::conditional_t<sizeof(DType) == 2, int, short>;

  static __forceinline__ __device__ NopeVec load_nope(const NopeVec* ptr) {
    if constexpr (sizeof(DType) == 2) return ld_na_global_v2(ptr);
    else                              return ld_na_global_v1(ptr);
  }
  static __forceinline__ __device__ RopeVec load_rope(const RopeVec* ptr) {
    if constexpr (sizeof(DType) == 2) return ld_na_global_v1(ptr);
    else                              return ld_na_global_s16(ptr);
  }
  static __forceinline__ __device__ void store_nope(NopeVec* ptr, NopeVec v) {
    if constexpr (sizeof(DType) == 2) st_na_global_v2(ptr, v);
    else                              st_na_global_v1(ptr, v);
  }
  static __forceinline__ __device__ void store_rope(RopeVec* ptr, RopeVec v) {
    if constexpr (sizeof(DType) == 2) st_na_global_v1(ptr, v);
    else                              st_na_global_s16(ptr, v);
  }
};

Not blocking — purely a DRY/maintainability win; the static asserts in ConcatMLAKKernel already guard against mismatches.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/concat_mla.cuh` around lines 46 - 120, The four
nearly-identical specializations of ConcatMLAVecTraits (for nv_half,
nv_bfloat16, __nv_fp8_e4m3, __nv_fp8_e5m2) should be collapsed into a single
templated ConcatMLAVecTraits<DType> that selects NopeVec/RopeVec and the correct
ld_/st_ helpers via sizeof(DType) and if constexpr; update the type aliases
(NopeVec, RopeVec) and the four methods (load_nope, load_rope, store_nope,
store_rope) to use std::conditional_t and if constexpr to call ld_na_global_v2
vs ld_na_global_v1 and ld_na_global_s16 accordingly, and add a
static_assert(sizeof(DType)==1 || sizeof(DType)==2) to preserve the existing
guard used by ConcatMLAKernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/tvm_ffi_utils.h`:
- Around line 169-181: The macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 has
misaligned trailing backslashes on the
_DISPATCH_CASE_F16/_DISPATCH_CASE_BF16/_DISPATCH_CASE_FP8_* lines; run
clang-format (or manually adjust whitespace) on csrc/tvm_ffi_utils.h to
normalize indentation so the trailing '\' of those _DISPATCH_CASE_* lines aligns
with the other lines in the macro block, then re-run pre-commit to ensure the
file is formatted consistently.

---

Nitpick comments:
In `@include/flashinfer/concat_mla.cuh`:
- Around line 46-120: The four nearly-identical specializations of
ConcatMLAVecTraits (for nv_half, nv_bfloat16, __nv_fp8_e4m3, __nv_fp8_e5m2)
should be collapsed into a single templated ConcatMLAVecTraits<DType> that
selects NopeVec/RopeVec and the correct ld_/st_ helpers via sizeof(DType) and if
constexpr; update the type aliases (NopeVec, RopeVec) and the four methods
(load_nope, load_rope, store_nope, store_rope) to use std::conditional_t and if
constexpr to call ld_na_global_v2 vs ld_na_global_v1 and ld_na_global_s16
accordingly, and add a static_assert(sizeof(DType)==1 || sizeof(DType)==2) to
preserve the existing guard used by ConcatMLAKernel.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1fb03bda-3139-4553-9c9a-f8e0644e7e05

📥 Commits

Reviewing files that changed from the base of the PR and between 4d93ec2 and 9a95c0c.

📒 Files selected for processing (3)
  • csrc/concat_mla.cu
  • csrc/tvm_ffi_utils.h
  • include/flashinfer/concat_mla.cuh

Comment thread csrc/tvm_ffi_utils.h
…fier order

- Move DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8 from concat_mla.cu to
  tvm_ffi_utils.h so it can be reused by other operators
- Normalize static member function qualifier order to
  `static __forceinline__ __device__` (consistent with codebase style)
Signed-off-by: Albert Cheng <albecheng@nvidia.com>
@qiching qiching force-pushed the fix/concat-mla-k-fp8-support branch from 9a95c0c to 896d3ac Compare April 21, 2026 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants