Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
if not is_lib_missing:
raise
from flashinfer.fp4_quantization import nvfp4_quantize_paged_kv_cache
from flashinfer.prefill import trtllm_fmha_v2_prefill
from flashinfer.testing.utils import (
attention_tb_per_sec_with_actual_seq_lens,
attention_tflops_per_sec_with_actual_seq_lens,
Expand Down Expand Up @@ -111,6 +112,7 @@ def parse_attention_args(line, parser):
"cutlass",
"trtllm-gen",
"trtllm-native",
"trtllm-fmha-v2",
"trtllm-gen-native", # Deprecated, will be removed in future
"cute-dsl",
],
Expand Down Expand Up @@ -936,6 +938,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
remove_trtllm_native = True
if remove_trtllm_native:
backends.remove("trtllm-native")
if "trtllm-fmha-v2" in backends and is_nvfp4_kv:
print("[INFO] trtllm-fmha-v2 backend does not support NVFP4. Skipping.")
backends.remove("trtllm-fmha-v2")

if "cutlass" in backends:
print("[INFO] CUTLASS backend does not support prefill. Skipping.")
Expand Down Expand Up @@ -1072,7 +1077,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
.to(device)
)

# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
# Page-based indptr for FlashInfer paged attention (cumulative page counts)
kv_indptr = (
torch.cat(
[
Expand All @@ -1086,6 +1091,17 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
.int()
.to(device)
)
# Token-based indptr for TRT-LLM backends (cumulative token counts)
kv_token_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_kv_device.flatten(), dim=0),
]
)
.int()
.to(device)
)
kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32)
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
Expand Down Expand Up @@ -1158,6 +1174,16 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
v_quantized, _ = to_float8(v_data, kv_dtype)
kv_cache = torch.cat([k_quantized, v_quantized], dim=1)

_fmha_v2_bmm2_scale = v_scale if v_scale is not None else 1.0

# Ensure trtllm-fmha-v2 sees contiguous HND-physical paged KV cache.
# Skip if kv_cache is not a plain Tensor (e.g., NVFP4 packed tuple).
# backend filter further down also drops trtllm-fmha-v2 in that case.
if "trtllm-fmha-v2" in backends and isinstance(kv_cache, torch.Tensor):
_fmha_v2_kv_cache = kv_cache.contiguous()
Comment thread
jimmyzho marked this conversation as resolved.
else:
_fmha_v2_kv_cache = kv_cache

# Prepare wrappers (after FP8 conversion so we have correct dtypes)
backend_wrappers = {}
resolved_backends = {}
Expand Down Expand Up @@ -1304,6 +1330,25 @@ def run_backend_wrapper(
v_scale=v_scale_tensor,
o_data_type=o_data_type,
)[0]
elif backend == "trtllm-fmha-v2":
_q_scale = q_scale if q_scale is not None else 1.0
_k_scale = k_scale if k_scale is not None else 1.0
return trtllm_fmha_v2_prefill(
qkv=(q, _fmha_v2_kv_cache),
input_layout="Q_PAGED_KV_HND",
workspace_buffer=workspace_buffer,
seq_lens=actual_seq_lens_kv_device.flatten(),
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=_q_scale * _k_scale * scale,
bmm2_scale=_fmha_v2_bmm2_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_token_indptr,
block_tables=block_tables,
mask_mode="causal" if causal else "padding",
out_dtype=o_data_type,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
print(f"[ERROR] Backend {backend} not supported")
return None
Expand Down Expand Up @@ -1366,9 +1411,15 @@ def run_backend_wrapper(
tested_outputs = list(outputs.values())

# When cases where FA2 is not available, try to find an alternative reference
# Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native
# Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native > trtllm-fmha-v2
if run_refcheck and not has_reference_output and len(tested_backends) > 1:
reference_priority = ["cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"]
reference_priority = [
"cudnn",
"cudnn-native",
"trtllm-gen",
"trtllm-native",
"trtllm-fmha-v2",
]
for candidate in reference_priority:
if candidate in tested_backends:
has_reference_output = True
Expand Down Expand Up @@ -1598,6 +1649,12 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
remove_trtllm_native = True
if remove_trtllm_native:
backends.remove("trtllm-native")
if "trtllm-fmha-v2" in backends and q_dtype == torch.float8_e4m3fn:
print(
"[INFO] trtllm-fmha-v2 backend does not support FP8 e4m3 with "
"SEPARATE_Q_K_V layout. Skipping."
)
backends.remove("trtllm-fmha-v2")

if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
Expand Down Expand Up @@ -1836,6 +1893,8 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
k = (k / k_scale).to(kv_dtype)
v = (v / v_scale).to(kv_dtype)

_fmha_v2_bmm2_scale = v_scale if v_scale is not None else 1.0

trtllm_out = None
if "trtllm-native" in backends or "cute-dsl" in backends:
# cute-dsl varlen kernel uses negative pointer offsets on output,
Expand Down Expand Up @@ -1944,6 +2003,24 @@ def run_backend_wrapper(
return_lse=True,
out=trtllm_out,
)[0]
elif backend == "trtllm-fmha-v2":
_q_scale = q_scale if q_scale is not None else 1.0
_k_scale = k_scale if k_scale is not None else 1.0
return trtllm_fmha_v2_prefill(
qkv=(q, k, v),
input_layout="SEPARATE_Q_K_V",
workspace_buffer=workspace_buffer,
seq_lens=actual_seq_lens_kv_device.flatten(),
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=_q_scale * _k_scale * scale,
bmm2_scale=_fmha_v2_bmm2_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
mask_mode="causal" if causal else "padding",
out_dtype=out_dtype,
)
Comment on lines +2006 to +2023
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Filter out FP8 queries before the ragged FMHA-v2 call.

trtllm_fmha_v2_prefill(input_layout="SEPARATE_Q_K_V") explicitly rejects torch.float8_e4m3fn queries, but this routine still allows that configuration to reach the new branch. That turns a benchmark flag combination into a runtime ValueError instead of a clean backend skip.

πŸ›‘οΈ Suggested guard
     if "trtllm-native" in backends:
         remove_trtllm_native = False
         if not (head_dim_qk == 192 and head_dim_vo == 128) and not (
             head_dim_qk == 128 and head_dim_vo == 128
         ):
             print(
                 "[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128 or head_dim_qk == 128 and head_dim_vo == 128. Skipping."
             )
             remove_trtllm_native = True
         if remove_trtllm_native:
             backends.remove("trtllm-native")
+    if "trtllm-fmha-v2" in backends and q_dtype == torch.float8_e4m3fn:
+        print(
+            "[INFO] trtllm-fmha-v2 does not support FP8 query with SEPARATE_Q_K_V. Skipping."
+        )
+        backends.remove("trtllm-fmha-v2")
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1947 - 1964, The
trtllm-fmha-v2 branch must filter out FP8 query tensors before calling
trtllm_fmha_v2_prefill because that function rejects torch.float8_e4m3fn and
currently causes a runtime ValueError; add a guard at the start of the backend
== "trtllm-fmha-v2" block that checks q.dtype (and/or k.dtype/v.dtype as
appropriate) for torch.float8_e4m3fn and short-circuits the branch (e.g., return
None or otherwise skip this backend) so the benchmark treats this combination as
an unsupported backend instead of letting trtllm_fmha_v2_prefill raise. Ensure
you reference the existing variables q, k, v and the call to
trtllm_fmha_v2_prefill when inserting the guard.

else:
print(f"[ERROR] Backend {backend} not supported")
return None
Expand Down
10 changes: 6 additions & 4 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,25 +318,27 @@ def dtype_str_to_torch_dtype(dtype_str):
},
"BatchPrefillWithPagedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache
# NOTE: trtllm-fmha-v2 calls trtllm_fmha_v2_prefill
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "auto", "cudnn", "cudnn-native"],
"8.6": ["fa2", "auto", "cudnn", "cudnn-native"],
"8.9": ["fa2", "auto", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "auto", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "auto", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
"10.0": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "auto", "cudnn", "cudnn-native"],
"12.0": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
Comment thread
jimmyzho marked this conversation as resolved.
"12.1": ["fa2", "auto", "cudnn", "cudnn-native"],
},
"BatchPrefillWithRaggedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_ragged_attention_deepseek
# NOTE: trtllm-fmha-v2 calls trtllm_fmha_v2_prefill
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "cudnn", "cudnn-native"],
"8.6": ["fa2", "cudnn", "cudnn-native"],
"8.9": ["fa2", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
"10.0": [
"fa2",
"cudnn",
Expand All @@ -353,7 +355,7 @@ def dtype_str_to_torch_dtype(dtype_str):
"cute-dsl",
"trtllm-native",
],
"12.0": ["fa2", "cudnn", "cudnn-native"],
"12.0": ["fa2", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
"12.1": ["fa2", "cudnn", "cudnn-native"],
},
"BatchMLAPagedAttentionWrapper": {
Expand Down
34 changes: 26 additions & 8 deletions csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -863,11 +863,21 @@ struct Gmem_tile_paged_kv {
paged_kv_log2_block_size_(params.paged_kv_cache.mTokensPerBlockLog2),
paged_kv_block_pool_ptr_(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr)),
paged_kv_global_block_offsets_(params.paged_kv_cache.mBlockOffsets),
params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock) {
// Handle Paged KV with shape [S, Dh], by offsetting it to the target batch.
int32_t const paged_kv_block_offset =
(binfo.bidb * 2 + qkv_offset - 1) * params.paged_kv_cache.mMaxBlocksPerSeq;
paged_kv_global_block_offsets_ += paged_kv_block_offset;
params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock),
uses_shared_paged_kv_idx_(params.paged_kv_cache.mUsesSharedPagedKvIdx),
kv_type_(qkv_offset - 1) {
// Handle Paged KV by offsetting the block offsets pointer to the target batch.
if (uses_shared_paged_kv_idx_) {
// Shared [B, M] layout: one set of page indices for both K and V.
// The kv_type_ (0=K, 1=V) is applied at load time via page_idx*2+kv_type_.
int32_t const paged_kv_block_offset = binfo.bidb * params.paged_kv_cache.mMaxBlocksPerSeq;
paged_kv_global_block_offsets_ += paged_kv_block_offset;
} else {
// Pre-expanded [B, 2, M] layout: separate K and V offset arrays.
int32_t const paged_kv_block_offset =
(binfo.bidb * 2 + qkv_offset - 1) * params.paged_kv_cache.mMaxBlocksPerSeq;
paged_kv_global_block_offsets_ += paged_kv_block_offset;
}

// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
Expand Down Expand Up @@ -915,9 +925,13 @@ struct Gmem_tile_paged_kv {
for (int ii = 0; ii < LDGS; ++ii) {
int row_idx = row_ + ii * (int)ROWS_PER_LDG;
int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_);
char const* local_kv_ptr = reinterpret_cast<char*>(
paged_kv_block_pool_ptr_ +
params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx]);
int32_t pool_idx = paged_kv_global_block_offsets_[paged_kv_block_idx];
// Shared page index: transform logical page β†’ interleaved K/V pool offset.
if (uses_shared_paged_kv_idx_) {
pool_idx = pool_idx * 2 + kv_type_;
}
char const* local_kv_ptr = reinterpret_cast<char*>(paged_kv_block_pool_ptr_ +
params_kv_block_size_in_bytes_ * pool_idx);

// Predicates.
// TODO: do we need to make sure row_idx < ROWS ?
Expand Down Expand Up @@ -958,6 +972,10 @@ struct Gmem_tile_paged_kv {
int32_t* paged_kv_global_block_offsets_;
// The paged block size.
int paged_kv_log2_block_size_;
// Whether block offsets use shared [B,M] format (true) vs pre-expanded [B,2,M] (false).
bool uses_shared_paged_kv_idx_;
// 0 for K, 1 for V. Used with shared page indices to compute pool offset = page_idx*2+kv_type_.
int kv_type_;
// The register to store predicates.
uint32_t preds_[PRED_REGS];
// The fetch registers.
Expand Down
21 changes: 16 additions & 5 deletions csrc/fmha_v2/fmha/paged_kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,27 @@ struct Kv_block_array {
// E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6
int32_t mTokensPerBlockLog2;
// Table maps logical block idx to the data pointer of k/v cache block pool
// Shape [B, W, 2, M], where 2 is table for K and V,
// B is current number of sequences
// W is beam width
// M is Max number of blocks per sequence
//
// When mUsesSharedPagedKvIdx is false (default, TRT-LLM native format):
// Shape [B, W, 2, M], where 2 is table for K and V,
// B is current number of sequences, W is beam width,
// M is max number of blocks per sequence.
//
// When mUsesSharedPagedKvIdx is true (FlashInfer interleaved KV pool format):
// Shape [B, M] containing logical page indices.
// K and V share the same index; the kernel computes pool offsets on-the-fly as:
// K pool offset = page_idx * 2
// V pool offset = page_idx * 2 + 1

// Size of KV cache blocks in bytes (H*D*T*sizeof(DataType))
int32_t mBytesPerBlock;
// Pointer to beginning of pool.
void* mPoolPtr;
// Pointer to block offsets.
PtrType* mBlockOffsets;
// When true, mBlockOffsets is [B, M] with shared K/V page indices that need
// the *2/+1 transform, instead of pre-expanded [B, 2, M] separate offsets.
bool mUsesSharedPagedKvIdx;

Kv_block_array() = default;

Expand All @@ -52,7 +62,8 @@ struct Kv_block_array {
mTokensPerBlock(tokensPerBlock),
mBytesPerBlock{bytesPerBlock},
mPoolPtr{poolPtr},
mBlockOffsets{nullptr} {
mBlockOffsets{nullptr},
mUsesSharedPagedKvIdx{false} {
float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock);
mTokensPerBlockLog2 = static_cast<int>(tokensPerBlockSeqLog2);
}
Expand Down
31 changes: 21 additions & 10 deletions csrc/fmha_v2/fmha/warpspec/dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,10 @@ struct DMA {
cudaTmaDesc const* desc_k = &params.tma_desc_k;
cudaTmaDesc const* desc_v = &params.tma_desc_v;

bool const shared_kv_idx = params.paged_kv_cache.mUsesSharedPagedKvIdx;
int32_t const* paged_block_offsets =
params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq;
params.paged_kv_cache.mBlockOffsets +
bidb * (shared_kv_idx ? 1 : 2) * params.paged_kv_cache.mMaxBlocksPerSeq;

if (SCHEDULING_MODE == 0) {
// split work across M
Expand Down Expand Up @@ -416,11 +418,12 @@ struct DMA {
int bar_id;
// Load paged kv input.
if constexpr (PAGED_KV_INPUT) {
bar_id = load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks,
params.paged_kv_cache.mTokensPerBlockLog2,
params.blocks_per_tma_load, params.blocks_per_tma_load_log2,
params.paged_kv_cache.mMaxBlocksPerSeq, paged_block_offsets,
desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
bar_id =
load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks,
params.paged_kv_cache.mTokensPerBlockLog2,
params.blocks_per_tma_load, params.blocks_per_tma_load_log2,
params.paged_kv_cache.mMaxBlocksPerSeq, paged_block_offsets,
shared_kv_idx, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch);
} else {
bar_id = load_kv(bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v,
cbw_v_scratch);
Expand Down Expand Up @@ -545,7 +548,7 @@ struct DMA {
int num_valid_kv_blocks, int tokens_per_block_log2,
int blocks_per_tma_load, int blocks_per_tma_load_log2,
int max_blocks_per_sequence,
int32_t const* paged_block_offsets,
int32_t const* paged_block_offsets, bool shared_kv_idx,
cudaTmaDesc const* desc_k, cudaTmaDesc const* desc_v,
Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v,
BufferWriterScratch& cbw_v_scratch) {
Expand All @@ -562,9 +565,17 @@ struct DMA {
for (int bi = 0; bi < blocks_per_tma_load; ++bi) {
int const bounded_block_idx = min(num_valid_kv_blocks - 1, paged_kv_block_idx + bi);

const int32_t k_paged_block_offset = paged_block_offsets[bounded_block_idx];
const int32_t v_paged_block_offset =
paged_block_offsets[max_blocks_per_sequence + bounded_block_idx];
int32_t k_paged_block_offset, v_paged_block_offset;
if (shared_kv_idx) {
// Shared [B, M] layout: transform logical page index to interleaved pool offsets.
int32_t page_idx = paged_block_offsets[bounded_block_idx];
k_paged_block_offset = page_idx * 2;
v_paged_block_offset = page_idx * 2 + 1;
} else {
// Pre-expanded [B, 2, M] layout: K and V offsets are stored separately.
k_paged_block_offset = paged_block_offsets[bounded_block_idx];
v_paged_block_offset = paged_block_offsets[max_blocks_per_sequence + bounded_block_idx];
}

#pragma unroll
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) {
Expand Down
10 changes: 9 additions & 1 deletion csrc/fmha_v2/fmha/warpspec/epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -1131,9 +1131,17 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
enum { EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION };

// Ctor.
// DIVERGENCE from upstream TRT-LLM: read the host-encoded scalar
// params.scale_bmm2 (uint32 set_alpha output) instead of dereferencing
// params.scale_bmm2_d. This makes the epilogue safe under CUDA graph capture
// (the value travels through the kernel-arg buffer, captured by value) and
// lets the binding stop allocating a workspace slot + issuing a
// pageable-host cudaMemcpyAsync. Other epilogues (hopper/, gmem_tile_o*.h)
// already read params.scale_bmm2 directly or via a `d ? *d : scale_bmm2`
// fallback, so this change is local to the SM90 WS path.
template <typename Params, typename Block_info>
inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info)
: Base(params, block_info), scale_bmm2_(*params.scale_bmm2_d) {}
: Base(params, block_info), scale_bmm2_(params.scale_bmm2) {}

// Add the attention sink to the global sum.
inline __device__ void add_attention_sink(float& sum, float max) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/fmha_v2_jit_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ void fmha_v2_run(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, ffi::T
int max_q_len, int max_kv_len, int batch_size, const std::string& mask_mode_str,
float scale_softmax, float scale_bmm1, float scale_bmm2, int window_left,
int chunked_attention_size, bool has_alibi, float softcapping_scale,
float skip_softmax_threshold_scale_factor, ffi::TensorView scale_bmm2_d,
Optional<ffi::TensorView> softmax_stats, Optional<ffi::TensorView> sinks);
float skip_softmax_threshold_scale_factor, Optional<ffi::TensorView> softmax_stats,
Optional<ffi::TensorView> sinks);

// FMHAv2 attention operator
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, fmha_v2_run);
Loading
Loading