diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 2fdc9013ac..9befc19ce6 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -22,6 +22,7 @@ if not is_lib_missing: raise from flashinfer.fp4_quantization import nvfp4_quantize_paged_kv_cache +from flashinfer.prefill import trtllm_fmha_v2_prefill from flashinfer.testing.utils import ( attention_tb_per_sec_with_actual_seq_lens, attention_tflops_per_sec_with_actual_seq_lens, @@ -111,6 +112,7 @@ def parse_attention_args(line, parser): "cutlass", "trtllm-gen", "trtllm-native", + "trtllm-fmha-v2", "trtllm-gen-native", # Deprecated, will be removed in future "cute-dsl", ], @@ -936,6 +938,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): remove_trtllm_native = True if remove_trtllm_native: backends.remove("trtllm-native") + if "trtllm-fmha-v2" in backends and is_nvfp4_kv: + print("[INFO] trtllm-fmha-v2 backend does not support NVFP4. Skipping.") + backends.remove("trtllm-fmha-v2") if "cutlass" in backends: print("[INFO] CUTLASS backend does not support prefill. Skipping.") @@ -1072,7 +1077,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): .to(device) ) - # Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr + # Page-based indptr for FlashInfer paged attention (cumulative page counts) kv_indptr = ( torch.cat( [ @@ -1086,6 +1091,17 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): .int() .to(device) ) + # Token-based indptr for TRT-LLM backends (cumulative token counts) + kv_token_indptr = ( + torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_kv_device.flatten(), dim=0), + ] + ) + .int() + .to(device) + ) kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] @@ -1158,6 +1174,16 @@ def to_float8(x, dtype=torch.float8_e4m3fn): v_quantized, _ = to_float8(v_data, kv_dtype) kv_cache = torch.cat([k_quantized, v_quantized], dim=1) + _fmha_v2_bmm2_scale = v_scale if v_scale is not None else 1.0 + + # Ensure trtllm-fmha-v2 sees contiguous HND-physical paged KV cache. + # Skip if kv_cache is not a plain Tensor (e.g., NVFP4 packed tuple). + # backend filter further down also drops trtllm-fmha-v2 in that case. + if "trtllm-fmha-v2" in backends and isinstance(kv_cache, torch.Tensor): + _fmha_v2_kv_cache = kv_cache.contiguous() + else: + _fmha_v2_kv_cache = kv_cache + # Prepare wrappers (after FP8 conversion so we have correct dtypes) backend_wrappers = {} resolved_backends = {} @@ -1304,6 +1330,25 @@ def run_backend_wrapper( v_scale=v_scale_tensor, o_data_type=o_data_type, )[0] + elif backend == "trtllm-fmha-v2": + _q_scale = q_scale if q_scale is not None else 1.0 + _k_scale = k_scale if k_scale is not None else 1.0 + return trtllm_fmha_v2_prefill( + qkv=(q, _fmha_v2_kv_cache), + input_layout="Q_PAGED_KV_HND", + workspace_buffer=workspace_buffer, + seq_lens=actual_seq_lens_kv_device.flatten(), + max_q_len=s_qo, + max_kv_len=s_kv, + bmm1_scale=_q_scale * _k_scale * scale, + bmm2_scale=_fmha_v2_bmm2_scale, + batch_size=batch_size, + cum_seq_lens_q=qo_indptr, + cum_seq_lens_kv=kv_token_indptr, + block_tables=block_tables, + mask_mode="causal" if causal else "padding", + out_dtype=o_data_type, + ) else: print(f"[ERROR] Backend {backend} not supported") return None @@ -1366,9 +1411,15 @@ def run_backend_wrapper( tested_outputs = list(outputs.values()) # When cases where FA2 is not available, try to find an alternative reference - # Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native + # Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native > trtllm-fmha-v2 if run_refcheck and not has_reference_output and len(tested_backends) > 1: - reference_priority = ["cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"] + reference_priority = [ + "cudnn", + "cudnn-native", + "trtllm-gen", + "trtllm-native", + "trtllm-fmha-v2", + ] for candidate in reference_priority: if candidate in tested_backends: has_reference_output = True @@ -1598,6 +1649,12 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): remove_trtllm_native = True if remove_trtllm_native: backends.remove("trtllm-native") + if "trtllm-fmha-v2" in backends and q_dtype == torch.float8_e4m3fn: + print( + "[INFO] trtllm-fmha-v2 backend does not support FP8 e4m3 with " + "SEPARATE_Q_K_V layout. Skipping." + ) + backends.remove("trtllm-fmha-v2") if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") @@ -1836,6 +1893,8 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): k = (k / k_scale).to(kv_dtype) v = (v / v_scale).to(kv_dtype) + _fmha_v2_bmm2_scale = v_scale if v_scale is not None else 1.0 + trtllm_out = None if "trtllm-native" in backends or "cute-dsl" in backends: # cute-dsl varlen kernel uses negative pointer offsets on output, @@ -1944,6 +2003,24 @@ def run_backend_wrapper( return_lse=True, out=trtllm_out, )[0] + elif backend == "trtllm-fmha-v2": + _q_scale = q_scale if q_scale is not None else 1.0 + _k_scale = k_scale if k_scale is not None else 1.0 + return trtllm_fmha_v2_prefill( + qkv=(q, k, v), + input_layout="SEPARATE_Q_K_V", + workspace_buffer=workspace_buffer, + seq_lens=actual_seq_lens_kv_device.flatten(), + max_q_len=s_qo, + max_kv_len=s_kv, + bmm1_scale=_q_scale * _k_scale * scale, + bmm2_scale=_fmha_v2_bmm2_scale, + batch_size=batch_size, + cum_seq_lens_q=qo_indptr, + cum_seq_lens_kv=kv_indptr, + mask_mode="causal" if causal else "padding", + out_dtype=out_dtype, + ) else: print(f"[ERROR] Backend {backend} not supported") return None diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 2bf9916e61..cc3ef65b2e 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -318,25 +318,27 @@ def dtype_str_to_torch_dtype(dtype_str): }, "BatchPrefillWithPagedKVCacheWrapper": { # NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache + # NOTE: trtllm-fmha-v2 calls trtllm_fmha_v2_prefill # NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache "7.5": [], "8.0": ["fa2", "auto", "cudnn", "cudnn-native"], "8.6": ["fa2", "auto", "cudnn", "cudnn-native"], "8.9": ["fa2", "auto", "cudnn", "cudnn-native"], - "9.0": ["fa2", "fa3", "auto", "cudnn", "cudnn-native"], + "9.0": ["fa2", "fa3", "auto", "cudnn", "cudnn-native", "trtllm-fmha-v2"], "10.0": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"], "10.3": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"], - "12.0": ["fa2", "auto", "cudnn", "cudnn-native"], + "12.0": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-fmha-v2"], "12.1": ["fa2", "auto", "cudnn", "cudnn-native"], }, "BatchPrefillWithRaggedKVCacheWrapper": { # NOTE: trtllm-native calls trtllm_ragged_attention_deepseek + # NOTE: trtllm-fmha-v2 calls trtllm_fmha_v2_prefill # NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache "7.5": [], "8.0": ["fa2", "cudnn", "cudnn-native"], "8.6": ["fa2", "cudnn", "cudnn-native"], "8.9": ["fa2", "cudnn", "cudnn-native"], - "9.0": ["fa2", "fa3", "cudnn", "cudnn-native"], + "9.0": ["fa2", "fa3", "cudnn", "cudnn-native", "trtllm-fmha-v2"], "10.0": [ "fa2", "cudnn", @@ -353,7 +355,7 @@ def dtype_str_to_torch_dtype(dtype_str): "cute-dsl", "trtllm-native", ], - "12.0": ["fa2", "cudnn", "cudnn-native"], + "12.0": ["fa2", "cudnn", "cudnn-native", "trtllm-fmha-v2"], "12.1": ["fa2", "cudnn", "cudnn-native"], }, "BatchMLAPagedAttentionWrapper": { diff --git a/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h index 00797d0a01..e875ad978b 100644 --- a/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h +++ b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h @@ -863,11 +863,21 @@ struct Gmem_tile_paged_kv { paged_kv_log2_block_size_(params.paged_kv_cache.mTokensPerBlockLog2), paged_kv_block_pool_ptr_(reinterpret_cast(params.paged_kv_cache.mPoolPtr)), paged_kv_global_block_offsets_(params.paged_kv_cache.mBlockOffsets), - params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock) { - // Handle Paged KV with shape [S, Dh], by offsetting it to the target batch. - int32_t const paged_kv_block_offset = - (binfo.bidb * 2 + qkv_offset - 1) * params.paged_kv_cache.mMaxBlocksPerSeq; - paged_kv_global_block_offsets_ += paged_kv_block_offset; + params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock), + uses_shared_paged_kv_idx_(params.paged_kv_cache.mUsesSharedPagedKvIdx), + kv_type_(qkv_offset - 1) { + // Handle Paged KV by offsetting the block offsets pointer to the target batch. + if (uses_shared_paged_kv_idx_) { + // Shared [B, M] layout: one set of page indices for both K and V. + // The kv_type_ (0=K, 1=V) is applied at load time via page_idx*2+kv_type_. + int32_t const paged_kv_block_offset = binfo.bidb * params.paged_kv_cache.mMaxBlocksPerSeq; + paged_kv_global_block_offsets_ += paged_kv_block_offset; + } else { + // Pre-expanded [B, 2, M] layout: separate K and V offset arrays. + int32_t const paged_kv_block_offset = + (binfo.bidb * 2 + qkv_offset - 1) * params.paged_kv_cache.mMaxBlocksPerSeq; + paged_kv_global_block_offsets_ += paged_kv_block_offset; + } // Compute the position in the sequence (within the CTA for the moment). int row = tidx / THREADS_PER_ROW; @@ -915,9 +925,13 @@ struct Gmem_tile_paged_kv { for (int ii = 0; ii < LDGS; ++ii) { int row_idx = row_ + ii * (int)ROWS_PER_LDG; int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_); - char const* local_kv_ptr = reinterpret_cast( - paged_kv_block_pool_ptr_ + - params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx]); + int32_t pool_idx = paged_kv_global_block_offsets_[paged_kv_block_idx]; + // Shared page index: transform logical page → interleaved K/V pool offset. + if (uses_shared_paged_kv_idx_) { + pool_idx = pool_idx * 2 + kv_type_; + } + char const* local_kv_ptr = reinterpret_cast(paged_kv_block_pool_ptr_ + + params_kv_block_size_in_bytes_ * pool_idx); // Predicates. // TODO: do we need to make sure row_idx < ROWS ? @@ -958,6 +972,10 @@ struct Gmem_tile_paged_kv { int32_t* paged_kv_global_block_offsets_; // The paged block size. int paged_kv_log2_block_size_; + // Whether block offsets use shared [B,M] format (true) vs pre-expanded [B,2,M] (false). + bool uses_shared_paged_kv_idx_; + // 0 for K, 1 for V. Used with shared page indices to compute pool offset = page_idx*2+kv_type_. + int kv_type_; // The register to store predicates. uint32_t preds_[PRED_REGS]; // The fetch registers. diff --git a/csrc/fmha_v2/fmha/paged_kv_cache.h b/csrc/fmha_v2/fmha/paged_kv_cache.h index a8e13a61d0..dab71ef975 100644 --- a/csrc/fmha_v2/fmha/paged_kv_cache.h +++ b/csrc/fmha_v2/fmha/paged_kv_cache.h @@ -31,10 +31,17 @@ struct Kv_block_array { // E.g. for mTokensPerBlock 64, mTokensPerBlockLog2 equals to 6 int32_t mTokensPerBlockLog2; // Table maps logical block idx to the data pointer of k/v cache block pool - // Shape [B, W, 2, M], where 2 is table for K and V, - // B is current number of sequences - // W is beam width - // M is Max number of blocks per sequence + // + // When mUsesSharedPagedKvIdx is false (default, TRT-LLM native format): + // Shape [B, W, 2, M], where 2 is table for K and V, + // B is current number of sequences, W is beam width, + // M is max number of blocks per sequence. + // + // When mUsesSharedPagedKvIdx is true (FlashInfer interleaved KV pool format): + // Shape [B, M] containing logical page indices. + // K and V share the same index; the kernel computes pool offsets on-the-fly as: + // K pool offset = page_idx * 2 + // V pool offset = page_idx * 2 + 1 // Size of KV cache blocks in bytes (H*D*T*sizeof(DataType)) int32_t mBytesPerBlock; @@ -42,6 +49,9 @@ struct Kv_block_array { void* mPoolPtr; // Pointer to block offsets. PtrType* mBlockOffsets; + // When true, mBlockOffsets is [B, M] with shared K/V page indices that need + // the *2/+1 transform, instead of pre-expanded [B, 2, M] separate offsets. + bool mUsesSharedPagedKvIdx; Kv_block_array() = default; @@ -52,7 +62,8 @@ struct Kv_block_array { mTokensPerBlock(tokensPerBlock), mBytesPerBlock{bytesPerBlock}, mPoolPtr{poolPtr}, - mBlockOffsets{nullptr} { + mBlockOffsets{nullptr}, + mUsesSharedPagedKvIdx{false} { float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock); mTokensPerBlockLog2 = static_cast(tokensPerBlockSeqLog2); } diff --git a/csrc/fmha_v2/fmha/warpspec/dma.h b/csrc/fmha_v2/fmha/warpspec/dma.h index 6934087270..bfcc91c16b 100644 --- a/csrc/fmha_v2/fmha/warpspec/dma.h +++ b/csrc/fmha_v2/fmha/warpspec/dma.h @@ -364,8 +364,10 @@ struct DMA { cudaTmaDesc const* desc_k = ¶ms.tma_desc_k; cudaTmaDesc const* desc_v = ¶ms.tma_desc_v; + bool const shared_kv_idx = params.paged_kv_cache.mUsesSharedPagedKvIdx; int32_t const* paged_block_offsets = - params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; + params.paged_kv_cache.mBlockOffsets + + bidb * (shared_kv_idx ? 1 : 2) * params.paged_kv_cache.mMaxBlocksPerSeq; if (SCHEDULING_MODE == 0) { // split work across M @@ -416,11 +418,12 @@ struct DMA { int bar_id; // Load paged kv input. if constexpr (PAGED_KV_INPUT) { - bar_id = load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks, - params.paged_kv_cache.mTokensPerBlockLog2, - params.blocks_per_tma_load, params.blocks_per_tma_load_log2, - params.paged_kv_cache.mMaxBlocksPerSeq, paged_block_offsets, - desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch); + bar_id = + load_paged_kv(bidh_kv, kv_step_idx * STEP_KV, num_valid_kv_blocks, + params.paged_kv_cache.mTokensPerBlockLog2, + params.blocks_per_tma_load, params.blocks_per_tma_load_log2, + params.paged_kv_cache.mMaxBlocksPerSeq, paged_block_offsets, + shared_kv_idx, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch); } else { bar_id = load_kv(bidh_kv, kv_step_idx * STEP_KV, desc_k, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch); @@ -545,7 +548,7 @@ struct DMA { int num_valid_kv_blocks, int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2, int max_blocks_per_sequence, - int32_t const* paged_block_offsets, + int32_t const* paged_block_offsets, bool shared_kv_idx, cudaTmaDesc const* desc_k, cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch) { @@ -562,9 +565,17 @@ struct DMA { for (int bi = 0; bi < blocks_per_tma_load; ++bi) { int const bounded_block_idx = min(num_valid_kv_blocks - 1, paged_kv_block_idx + bi); - const int32_t k_paged_block_offset = paged_block_offsets[bounded_block_idx]; - const int32_t v_paged_block_offset = - paged_block_offsets[max_blocks_per_sequence + bounded_block_idx]; + int32_t k_paged_block_offset, v_paged_block_offset; + if (shared_kv_idx) { + // Shared [B, M] layout: transform logical page index to interleaved pool offsets. + int32_t page_idx = paged_block_offsets[bounded_block_idx]; + k_paged_block_offset = page_idx * 2; + v_paged_block_offset = page_idx * 2 + 1; + } else { + // Pre-expanded [B, 2, M] layout: K and V offsets are stored separately. + k_paged_block_offset = paged_block_offsets[bounded_block_idx]; + v_paged_block_offset = paged_block_offsets[max_blocks_per_sequence + bounded_block_idx]; + } #pragma unroll for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) { diff --git a/csrc/fmha_v2/fmha/warpspec/epilogue.h b/csrc/fmha_v2/fmha/warpspec/epilogue.h index 40248a51cc..254c1810dc 100644 --- a/csrc/fmha_v2/fmha/warpspec/epilogue.h +++ b/csrc/fmha_v2/fmha/warpspec/epilogue.h @@ -1131,9 +1131,17 @@ struct Tile_o_epilogue enum { EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION }; // Ctor. + // DIVERGENCE from upstream TRT-LLM: read the host-encoded scalar + // params.scale_bmm2 (uint32 set_alpha output) instead of dereferencing + // params.scale_bmm2_d. This makes the epilogue safe under CUDA graph capture + // (the value travels through the kernel-arg buffer, captured by value) and + // lets the binding stop allocating a workspace slot + issuing a + // pageable-host cudaMemcpyAsync. Other epilogues (hopper/, gmem_tile_o*.h) + // already read params.scale_bmm2 directly or via a `d ? *d : scale_bmm2` + // fallback, so this change is local to the SM90 WS path. template inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info) - : Base(params, block_info), scale_bmm2_(*params.scale_bmm2_d) {} + : Base(params, block_info), scale_bmm2_(params.scale_bmm2) {} // Add the attention sink to the global sum. inline __device__ void add_attention_sink(float& sum, float max) { diff --git a/csrc/fmha_v2_jit_binding.cu b/csrc/fmha_v2_jit_binding.cu index f0eb500eec..d94d9bd208 100644 --- a/csrc/fmha_v2_jit_binding.cu +++ b/csrc/fmha_v2_jit_binding.cu @@ -32,8 +32,8 @@ void fmha_v2_run(ffi::TensorView q, ffi::TensorView k, ffi::TensorView v, ffi::T int max_q_len, int max_kv_len, int batch_size, const std::string& mask_mode_str, float scale_softmax, float scale_bmm1, float scale_bmm2, int window_left, int chunked_attention_size, bool has_alibi, float softcapping_scale, - float skip_softmax_threshold_scale_factor, ffi::TensorView scale_bmm2_d, - Optional softmax_stats, Optional sinks); + float skip_softmax_threshold_scale_factor, Optional softmax_stats, + Optional sinks); // FMHAv2 attention operator TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, fmha_v2_run); diff --git a/csrc/fmha_v2_run.cu b/csrc/fmha_v2_run.cu index 3dfff9b967..7fb3647f0d 100644 --- a/csrc/fmha_v2_run.cu +++ b/csrc/fmha_v2_run.cu @@ -198,8 +198,6 @@ static inline void set_params( } set_alpha(params.scale_softmax, scale_softmax, scale_softmax_type); set_alpha(params.scale_bmm2, scale_bmm2, scale_type2); - // NOTE: scale_bmm2_d is now pre-populated from Python to avoid cudaMemcpy synchronization. - // The Python side calls create_scale_bmm2_d_tensor() which replicates set_alpha logic. params.scale_bmm2_d = reinterpret_cast(scale_bmm2_d); params.softcapping_scale_bmm1 = softcapping_scale_bmm1; @@ -330,7 +328,6 @@ void fmha_v2_run( const std::string& mask_mode_str, float scale_softmax, float scale_bmm1, float scale_bmm2, int window_left, int chunked_attention_size, bool has_alibi, float softcapping_scale, float skip_softmax_threshold_scale_factor, - ffi::TensorView scale_bmm2_d, // Pre-populated scale_bmm2 on device [1] int32 Optional softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum) Optional sinks) { Attention_input_layout input_layout = string_to_input_layout(input_layout_str); @@ -520,17 +517,10 @@ void fmha_v2_run( ? allocator.aligned_alloc(packed_mask_size_in_bytes, 128, "packed_mask_d") : nullptr; - // NOTE: scale_bmm2_d is now passed as a pre-populated tensor from Python - // to avoid cudaMemcpy synchronization in set_params(). - // Softmax stats: stores (max, sum) per token, 2 floats per (b, s_q, h) - // Write directly to user-provided tensor when available, otherwise use workspace. - void* softmax_stats_ptr; + void* softmax_stats_ptr = nullptr; if (softmax_stats.has_value()) { softmax_stats_ptr = softmax_stats.value().data_ptr(); - } else { - const size_t softmax_stats_size = 2 * sizeof(float) * b * s_q * h; - softmax_stats_ptr = allocator.aligned_alloc(softmax_stats_size, 128, "softmax_stats_d"); } void* attention_sinks_d = sinks.has_value() ? sinks.value().data_ptr() : nullptr; @@ -543,8 +533,8 @@ void fmha_v2_run( void* kv_cache_pool_ptr = nullptr; int32_t* kv_cache_block_offsets_d = nullptr; - // For Q_PAGED_KV layout, block_tables is pre-expanded on the Python side from [B, M] to [B, 2, M] - // where [:, 0, :] contains K offsets and [:, 1, :] contains V offsets. + // For Q_PAGED_KV layout, block_tables has shape [B, M] containing logical page indices. + // The kernel transforms these to interleaved pool offsets (K=page*2, V=page*2+1) on-the-fly. int block_table_max_blocks = 0; switch (input_layout) { @@ -565,10 +555,9 @@ void fmha_v2_run( kv_cache_pool_ptr = k.data_ptr(); if (maybe_block_tables.has_value()) { - // block_tables is pre-expanded on Python side with shape [B, 2, M] - // where M is max_blocks_per_sequence + // block_tables has shape [B, M] with logical page indices ffi::TensorView block_tables = maybe_block_tables.value(); - block_table_max_blocks = block_tables.shape()[2]; // shape is [B, 2, M] + block_table_max_blocks = block_tables.shape()[1]; // shape is [B, M] kv_cache_block_offsets_d = static_cast(block_tables.data_ptr()); } } break; @@ -577,6 +566,10 @@ void fmha_v2_run( break; } + // The host-encoded scale_type2 uint32 lives in params.scale_bmm2 (set by + // set_alpha) and travels into kernels through the launch arg buffer — every + // epilogue reads it from there. params.scale_bmm2_d is left nullptr; the + // few non-WS sites that consult it use a `d ? *d : scale_bmm2` fallback. bert::Fused_multihead_attention_params_v2 params_v2; set_params(params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, b, s_q, s, h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size, @@ -585,14 +578,15 @@ void fmha_v2_run( kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d, static_cast(cum_seq_lens_kv.data_ptr()), static_cast(cum_seq_lens_q.data_ptr()), o.data_ptr(), nullptr, nullptr, - softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1, scale_softmax, scale_bmm2, + softmax_stats_ptr, /*scale_bmm2_d=*/nullptr, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, false, false, false, has_alibi, skip_softmax_threshold_scale_factor); - // For Q_PAGED_KV layout, override mMaxBlocksPerSeq to match the actual block_tables stride - // that we used when expanding the block offsets from [B, M] to [B, 2, M] + // For Q_PAGED_KV layout, override mMaxBlocksPerSeq to match the actual block_tables stride, + // and enable shared page index mode so the kernel transforms page_idx → pool offsets on-the-fly. if (input_layout == Attention_input_layout::Q_PAGED_KV && block_table_max_blocks > 0) { params_v2.paged_kv_cache.mMaxBlocksPerSeq = block_table_max_blocks; + params_v2.paged_kv_cache.mUsesSharedPagedKvIdx = true; } // Total number of Q tokens is needed to set TMA desc on the host. diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index d491dd35d9..39f5a6e6fc 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -101,45 +101,6 @@ def _split_scale_param(scale): return None, float(scale) -def _create_scale_bmm2_d_tensor( - scale_bmm2: float, data_dtype: torch.dtype, device: torch.device -) -> torch.Tensor: - """Create a scale_bmm2_d tensor with the correct bit pattern for the TRT-LLM FMHAv2 kernel. - - This function replicates the C++ set_alpha logic for scale_type2 to avoid - cudaMemcpy synchronization in the kernel. The scale value is converted to - the appropriate floating-point format and stored as int32 bits on device. - - The scale_type2 logic (from C++): - - FP16 input -> scale stored as FP16 bits in lower 16 bits of uint32 - - BF16 input -> scale stored as BF16 bits in lower 16 bits of uint32 - - Other (FP8, INT8, etc.) -> scale stored as FP32 bits in uint32 - - Args: - scale_bmm2: The scale value for BMM2 (typically 1.0) - data_dtype: The input tensor dtype (determines scale_type2) - device: The target device for the tensor - - Returns: - A 1-element int32 tensor on device containing the scale bits - """ - if data_dtype == torch.float16: - # Create int32 buffer on device, write FP16 value to lower 16 bits via view - result = torch.zeros(1, dtype=torch.int32, device=device) - result.view(torch.float16)[0] = scale_bmm2 - return result - elif data_dtype == torch.bfloat16: - # Create int32 buffer on device, write BF16 value to lower 16 bits via view - result = torch.zeros(1, dtype=torch.int32, device=device) - result.view(torch.bfloat16)[0] = scale_bmm2 - return result - else: - # FP8, INT8, etc. use FP32 accumulation - create FP32 tensor and view as int32 - return torch.tensor([scale_bmm2], dtype=torch.float32, device=device).view( - torch.int32 - ) - - @functools.cache def get_fmha_module( dtype_q: torch.dtype, @@ -4442,6 +4403,10 @@ def trtllm_fmha_v2_prefill( # TODO: implement native NHD support in the kernel to avoid this transpose kv_cache = paged_kv.transpose(-3, -2).contiguous() k_cache, v_cache = kv_cache.unbind(dim=1) + if block_tables is None: + raise ValueError( + "block_tables is required for Q_PAGED_KV_NHD input layout." + ) elif input_layout == "Q_PAGED_KV_HND": assert isinstance(qkv, tuple) query, paged_kv = qkv[0], qkv[1] @@ -4450,6 +4415,10 @@ def trtllm_fmha_v2_prefill( f"Q_PAGED_KV_HND expects paged_KV shape [pages, 2, num_kv_heads, page_size, head_dim], got {tuple(paged_kv.shape)}" ) k_cache, v_cache = paged_kv.unbind(dim=1) + if block_tables is None: + raise ValueError( + "block_tables is required for Q_PAGED_KV_HND input layout." + ) elif input_layout == "SEPARATE_Q_K_V": assert isinstance(qkv, tuple) query, k_cache, v_cache = qkv[0], qkv[1], qkv[2] @@ -4526,12 +4495,6 @@ def trtllm_fmha_v2_prefill( device=query.device, ) - # Handle scale parameters - scale_bmm1 = float(bmm1_scale) - scale_bmm2 = float(bmm2_scale) - - # Softmax scale: 1.0 for FP8, 0.0 (auto-detect) for FP16/BF16 - # C++ kernel auto-sets to 1.0 for FP16/E4M3 when 0.0 is passed is_e4m3 = ( query.dtype == torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else False ) @@ -4552,7 +4515,9 @@ def trtllm_fmha_v2_prefill( "num_kv_heads=4, head_dim=256 is not supported for " f"{input_layout} layout due to a known issue." ) - scale_softmax = 1.0 if is_e4m3 else 1.0 + # Always pass 1.0: the C++ auto-detect (scale_softmax == 0.0) handles FP16/INT8/E4M3 + # but has no branch for BF16, where 0.0 would zero out the softmax output. + scale_softmax = 1.0 softcapping_scale = ( logits_soft_cap_scale if logits_soft_cap_scale is not None else 0.0 ) @@ -4564,7 +4529,8 @@ def trtllm_fmha_v2_prefill( ) # Allocate LSE tensor if saving softmax stats - # Kernel writes in ragged (flat) format: [total_q_tokens, num_qo_heads, 2] + # Kernel writes in ragged (flat + # ) format: [total_q_tokens, num_qo_heads, 2] # total_q_tokens == query.shape[0] for all ragged layouts lse = None if save_softmax_stats: @@ -4574,19 +4540,6 @@ def trtllm_fmha_v2_prefill( device=query.device, ) - # For Q_PAGED_KV layout, expand block_tables from [B, M] to [B, 2, M] - # TRT-LLM kernel expects separate K and V block offset arrays. - # FlashInfer layout: K for page i is at block index 2*i, V at 2*i+1 - expanded_block_tables = None - if block_tables is not None and input_layout.lower().startswith("q_paged_kv"): - # K offsets = page_idx * 2 (even blocks) - # V offsets = page_idx * 2 + 1 (odd blocks) - expanded_block_tables = torch.stack( - [block_tables * 2, block_tables * 2 + 1], dim=1 - ).contiguous() # [B, 2, M] - - scale_bmm2_d = _create_scale_bmm2_d_tensor(scale_bmm2, query.dtype, query.device) - module.run( query, # Q tensor k_cache, # K tensor @@ -4595,7 +4548,7 @@ def trtllm_fmha_v2_prefill( workspace_buffer, # Workspace buffer workspace_buffer.numel() * workspace_buffer.element_size(), # Workspace buffer size in bytes - expanded_block_tables, # Expanded block tables [B, 2, M] or None + block_tables, page_size, seq_lens, # Sequence length for kv_cache cum_seq_lens_q, # Cumulative sequence length for query @@ -4606,15 +4559,14 @@ def trtllm_fmha_v2_prefill( batch_size, # Batch size mask_mode.lower(), # Attention mask type scale_softmax, # Softmax scale - scale_bmm1, # BMM1 scale - scale_bmm2, # BMM2 scale (float, still needed for set_alpha in C++) + bmm1_scale, # BMM1 scale + bmm2_scale, # BMM2 scale (host float; encoded by C++ set_alpha) window_left, # Window left chunked_attention_size, # Chunked attention size pos_encoding_mode is not None and pos_encoding_mode.lower() == "alibi", # Alibi mode softcapping_scale, # Softcapping scale (0.0 = disabled) skip_softmax_threshold_scale_factor, # threshold_scale_factor for skip-softmax (0.0 = disable) - scale_bmm2_d, # Pre-populated scale_bmm2 on device (avoids cudaMemcpy) lse, # Optional LSE tensor (None if not saving softmax stats) sinks, # Optional sinks tensor ) diff --git a/tests/attention/test_fmha_v2_prefill.py b/tests/attention/test_fmha_v2_prefill.py index ba551f1233..673b6dcd14 100644 --- a/tests/attention/test_fmha_v2_prefill.py +++ b/tests/attention/test_fmha_v2_prefill.py @@ -859,7 +859,7 @@ def test_trtllm_fmha_v2_prefill( [ (torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16), - (torch.float8_e4m3fn, torch.float16), + (torch.float8_e4m3fn, torch.bfloat16), ], ) @pytest.mark.parametrize(