Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions vllm_kunlun/ops/attention/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def flash_mla_sparse_prefill(
sm_scale: float,
q_lod_xpu: torch.Tensor,
q_lod_cpu: torch.Tensor,
kv_lod_xpu: Optional[torch.Tensor] = None,
kv_lod_cpu: Optional[torch.Tensor] = None,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Expand All @@ -206,6 +208,8 @@ def flash_mla_sparse_prefill(
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长).
- kv_lod_xpu: [batch+1], int32, kv的每个seq长度的累加信息(含context). None时fallback到q_lod.
- kv_lod_cpu: [batch+1], int32, kv的每个seq长度的累加信息(含context). None时fallback到q_lod.
- d_v: The dimension of value vectors. Can only be 512

Returns:
Expand All @@ -222,17 +226,23 @@ def flash_mla_sparse_prefill(
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)

# For multi-turn conversations, kv_lod (total sequence lengths) differs
# from q_lod (new tokens only). Fall back to q_lod if not provided
# (single-turn where qlod == kvlod).
_kvlod_cpu = kv_lod_cpu if kv_lod_cpu is not None else q_lod_cpu
_kvlod_xpu = kv_lod_xpu if kv_lod_xpu is not None else q_lod_xpu

torch.ops._C.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
qlod_cpu=q_lod_cpu,
qlod_xpu=q_lod_xpu,
kvlod_cpu=q_lod_cpu,
kvlod_xpu=q_lod_xpu,
kvlod_cpu=_kvlod_cpu,
kvlod_xpu=_kvlod_xpu,
sm_scale=sm_scale,
d_v=d_v,
is_causal=True, #aiak这个值为true,这是为啥
is_causal=True,
out=out,
max_logits=max_logits,
lse=lse,
Expand Down
6 changes: 5 additions & 1 deletion vllm_kunlun/ops/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def int8_paged_mqa_logits(
num_blocks, block_size, _, _ = kv_cache_fp8.shape

kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1)
k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8)
# NOTE: kv_cache_fp8[:, :block_size * D] 是非连续视图(stride[0] = block_size*(D+4),非 block_size*D)。
# I8_paged_mqa_logits kernel 假设 k_val 连续,内部按 physical_block_id * block_size * D
# 计算偏移;若不 contiguous,block_id>0 时会读到错误地址(scale字节混入K数据)。
# Perf: contiguous() 每次 decode step 会复制 ~num_blocks*block_size*D 字节(典型配置约80MB)。
k_val = kv_cache_fp8[:, :block_size * D].contiguous().view(torch.int8)
k_val = k_val.view(-1, block_size, 1, D)

block_indices = block_tables.flatten()
Expand Down
38 changes: 34 additions & 4 deletions vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ class MLASparsePrefillMetadata:
request_ids: torch.Tensor = None
query_start_loc: torch.Tensor = None
query_start_loc_cpu: torch.Tensor = None
# KV LOD: cumulative offsets of total seq_lens (context + query) per
# prefill request. Used by the sparse prefill kernel for correct
# causal masking in multi-turn conversations where kv_len != q_len.
kv_start_loc: torch.Tensor = None
kv_start_loc_cpu: torch.Tensor = None

@dataclass
class FlashMLASparseDecodeAndContextMetadata:
Expand Down Expand Up @@ -464,6 +469,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
(vllm_config.scheduler_config.max_num_batched_tokens, ),
dtype=torch.int32,
device=device)
# Pre-allocated buffer for kv_lod CPU tensor, reused across build calls.
self._kv_lod_cpu_buf = None
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
Expand Down Expand Up @@ -510,9 +517,30 @@ def build(self,
# For mixed batches, it will have -1 for decode and request_id for prefill
prefill_metadata = None
if num_prefills > 0:
# Compute kv_lod from seq_lens for multi-turn correctness.
# kv_lod = cumsum of total sequence lengths (context + query),
# which differs from q_lod (cumsum of query lengths only) when
# there is existing KV cache from prior turns.
prefill_seq_lens_cpu = common_attn_metadata.seq_lens_cpu[
num_decodes:]
if self._kv_lod_cpu_buf is None or self._kv_lod_cpu_buf.shape[
0] != num_prefills + 1:
self._kv_lod_cpu_buf = torch.zeros(num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_lod_cpu = self._kv_lod_cpu_buf
kv_lod_cpu.zero_()
kv_lod_cpu[1:] = prefill_seq_lens_cpu.to(
torch.int32).cumsum(dim=0)
kv_lod_xpu = kv_lod_cpu.to(self.device)

q_start = common_attn_metadata.query_start_loc[num_decodes]
q_start_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes]
prefill_metadata = MLASparsePrefillMetadata(
query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes],
query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - q_start, #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - q_start_cpu,
kv_start_loc=kv_lod_xpu,
kv_start_loc_cpu=kv_lod_cpu,
)

decode_metadata = None
Expand Down Expand Up @@ -623,11 +651,13 @@ def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的
_attn_out = flash_mla_sparse_prefill(
q=q,
kv=kv_c_and_k_pe_cache,
kv=kv_c_and_k_pe_cache,
indices=topk_indices,
sm_scale=self.softmax_scale,
q_lod_xpu=prefill_metadata.query_start_loc,
q_lod_cpu=prefill_metadata.query_start_loc_cpu
q_lod_cpu=prefill_metadata.query_start_loc_cpu,
kv_lod_xpu=prefill_metadata.kv_start_loc,
kv_lod_cpu=prefill_metadata.kv_start_loc_cpu,
)[0]
return _attn_out

Expand Down
8 changes: 8 additions & 0 deletions vllm_kunlun/vllm_utils_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,8 @@ def sparse_prefill_fwd_opt(
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
side_stream: Optional[int] = -1,
p_out: Optional[torch.Tensor] = None,
) -> None:
kunlun_ops.sparse_prefill_fwd_opt(
q=q,
Expand All @@ -2373,6 +2375,7 @@ def sparse_prefill_fwd_opt(
out=out,
max_logits=max_logits,
lse=lse,
p_out=p_out,
sm_scale=sm_scale,
qlod_cpu=qlod_cpu,
qlod_xpu=qlod_xpu,
Expand Down Expand Up @@ -2401,6 +2404,8 @@ def sparse_prefill_fwd_opt_cuda(
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
side_stream: Optional[int] = -1,
p_out: Optional[torch.Tensor] = None,
) -> None:
kunlun_ops.sparse_prefill_fwd_opt(
q=q,
Expand All @@ -2409,6 +2414,7 @@ def sparse_prefill_fwd_opt_cuda(
out=out,
max_logits=max_logits,
lse=lse,
p_out=p_out,
sm_scale=sm_scale,
qlod_cpu=qlod_cpu,
qlod_xpu=qlod_xpu,
Expand Down Expand Up @@ -2436,6 +2442,8 @@ def _fake_sparse_prefill_fwd_opt(
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
side_stream: Optional[int] = -1,
p_out: Optional[torch.Tensor] = None,
) -> None:
return None

Expand Down
Loading