Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@ def _check_trtllm_gen_mla_shape(
):
if query.ndim != 4:
raise ValueError(f"Expected query.ndim == 4, got {query.ndim}")
if kv_cache.ndim != 4:
raise ValueError(f"Expected kv_cache.ndim == 4, got {kv_cache.ndim}")

# Support both 3D and 4D kv_cache for backward compatibility
if kv_cache.ndim == 3:
# [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]
kv_cache = kv_cache.unsqueeze(1)
elif kv_cache.ndim != 4:
raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}")
Comment on lines +80 to +84
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic is correct. For improved clarity and to separate validation from transformation, you could first check if the dimension is valid and then perform the normalization. This makes the intent of each block of code clearer.

Suggested change
if kv_cache.ndim == 3:
# [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]
kv_cache = kv_cache.unsqueeze(1)
elif kv_cache.ndim != 4:
raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}")
if kv_cache.ndim not in (3, 4):
raise ValueError(f"Expected kv_cache.ndim to be 3 or 4, got {kv_cache.ndim}")
if kv_cache.ndim == 3:
# [num_pages, page_size, head_dim_ckv + head_dim_kpe] -> [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe]
kv_cache = kv_cache.unsqueeze(1)


if qk_nope_head_dim != 128:
raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}")
if kv_lora_rank != 512:
Expand Down Expand Up @@ -112,6 +118,8 @@ def _check_trtllm_gen_mla_shape(
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
)

return kv_cache


@functools.cache
def get_trtllm_gen_fmha_module():
Expand Down Expand Up @@ -533,7 +541,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
Parameters
----------
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe] or [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache. Both 3D and 4D formats are supported for backward compatibility.
workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
Expand Down Expand Up @@ -620,13 +628,15 @@ def trtllm_batch_decode_with_kv_cache_mla(
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)

# Extract block_size (works for both 3D and 4D)
block_size = kv_cache.size(-2)
if (
block_size != 32 and block_size != 64
): # todo(Yingyi): add support for more block sizes?
raise ValueError(f"Supported block_size are 32 and 64, got {block_size}")

_check_trtllm_gen_mla_shape(
# Validate and normalize to 4D
kv_cache = _check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
Expand Down Expand Up @@ -705,7 +715,7 @@ def xqa_batch_decode_with_kv_cache_mla(
"""
Parameters:
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe] or [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache. Both 3D and 4D formats are supported for backward compatibility.
workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
Expand Down Expand Up @@ -736,6 +746,7 @@ def xqa_batch_decode_with_kv_cache_mla(
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
sm_count = get_device_sm_count(query.device)

# Extract block_size (works for both 3D and 4D)
block_size = kv_cache.size(-2)
q_len_per_request = query.size(1)
if q_len_per_request != 1:
Expand All @@ -749,7 +760,8 @@ def xqa_batch_decode_with_kv_cache_mla(
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")

_check_trtllm_gen_mla_shape(
# Validate and normalize to 4D
kv_cache = _check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
Expand Down