diff --git a/vllm_kunlun/ops/attention/flashmla.py b/vllm_kunlun/ops/attention/flashmla.py index b643e651..df4b45b7 100644 --- a/vllm_kunlun/ops/attention/flashmla.py +++ b/vllm_kunlun/ops/attention/flashmla.py @@ -194,6 +194,8 @@ def flash_mla_sparse_prefill( sm_scale: float, q_lod_xpu: torch.Tensor, q_lod_cpu: torch.Tensor, + kv_lod_xpu: Optional[torch.Tensor] = None, + kv_lod_cpu: Optional[torch.Tensor] = None, d_v: int = 512, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -206,6 +208,8 @@ def flash_mla_sparse_prefill( Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - q_lod_xpu: [batch+1], int32, q的每个seq长度的累加信息, 长度为batch_num + 1 (为空则表示q定长). + - kv_lod_xpu: [batch+1], int32, kv的每个seq长度的累加信息(含context). None时fallback到q_lod. + - kv_lod_cpu: [batch+1], int32, kv的每个seq长度的累加信息(含context). None时fallback到q_lod. - d_v: The dimension of value vectors. Can only be 512 Returns: @@ -222,17 +226,23 @@ def flash_mla_sparse_prefill( max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device) lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device) + # For multi-turn conversations, kv_lod (total sequence lengths) differs + # from q_lod (new tokens only). Fall back to q_lod if not provided + # (single-turn where qlod == kvlod). + _kvlod_cpu = kv_lod_cpu if kv_lod_cpu is not None else q_lod_cpu + _kvlod_xpu = kv_lod_xpu if kv_lod_xpu is not None else q_lod_xpu + torch.ops._C.sparse_prefill_fwd_opt( q=q, kv=kv, indices=indices, qlod_cpu=q_lod_cpu, qlod_xpu=q_lod_xpu, - kvlod_cpu=q_lod_cpu, - kvlod_xpu=q_lod_xpu, + kvlod_cpu=_kvlod_cpu, + kvlod_xpu=_kvlod_xpu, sm_scale=sm_scale, d_v=d_v, - is_causal=True, #aiak这个值为true,这是为啥 + is_causal=True, out=out, max_logits=max_logits, lse=lse, diff --git a/vllm_kunlun/ops/deep_gemm.py b/vllm_kunlun/ops/deep_gemm.py index ffc7c90b..ef04bee3 100644 --- a/vllm_kunlun/ops/deep_gemm.py +++ b/vllm_kunlun/ops/deep_gemm.py @@ -88,7 +88,11 @@ def int8_paged_mqa_logits( num_blocks, block_size, _, _ = kv_cache_fp8.shape kv_cache_fp8 = kv_cache_fp8.view(num_blocks, -1) - k_val = kv_cache_fp8[:, :block_size * D].view(torch.int8) + # NOTE: kv_cache_fp8[:, :block_size * D] 是非连续视图(stride[0] = block_size*(D+4),非 block_size*D)。 + # I8_paged_mqa_logits kernel 假设 k_val 连续,内部按 physical_block_id * block_size * D + # 计算偏移;若不 contiguous,block_id>0 时会读到错误地址(scale字节混入K数据)。 + # Perf: contiguous() 每次 decode step 会复制 ~num_blocks*block_size*D 字节(典型配置约80MB)。 + k_val = kv_cache_fp8[:, :block_size * D].contiguous().view(torch.int8) k_val = k_val.view(-1, block_size, 1, D) block_indices = block_tables.flatten() diff --git a/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py index d5eea238..0a054b3d 100644 --- a/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm_kunlun/v1/attention/backends/mla/flashmla_sparse.py @@ -116,6 +116,11 @@ class MLASparsePrefillMetadata: request_ids: torch.Tensor = None query_start_loc: torch.Tensor = None query_start_loc_cpu: torch.Tensor = None + # KV LOD: cumulative offsets of total seq_lens (context + query) per + # prefill request. Used by the sparse prefill kernel for correct + # causal masking in multi-turn conversations where kv_len != q_len. + kv_start_loc: torch.Tensor = None + kv_start_loc_cpu: torch.Tensor = None @dataclass class FlashMLASparseDecodeAndContextMetadata: @@ -464,6 +469,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], (vllm_config.scheduler_config.max_num_batched_tokens, ), dtype=torch.int32, device=device) + # Pre-allocated buffer for kv_lod CPU tensor, reused across build calls. + self._kv_lod_cpu_buf = None def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -510,9 +517,30 @@ def build(self, # For mixed batches, it will have -1 for decode and request_id for prefill prefill_metadata = None if num_prefills > 0: + # Compute kv_lod from seq_lens for multi-turn correctness. + # kv_lod = cumsum of total sequence lengths (context + query), + # which differs from q_lod (cumsum of query lengths only) when + # there is existing KV cache from prior turns. + prefill_seq_lens_cpu = common_attn_metadata.seq_lens_cpu[ + num_decodes:] + if self._kv_lod_cpu_buf is None or self._kv_lod_cpu_buf.shape[ + 0] != num_prefills + 1: + self._kv_lod_cpu_buf = torch.zeros(num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_lod_cpu = self._kv_lod_cpu_buf + kv_lod_cpu.zero_() + kv_lod_cpu[1:] = prefill_seq_lens_cpu.to( + torch.int32).cumsum(dim=0) + kv_lod_xpu = kv_lod_cpu.to(self.device) + + q_start = common_attn_metadata.query_start_loc[num_decodes] + q_start_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes] prefill_metadata = MLASparsePrefillMetadata( - query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - common_attn_metadata.query_start_loc[num_decodes], #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值 - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - common_attn_metadata.query_start_loc_cpu[num_decodes], + query_start_loc = common_attn_metadata.query_start_loc[num_decodes:] - q_start, #因为prefiil、decode请求是分离,所以需要对q进行切分,故需调整该值 + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[num_decodes:] - q_start_cpu, + kv_start_loc=kv_lod_xpu, + kv_start_loc_cpu=kv_lod_cpu, ) decode_metadata = None @@ -623,11 +651,13 @@ def _bf16_prefill(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: # NOTE: 只有prefill阶段attn_metadata.query_start_loc是符合klx算子需求的 _attn_out = flash_mla_sparse_prefill( q=q, - kv=kv_c_and_k_pe_cache, + kv=kv_c_and_k_pe_cache, indices=topk_indices, sm_scale=self.softmax_scale, q_lod_xpu=prefill_metadata.query_start_loc, - q_lod_cpu=prefill_metadata.query_start_loc_cpu + q_lod_cpu=prefill_metadata.query_start_loc_cpu, + kv_lod_xpu=prefill_metadata.kv_start_loc, + kv_lod_cpu=prefill_metadata.kv_start_loc_cpu, )[0] return _attn_out diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index b212d2e3..e96b39a0 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -2365,6 +2365,8 @@ def sparse_prefill_fwd_opt( d_v: Optional[int] = -1, is_causal: Optional[bool] = True, use_xfa_boost: Optional[bool] = False, + side_stream: Optional[int] = -1, + p_out: Optional[torch.Tensor] = None, ) -> None: kunlun_ops.sparse_prefill_fwd_opt( q=q, @@ -2373,6 +2375,7 @@ def sparse_prefill_fwd_opt( out=out, max_logits=max_logits, lse=lse, + p_out=p_out, sm_scale=sm_scale, qlod_cpu=qlod_cpu, qlod_xpu=qlod_xpu, @@ -2401,6 +2404,8 @@ def sparse_prefill_fwd_opt_cuda( d_v: Optional[int] = -1, is_causal: Optional[bool] = True, use_xfa_boost: Optional[bool] = False, + side_stream: Optional[int] = -1, + p_out: Optional[torch.Tensor] = None, ) -> None: kunlun_ops.sparse_prefill_fwd_opt( q=q, @@ -2409,6 +2414,7 @@ def sparse_prefill_fwd_opt_cuda( out=out, max_logits=max_logits, lse=lse, + p_out=p_out, sm_scale=sm_scale, qlod_cpu=qlod_cpu, qlod_xpu=qlod_xpu, @@ -2436,6 +2442,8 @@ def _fake_sparse_prefill_fwd_opt( d_v: Optional[int] = -1, is_causal: Optional[bool] = True, use_xfa_boost: Optional[bool] = False, + side_stream: Optional[int] = -1, + p_out: Optional[torch.Tensor] = None, ) -> None: return None