Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ class TllmGenFmhaRunnerCache {
void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q,
int* cum_seq_lens_kv, float* attention_sinks, Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens,
int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads,
int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
cudaStream_t stream) {
int* cum_seq_lens_kv, float* attention_sinks, float* lse, Data_type q_data_type,
Data_type kv_data_type, Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size,
int64_t max_q_len, int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size,
int64_t q_stride_tokens, int64_t q_stride_heads, int64_t kv_stride_keys_values,
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k,
int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -142,6 +142,7 @@ void trtllm_paged_attention_launcher(
runner_params.mSumOfSeqLensQ = sum_seq_q;
runner_params.ptrAttentionSinks = attention_sinks;
runner_params.enable_pdl = enable_pdl;
runner_params.lsePtr = lse;

// The sparse MLA parameters.
runner_params.mSparseMla = sparse_mla_top_k > 0;
Expand All @@ -158,6 +159,19 @@ void trtllm_paged_attention_launcher(

runner_params.cumSeqLensQPtr = cum_seq_lens_q;
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;

// Allocate softmax stats buffer for LSE if requested
if (lse != nullptr) {
size_t max_batch_size = 8192;
size_t max_num_qo_heads = 256;
size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8);
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace");
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
}
Comment on lines +164 to +174
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for allocating LSE-related buffers has a couple of areas for improvement:

  1. Hardcoded values: max_batch_size and max_num_qo_heads are hardcoded. This can lead to inefficient workspace memory allocation. It's better to use the dynamic batch_size and num_qo_heads values from the function arguments to make the allocation more precise.

  2. Unclear allocations: The preceding comment mentions allocating the softmaxStatsPtr, but this block also allocates multiCtasKvCounterPtr and multiCtasKvScratchPtr. These seem related to multi-CTA mode, which is disabled for context attention. If these are not needed for LSE computation in context mode, they should be removed. If they are needed, the comment should be updated for clarity.

The suggestion below addresses the hardcoded values.

    if (lse != nullptr) {
      size_t num_semaphores = round_up(batch_size * num_qo_heads, 8);
      runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
          num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
      runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
          sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace");
      runner_params.multiCtasKvScratchPtr =
          float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
    }

} else {
// Generation.
// Note that kernel names are still labeled as using a dense mask even when maskType is
Expand All @@ -177,10 +191,14 @@ void trtllm_paged_attention_launcher(
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
size_t num_semaphores =
round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes
// semaphores be at the first 8MB of workspace buffer: counter | scratch
// todo(Yingyi): add softmax buffer later for lse return
// semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
// Allocate softmax stats buffer for LSE if requested
if (lse != nullptr) {
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * sum_seq_q, 16, "trtllm_gen_softmax_workspace");
}
// scratch takes the rest of the workspace buffer
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
Expand Down Expand Up @@ -227,7 +245,7 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
int64_t batch_size, int64_t window_left,
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q) {
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -277,6 +295,11 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
<< "attention_sinks must be a float tensor";
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}
float* lse_ptr = nullptr;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
}
auto maybe_bmm1_scale_value = bmm1_scale.as<double>();
auto maybe_bmm2_scale_value = bmm2_scale.as<double>();
auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>();
Expand All @@ -300,10 +323,10 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool,
num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens,
q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type,
o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
sparse_mla_top_k, sm_count, enable_pdl, workspace_size, stream);
Expand All @@ -316,7 +339,8 @@ void trtllm_paged_attention_context(
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand Down Expand Up @@ -358,6 +382,11 @@ void trtllm_paged_attention_context(
<< "attention_sinks must be a float tensor";
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}
float* lse_ptr = nullptr;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
}

auto maybe_bmm1_scale_value = bmm1_scale.as<double>();
auto maybe_bmm2_scale_value = bmm2_scale.as<double>();
Expand Down Expand Up @@ -385,7 +414,7 @@ void trtllm_paged_attention_context(
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values,
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
Expand Down
31 changes: 27 additions & 4 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,7 +2124,9 @@ def trtllm_batch_decode_with_kv_cache(
mask: Optional[torch.Tensor] = None,
max_q_len: Optional[int] = None,
cum_seq_lens_q: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
return_lse: bool = False,
lse: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]]:
"""
Parameters
----------
Expand Down Expand Up @@ -2207,10 +2209,21 @@ def trtllm_batch_decode_with_kv_cache(
Only supported by trtllm-gen backend. Must be provided together with ``max_q_len``.
When None, all requests use uniform query length specified by ``q_len_per_req``.

return_lse : bool = False
Whether to return Log-Sum-Exp (LSE) values.
Only supported by trtllm-gen backend. XQA backend does not support LSE return.

lse : Optional[torch.Tensor] = None
LSE tensor to write into. If not provided and return_lse is True, a new tensor will be allocated.
Shape should be ``[num_tokens, num_heads]``, dtype: ``torch.float32``.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
out : Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]]
If :attr:`return_lse` is ``False``, the attention output (torch.Tensor or FP4Tensor).
If :attr:`return_lse` is ``True``, a tuple of two tensors:
- The attention output (torch.Tensor or FP4Tensor)
- The LSE tensor (torch.Tensor with dtype float32)
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

Expand Down Expand Up @@ -2240,6 +2253,8 @@ def trtllm_batch_decode_with_kv_cache(
raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size")
if max_q_len is not None or cum_seq_lens_q is not None:
raise ValueError("xqa backend does not support cum_seq_lens_q")
if return_lse:
raise ValueError("xqa backend does not support return_lse")

# Handle out and out_dtype
if out_dtype is None:
Expand Down Expand Up @@ -2363,6 +2378,12 @@ def trtllm_batch_decode_with_kv_cache(
assert max_q_len is not None
batch_size = cum_seq_lens_q.size(0) - 1

# Allocate LSE tensor if return_lse is True and lse is not provided
if return_lse and lse is None:
lse = torch.empty(
query.shape[0], query.shape[1], device=query.device, dtype=torch.float32
)

run_func(
out,
out_scale_factor,
Expand All @@ -2387,13 +2408,15 @@ def trtllm_batch_decode_with_kv_cache(
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
cum_seq_lens_q,
lse if return_lse else None,
)

return (
out_tensor = (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)
return (out_tensor, lse) if return_lse else out_tensor
else:
raise KeyError(f"Backend {backend} not supported")

Expand Down
32 changes: 30 additions & 2 deletions flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ def trtllm_batch_decode_with_kv_cache_mla(
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
backend: str = "auto",
) -> torch.Tensor:
return_lse: bool = False,
lse: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Parameters
----------
Expand All @@ -554,6 +556,22 @@ def trtllm_batch_decode_with_kv_cache_mla(
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.

return_lse : bool = False
Whether to return Log-Sum-Exp (LSE) values.
Only supported by trtllm-gen backend. XQA backend does not support LSE return.

lse : Optional[torch.Tensor] = None
LSE tensor to write into. If not provided and return_lse is True, a new tensor will be allocated.
Shape should be ``[batch_size * q_len_per_request, num_heads]``, dtype: ``torch.float32``.

Returns
-------
out : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If :attr:`return_lse` is ``False``, the attention output tensor.
If :attr:`return_lse` is ``True``, a tuple of two tensors:
- The attention output tensor
- The LSE tensor (torch.Tensor with dtype float32)

Note
----
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
Expand Down Expand Up @@ -597,6 +615,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}"
)
if return_lse:
raise ValueError("XQA MLA backend does not support return_lse")
return xqa_batch_decode_with_kv_cache_mla(
query,
kv_cache,
Expand Down Expand Up @@ -654,6 +674,13 @@ def trtllm_batch_decode_with_kv_cache_mla(
max_q_len = query.size(1)
query = query.flatten(0, 1) # [B*S, H, D]

# Allocate LSE tensor if return_lse is True and lse is not provided
if return_lse and lse is None:
num_qo_heads = query.shape[1]
lse = torch.empty(
query.shape[0], num_qo_heads, device=query.device, dtype=torch.float32
)

run_func(
out,
None, # fp4 output not supported in wrapper api yet.
Expand All @@ -678,9 +705,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
None, # cum_seq_lens_q
lse if return_lse else None,
)

return out
return (out, lse) if return_lse else out
else:
raise ValueError(f"Backend {backend} not supported")

Expand Down
29 changes: 25 additions & 4 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3487,7 +3487,9 @@ def trtllm_batch_context_with_kv_cache(
kv_layout: str = "HND",
enable_pdl: Optional[bool] = None,
sinks: Optional[List[torch.Tensor]] = None,
) -> Union[torch.Tensor, FP4Tensor]:
return_lse: bool = False,
lse: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]]:
"""
Parameters
----------
Expand Down Expand Up @@ -3540,10 +3542,20 @@ def trtllm_batch_context_with_kv_cache(
sinks : Optional[List[torch.Tensor]] = None
additional value per head in the denominator of the softmax.

return_lse : bool = False
Whether to return Log-Sum-Exp (LSE) values.

lse : Optional[torch.Tensor] = None
LSE tensor to write into. If not provided and return_lse is True, a new tensor will be allocated.
Shape should be ``[num_tokens, num_heads]``, dtype: ``torch.float32``.

Returns
-------
out: Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
out: Union[torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]]
If :attr:`return_lse` is ``False``, the attention output (torch.Tensor or FP4Tensor).
If :attr:`return_lse` is ``True``, a tuple of two tensors:
- The attention output (torch.Tensor or FP4Tensor)
- The LSE tensor (torch.Tensor with dtype float32)
"""

if enable_pdl is None:
Expand Down Expand Up @@ -3649,6 +3661,13 @@ def trtllm_batch_context_with_kv_cache(
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()

# Allocate LSE tensor if return_lse is True and lse is not provided
if return_lse and lse is None:
lse = torch.empty(
query.shape[0], query.shape[1], device=query.device, dtype=torch.float32
)

run_func(
out,
out_scale_factor,
Expand All @@ -3673,12 +3692,14 @@ def trtllm_batch_context_with_kv_cache(
enable_pdl,
workspace_size,
sinks,
lse if return_lse else None,
)
return (
out_tensor = (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)
return (out_tensor, lse) if return_lse else out_tensor


@flashinfer_api
Expand Down
Loading