diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 3d5e8956e8..41fce44b59 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -76,16 +76,16 @@ class TllmGenFmhaRunnerCache { void trtllm_paged_attention_launcher( void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache, void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q, - int* cum_seq_lens_kv, float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, - Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, - int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, - int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads, - int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, - const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, - int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, - int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, int64_t workspace_size, - cudaStream_t stream) { + int* cum_seq_lens_kv, float* attention_sinks, float* lse, Data_type q_data_type, + Data_type kv_data_type, Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, + int64_t max_q_len, int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, + int64_t q_stride_tokens, int64_t q_stride_heads, int64_t kv_stride_keys_values, + int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, + double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, + const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, + int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, + int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -142,6 +142,7 @@ void trtllm_paged_attention_launcher( runner_params.mSumOfSeqLensQ = sum_seq_q; runner_params.ptrAttentionSinks = attention_sinks; runner_params.enable_pdl = enable_pdl; + runner_params.lsePtr = lse; // The sparse MLA parameters. runner_params.mSparseMla = sparse_mla_top_k > 0; @@ -158,6 +159,19 @@ void trtllm_paged_attention_launcher( runner_params.cumSeqLensQPtr = cum_seq_lens_q; runner_params.cumSeqLensKvPtr = cum_seq_lens_kv; + + // Allocate softmax stats buffer for LSE if requested + if (lse != nullptr) { + size_t max_batch_size = 8192; + size_t max_num_qo_heads = 256; + size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); + runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc( + num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); + runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( + sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace"); + runner_params.multiCtasKvScratchPtr = + float_allocator.aligned_alloc(0, 16, "trtllm_gen_scratch_workspace"); + } } else { // Generation. // Note that kernel names are still labeled as using a dense mask even when maskType is @@ -177,10 +191,14 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the first 8MB of workspace buffer: counter | scratch - // todo(Yingyi): add softmax buffer later for lse return + // semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc( num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); + // Allocate softmax stats buffer for LSE if requested + if (lse != nullptr) { + runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( + sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace"); + } // scratch takes the rest of the workspace buffer runner_params.multiCtasKvScratchPtr = float_allocator.aligned_alloc(0, 16, "trtllm_gen_scratch_workspace"); @@ -227,7 +245,7 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal int64_t batch_size, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional attention_sinks, - Optional cum_seq_lens_q) { + Optional cum_seq_lens_q, Optional lse) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -277,6 +295,11 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal << "attention_sinks must be a float tensor"; attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } + float* lse_ptr = nullptr; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + } auto maybe_bmm1_scale_value = bmm1_scale.as(); auto maybe_bmm2_scale_value = bmm2_scale.as(); auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as(); @@ -300,10 +323,10 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), static_cast(seq_lens.data_ptr()), cum_seq_lens_q_ptr, - /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, - TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, - num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens, - q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, + /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, + o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, + num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, + q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, sm_count, enable_pdl, workspace_size, stream); @@ -316,7 +339,8 @@ void trtllm_paged_attention_context( Variant bmm1_scale, Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { + bool enable_pdl, int64_t workspace_size, Optional attention_sinks, + Optional lse) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); @@ -358,6 +382,11 @@ void trtllm_paged_attention_context( << "attention_sinks must be a float tensor"; attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } + float* lse_ptr = nullptr; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + } auto maybe_bmm1_scale_value = bmm1_scale.as(); auto maybe_bmm2_scale_value = bmm2_scale.as(); @@ -385,7 +414,7 @@ void trtllm_paged_attention_context( static_cast(seq_lens.data_ptr()), /*cum_seq_lens_q=*/static_cast(cum_seq_lens_q.data_ptr()), /*cum_seq_lens_kv=*/static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, - q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, + lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index a879caf338..2352543922 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2124,7 +2124,9 @@ def trtllm_batch_decode_with_kv_cache( mask: Optional[torch.Tensor] = None, max_q_len: Optional[int] = None, cum_seq_lens_q: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, FP4Tensor]: + return_lse: bool = False, + lse: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]]: """ Parameters ---------- @@ -2207,10 +2209,21 @@ def trtllm_batch_decode_with_kv_cache( Only supported by trtllm-gen backend. Must be provided together with ``max_q_len``. When None, all requests use uniform query length specified by ``q_len_per_req``. + return_lse : bool = False + Whether to return Log-Sum-Exp (LSE) values. + Only supported by trtllm-gen backend. XQA backend does not support LSE return. + + lse : Optional[torch.Tensor] = None + LSE tensor to write into. If not provided and return_lse is True, a new tensor will be allocated. + Shape should be ``[num_tokens, num_heads]``, dtype: ``torch.float32``. + Returns ------- - out : Union[torch.Tensor, FP4Tensor] - output torch.Tensor or FP4Tensor. + out : Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output (torch.Tensor or FP4Tensor). + If :attr:`return_lse` is ``True``, a tuple of two tensors: + - The attention output (torch.Tensor or FP4Tensor) + - The LSE tensor (torch.Tensor with dtype float32) """ enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl @@ -2240,6 +2253,8 @@ def trtllm_batch_decode_with_kv_cache( raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size") if max_q_len is not None or cum_seq_lens_q is not None: raise ValueError("xqa backend does not support cum_seq_lens_q") + if return_lse: + raise ValueError("xqa backend does not support return_lse") # Handle out and out_dtype if out_dtype is None: @@ -2363,6 +2378,12 @@ def trtllm_batch_decode_with_kv_cache( assert max_q_len is not None batch_size = cum_seq_lens_q.size(0) - 1 + # Allocate LSE tensor if return_lse is True and lse is not provided + if return_lse and lse is None: + lse = torch.empty( + query.shape[0], query.shape[1], device=query.device, dtype=torch.float32 + ) + run_func( out, out_scale_factor, @@ -2387,13 +2408,15 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer.numel() * workspace_buffer.element_size(), sinks, cum_seq_lens_q, + lse if return_lse else None, ) - return ( + out_tensor = ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) ) + return (out_tensor, lse) if return_lse else out_tensor else: raise KeyError(f"Backend {backend} not supported") diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 83415521c3..7c88ca922e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -528,7 +528,9 @@ def trtllm_batch_decode_with_kv_cache_mla( sinks: Optional[List[torch.Tensor]] = None, enable_pdl: bool = None, backend: str = "auto", -) -> torch.Tensor: + return_lse: bool = False, + lse: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters ---------- @@ -554,6 +556,22 @@ def trtllm_batch_decode_with_kv_cache_mla( For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. + return_lse : bool = False + Whether to return Log-Sum-Exp (LSE) values. + Only supported by trtllm-gen backend. XQA backend does not support LSE return. + + lse : Optional[torch.Tensor] = None + LSE tensor to write into. If not provided and return_lse is True, a new tensor will be allocated. + Shape should be ``[batch_size * q_len_per_request, num_heads]``, dtype: ``torch.float32``. + + Returns + ------- + out : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output tensor. + If :attr:`return_lse` is ``True``, a tuple of two tensors: + - The attention output tensor + - The LSE tensor (torch.Tensor with dtype float32) + Note ---- In MLA, the actual BMM1 and BMM2 scales applied would be fused as: @@ -597,6 +615,8 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError( f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" ) + if return_lse: + raise ValueError("XQA MLA backend does not support return_lse") return xqa_batch_decode_with_kv_cache_mla( query, kv_cache, @@ -654,6 +674,13 @@ def trtllm_batch_decode_with_kv_cache_mla( max_q_len = query.size(1) query = query.flatten(0, 1) # [B*S, H, D] + # Allocate LSE tensor if return_lse is True and lse is not provided + if return_lse and lse is None: + num_qo_heads = query.shape[1] + lse = torch.empty( + query.shape[0], num_qo_heads, device=query.device, dtype=torch.float32 + ) + run_func( out, None, # fp4 output not supported in wrapper api yet. @@ -678,9 +705,10 @@ def trtllm_batch_decode_with_kv_cache_mla( workspace_buffer.numel() * workspace_buffer.element_size(), sinks, None, # cum_seq_lens_q + lse if return_lse else None, ) - return out + return (out, lse) if return_lse else out else: raise ValueError(f"Backend {backend} not supported") diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index bfdcfd9048..39c9e8082b 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3487,7 +3487,9 @@ def trtllm_batch_context_with_kv_cache( kv_layout: str = "HND", enable_pdl: Optional[bool] = None, sinks: Optional[List[torch.Tensor]] = None, -) -> Union[torch.Tensor, FP4Tensor]: + return_lse: bool = False, + lse: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]]: """ Parameters ---------- @@ -3540,10 +3542,20 @@ def trtllm_batch_context_with_kv_cache( sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. + return_lse : bool = False + Whether to return Log-Sum-Exp (LSE) values. + + lse : Optional[torch.Tensor] = None + LSE tensor to write into. If not provided and return_lse is True, a new tensor will be allocated. + Shape should be ``[num_tokens, num_heads]``, dtype: ``torch.float32``. + Returns ------- - out: Union[torch.Tensor, FP4Tensor] - output torch.Tensor or FP4Tensor. + out: Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output (torch.Tensor or FP4Tensor). + If :attr:`return_lse` is ``True``, a tuple of two tensors: + - The attention output (torch.Tensor or FP4Tensor) + - The LSE tensor (torch.Tensor with dtype float32) """ if enable_pdl is None: @@ -3649,6 +3661,13 @@ def trtllm_batch_context_with_kv_cache( if isinstance(bmm2_scale, torch.Tensor): assert bmm2_scale.dtype == torch.float32 workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() + + # Allocate LSE tensor if return_lse is True and lse is not provided + if return_lse and lse is None: + lse = torch.empty( + query.shape[0], query.shape[1], device=query.device, dtype=torch.float32 + ) + run_func( out, out_scale_factor, @@ -3673,12 +3692,14 @@ def trtllm_batch_context_with_kv_cache( enable_pdl, workspace_size, sinks, + lse if return_lse else None, ) - return ( + out_tensor = ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) ) + return (out_tensor, lse) if return_lse else out_tensor @flashinfer_api