Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm.v1.kv_cache_interface import FullAttentionSpec

BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA,
Expand Down Expand Up @@ -289,7 +289,7 @@ def get_kv_cache_spec(self, vllm_config):

def run_attention_backend(
backend: AttentionBackendEnum,
kv_cache_spec: MLAAttentionSpec,
kv_cache_spec: FullAttentionSpec,
layer_names: list[str],
vllm_config,
device: torch.device,
Expand Down Expand Up @@ -740,15 +740,14 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size]

# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = MLAAttentionSpec(
backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)

backend_output = run_attention_backend(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]
Expand Down
126 changes: 14 additions & 112 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,6 @@ def __init__(
metadata_cls if metadata_cls is not None else MLACommonMetadata
)
self.kv_cache_spec = kv_cache_spec
self.q_data_type = (
current_platform.fp8_dtype()
if (kv_cache_spec is not None and "fp8" in kv_cache_spec.cache_dtype_str)
else vllm_config.model_config.dtype
)
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
Expand Down Expand Up @@ -689,6 +684,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):

# For main run, qo_indptr == kv_indptr
kv_indptr = qo_indptr.clone()

# Prepare main prefill
self._fi_prefill_main.plan(
qo_indptr=qo_indptr,
Expand All @@ -701,7 +697,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
q_data_type=self.model_config.dtype,
)

# Prepare context prefills
Expand All @@ -720,7 +716,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.q_data_type,
q_data_type=self.model_config.dtype,
)

prefill.prefill_main = self._fi_prefill_main
Expand Down Expand Up @@ -973,7 +969,6 @@ def build(
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
q_data_type=self.q_data_type,
)

if self._use_cudnn_prefill:
Expand Down Expand Up @@ -1384,15 +1379,8 @@ def _flash_attn_varlen_diff_headdims(
return attn_out

def _run_prefill_new_tokens_fa(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running FlashAttention prefill new tokens", scope="local")
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
Expand All @@ -1407,23 +1395,11 @@ def _run_prefill_new_tokens_fa(
)

def _run_prefill_new_tokens_fi(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running FlashInfer prefill new tokens", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None
if fp8_attention:
logger.debug_once("Running Flashinfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)

ret = prefill.prefill_main.run(
q=q,
k=k,
Expand All @@ -1436,18 +1412,10 @@ def _run_prefill_new_tokens_fi(
return ret

def _run_prefill_new_tokens_cudnn(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running Cudnn prefill new tokens", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
output, lse = cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
Expand All @@ -1469,19 +1437,9 @@ def _run_prefill_new_tokens_cudnn(
return output

def _run_prefill_context_chunk_fa(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running FlashAttention prefill context chunk", scope="local")
assert prefill.chunked_context is not None
assert fp8_attention is False, (
"FlashAttention prefill does not support fp8 attention"
)
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
Expand All @@ -1496,22 +1454,10 @@ def _run_prefill_context_chunk_fa(
)

def _run_prefill_context_chunk_fi(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running FlashInfer prefill context chunk", scope="local")
assert isinstance(prefill, FlashInferPrefillMetadata)
if fp8_attention:
logger.debug_once("Running FlashInfer prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)

attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q,
k=k,
Expand All @@ -1523,20 +1469,12 @@ def _run_prefill_context_chunk_fi(
return attn_out, lse.transpose(0, 1).contiguous()

def _run_prefill_context_chunk_cudnn(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running Cudnn prefill context chunk", scope="local")
assert isinstance(prefill, CudnnPrefillMetadata)
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.query_seq_lens is not None
assert fp8_attention is False, "Cudnn prefill does not support fp8 attention"
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
Expand All @@ -1556,28 +1494,14 @@ def _run_prefill_context_chunk_cudnn(
)

def _run_prefill_new_tokens_trtllm_ragged(
self,
prefill: MLACommonPrefillMetadata,
q,
k,
v,
return_softmax_lse,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
):
logger.debug_once("Running TRT-LLM ragged prefill new tokens", scope="local")
"""TRT-LLM ragged attention for new tokens (causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek

assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None

if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)

ret = trtllm_ragged_attention_deepseek(
query=q,
key=k,
Expand All @@ -1604,15 +1528,8 @@ def _run_prefill_new_tokens_trtllm_ragged(
return ret

def _run_prefill_context_chunk_trtllm_ragged(
self,
prefill: MLACommonPrefillMetadata,
chunk_idx: int,
q,
k,
v,
fp8_attention: bool,
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
):
logger.debug_once("Running TRT-LLM ragged prefill context chunk", scope="local")
"""TRT-LLM ragged attention for context chunks (non-causal)."""
from flashinfer.prefill import trtllm_ragged_attention_deepseek

Expand All @@ -1629,13 +1546,6 @@ def _run_prefill_context_chunk_trtllm_ragged(
)
prefill.workspace_buffer.fill_(0)

if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
fp8_dtype = current_platform.fp8_dtype()
q = q.to(fp8_dtype)
k = k.to(fp8_dtype)
v = v.to(fp8_dtype)

attn_out, lse = trtllm_ragged_attention_deepseek(
query=q,
key=k,
Expand Down Expand Up @@ -1788,7 +1698,6 @@ def _compute_prefill_context(
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
fp8_attention: bool,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
Expand Down Expand Up @@ -1827,7 +1736,6 @@ def _compute_prefill_context(
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)

if output is None:
Expand Down Expand Up @@ -1856,7 +1764,6 @@ def _context_parallel_compute_prefill_context(
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
fp8_attention: bool,
):
assert k_scale is None, "DCP not support scaled kvcache now."
assert attn_metadata.prefill is not None
Expand Down Expand Up @@ -1933,7 +1840,6 @@ def _context_parallel_compute_prefill_context(
q=q,
k=k,
v=v,
fp8_attention=fp8_attention,
)

if output is None:
Expand Down Expand Up @@ -1964,7 +1870,6 @@ def _forward_prefill(
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
output: torch.Tensor,
fp8_attention: bool = False,
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
Expand All @@ -1984,7 +1889,6 @@ def _forward_prefill(
k=k,
v=v,
return_softmax_lse=has_context,
fp8_attention=fp8_attention,
)

if has_context:
Expand All @@ -1997,12 +1901,11 @@ def _forward_prefill(
attn_metadata,
k_scale=None,
dcp_world_size=self.dcp_world_size,
fp8_attention=fp8_attention,
)
)
else:
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale, fp8_attention
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)

# unpad if necessary
Expand Down Expand Up @@ -2123,7 +2026,6 @@ def forward(
attn_metadata,
layer._k_scale,
output=output[num_decode_tokens:],
fp8_attention=fp8_attention,
)

if has_decode:
Expand Down
Loading