|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +"""Custom Sparse Attention Indexer layers.""" |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm._aiter_ops import rocm_aiter_ops |
| 8 | +from vllm.forward_context import get_forward_context |
| 9 | +from vllm.logger import init_logger |
| 10 | +from vllm.model_executor.custom_op import CustomOp |
| 11 | +from vllm.platforms import current_platform |
| 12 | +from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits |
| 13 | +from vllm.utils.torch_utils import direct_register_custom_op |
| 14 | +from vllm.v1.attention.backends.mla.indexer import ( |
| 15 | + DeepseekV32IndexerMetadata, |
| 16 | +) |
| 17 | +from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton |
| 18 | +from vllm.v1.worker.workspace import current_workspace_manager |
| 19 | + |
| 20 | +if current_platform.is_cuda_alike(): |
| 21 | + from vllm import _custom_ops as ops |
| 22 | +elif current_platform.is_xpu(): |
| 23 | + from vllm._ipex_ops import ipex_ops as ops |
| 24 | + |
| 25 | +logger = init_logger(__name__) |
| 26 | + |
| 27 | + |
| 28 | +def sparse_attn_indexer( |
| 29 | + hidden_states: torch.Tensor, |
| 30 | + k_cache_prefix: str, |
| 31 | + kv_cache: torch.Tensor, |
| 32 | + q_fp8: torch.Tensor, |
| 33 | + k: torch.Tensor, |
| 34 | + weights: torch.Tensor, |
| 35 | + quant_block_size: int, |
| 36 | + scale_fmt: str | None, |
| 37 | + topk_tokens: int, |
| 38 | + head_dim: int, |
| 39 | + max_model_len: int, |
| 40 | + total_seq_lens: int, |
| 41 | + topk_indices_buffer: torch.Tensor, |
| 42 | +) -> torch.Tensor: |
| 43 | + # careful! this will be None in dummy run |
| 44 | + attn_metadata = get_forward_context().attn_metadata |
| 45 | + fp8_dtype = current_platform.fp8_dtype() |
| 46 | + |
| 47 | + # assert isinstance(attn_metadata, dict) |
| 48 | + if not isinstance(attn_metadata, dict): |
| 49 | + # Reserve workspace for indexer during profiling run |
| 50 | + current_workspace_manager().get_simultaneous( |
| 51 | + ((total_seq_lens, head_dim), torch.float8_e4m3fn), |
| 52 | + ((total_seq_lens, 4), torch.uint8), |
| 53 | + ) |
| 54 | + return sparse_attn_indexer_fake( |
| 55 | + hidden_states, |
| 56 | + k_cache_prefix, |
| 57 | + kv_cache, |
| 58 | + q_fp8, |
| 59 | + k, |
| 60 | + weights, |
| 61 | + quant_block_size, |
| 62 | + scale_fmt, |
| 63 | + topk_tokens, |
| 64 | + head_dim, |
| 65 | + max_model_len, |
| 66 | + total_seq_lens, |
| 67 | + topk_indices_buffer, |
| 68 | + ) |
| 69 | + attn_metadata = attn_metadata[k_cache_prefix] |
| 70 | + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) |
| 71 | + slot_mapping = attn_metadata.slot_mapping |
| 72 | + has_decode = attn_metadata.num_decodes > 0 |
| 73 | + has_prefill = attn_metadata.num_prefills > 0 |
| 74 | + num_decode_tokens = attn_metadata.num_decode_tokens |
| 75 | + |
| 76 | + ops.indexer_k_quant_and_cache( |
| 77 | + k, |
| 78 | + kv_cache, |
| 79 | + slot_mapping, |
| 80 | + quant_block_size, |
| 81 | + scale_fmt, |
| 82 | + ) |
| 83 | + |
| 84 | + topk_indices_buffer[: hidden_states.shape[0]] = -1 |
| 85 | + if has_prefill: |
| 86 | + prefill_metadata = attn_metadata.prefill |
| 87 | + |
| 88 | + # Get the full shared workspace buffers once (will allocate on first use) |
| 89 | + workspace_manager = current_workspace_manager() |
| 90 | + k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( |
| 91 | + ((total_seq_lens, head_dim), fp8_dtype), |
| 92 | + ((total_seq_lens, 4), torch.uint8), |
| 93 | + ) |
| 94 | + for chunk in prefill_metadata.chunks: |
| 95 | + k_fp8 = k_fp8_full[: chunk.total_seq_lens] |
| 96 | + k_scale = k_scale_full[: chunk.total_seq_lens] |
| 97 | + ops.cp_gather_indexer_k_quant_cache( |
| 98 | + kv_cache, |
| 99 | + k_fp8, |
| 100 | + k_scale, |
| 101 | + chunk.block_table, |
| 102 | + chunk.cu_seq_lens, |
| 103 | + ) |
| 104 | + |
| 105 | + logits = fp8_mqa_logits( |
| 106 | + q_fp8[chunk.token_start : chunk.token_end], |
| 107 | + (k_fp8, k_scale.view(torch.float32).flatten()), |
| 108 | + weights[chunk.token_start : chunk.token_end], |
| 109 | + chunk.cu_seqlen_ks, |
| 110 | + chunk.cu_seqlen_ke, |
| 111 | + ) |
| 112 | + num_rows = logits.shape[0] |
| 113 | + |
| 114 | + topk_indices = topk_indices_buffer[ |
| 115 | + chunk.token_start : chunk.token_end, :topk_tokens |
| 116 | + ] |
| 117 | + torch.ops._C.top_k_per_row_prefill( |
| 118 | + logits, |
| 119 | + chunk.cu_seqlen_ks, |
| 120 | + chunk.cu_seqlen_ke, |
| 121 | + topk_indices, |
| 122 | + num_rows, |
| 123 | + logits.stride(0), |
| 124 | + logits.stride(1), |
| 125 | + topk_tokens, |
| 126 | + ) |
| 127 | + |
| 128 | + if has_decode: |
| 129 | + decode_metadata = attn_metadata.decode |
| 130 | + # kv_cache size requirement [num_block, block_size, n_head, head_dim], |
| 131 | + # we only have [num_block, block_size, head_dim], |
| 132 | + kv_cache = kv_cache.unsqueeze(-2) |
| 133 | + decode_lens = decode_metadata.decode_lens |
| 134 | + if decode_metadata.requires_padding: |
| 135 | + # pad in edge case where we have short chunked prefill length < |
| 136 | + # decode_threshold since we unstrictly split |
| 137 | + # prefill and decode by decode_threshold |
| 138 | + # (currently set to 1 + speculative tokens) |
| 139 | + padded_q_fp8_decode_tokens = pack_seq_triton( |
| 140 | + q_fp8[:num_decode_tokens], decode_lens |
| 141 | + ) |
| 142 | + else: |
| 143 | + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( |
| 144 | + decode_lens.shape[0], -1, *q_fp8.shape[1:] |
| 145 | + ) |
| 146 | + # TODO: move and optimize below logic with triton kernels |
| 147 | + batch_size = padded_q_fp8_decode_tokens.shape[0] |
| 148 | + next_n = padded_q_fp8_decode_tokens.shape[1] |
| 149 | + assert batch_size == decode_metadata.seq_lens.shape[0] |
| 150 | + num_padded_tokens = batch_size * next_n |
| 151 | + |
| 152 | + logits = fp8_paged_mqa_logits( |
| 153 | + padded_q_fp8_decode_tokens, |
| 154 | + kv_cache, |
| 155 | + weights[:num_padded_tokens], |
| 156 | + decode_metadata.seq_lens, |
| 157 | + decode_metadata.block_table, |
| 158 | + decode_metadata.schedule_metadata, |
| 159 | + max_model_len=max_model_len, |
| 160 | + ) |
| 161 | + |
| 162 | + num_rows = logits.shape[0] |
| 163 | + |
| 164 | + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] |
| 165 | + torch.ops._C.top_k_per_row_decode( |
| 166 | + logits, |
| 167 | + next_n, |
| 168 | + decode_metadata.seq_lens, |
| 169 | + topk_indices, |
| 170 | + num_rows, |
| 171 | + logits.stride(0), |
| 172 | + logits.stride(1), |
| 173 | + topk_tokens, |
| 174 | + ) |
| 175 | + |
| 176 | + if decode_metadata.requires_padding: |
| 177 | + # if padded, we need to unpack |
| 178 | + # the topk indices removing padded tokens |
| 179 | + topk_indices = unpack_seq_triton( |
| 180 | + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), |
| 181 | + decode_lens, |
| 182 | + ) |
| 183 | + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( |
| 184 | + topk_indices |
| 185 | + ) |
| 186 | + |
| 187 | + return topk_indices_buffer |
| 188 | + |
| 189 | + |
| 190 | +def sparse_attn_indexer_fake( |
| 191 | + hidden_states: torch.Tensor, |
| 192 | + k_cache_prefix: str, |
| 193 | + kv_cache: torch.Tensor, |
| 194 | + q_fp8: torch.Tensor, |
| 195 | + k: torch.Tensor, |
| 196 | + weights: torch.Tensor, |
| 197 | + quant_block_size: int, |
| 198 | + scale_fmt: str | None, |
| 199 | + topk_tokens: int, |
| 200 | + head_dim: int, |
| 201 | + max_model_len: int, |
| 202 | + total_seq_lens: int, |
| 203 | + topk_indices_buffer: torch.Tensor | None, |
| 204 | +) -> torch.Tensor: |
| 205 | + return topk_indices_buffer |
| 206 | + |
| 207 | + |
| 208 | +direct_register_custom_op( |
| 209 | + op_name="sparse_attn_indexer", |
| 210 | + op_func=sparse_attn_indexer, |
| 211 | + mutates_args=["topk_indices_buffer"], |
| 212 | + fake_impl=sparse_attn_indexer_fake, |
| 213 | + dispatch_key=current_platform.dispatch_key, |
| 214 | +) |
| 215 | + |
| 216 | + |
| 217 | +@CustomOp.register("sparse_attn_indexer") |
| 218 | +class SparseAttnIndexer(CustomOp): |
| 219 | + """Sparse Attention Indexer Custom Op Layer. This layer is extracted as a |
| 220 | + separate custom op since it involves heavy custom kernels like `mqa_logits`, |
| 221 | + `paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires |
| 222 | + specific memory layout or implementation for different hardware backends to |
| 223 | + achieve optimal performance. |
| 224 | +
|
| 225 | + For now, the default native path will use CUDA backend path. Other platform |
| 226 | + may requires add the corresponding Custom Op name `sparse_attn_indexer` to |
| 227 | + `custom_ops` in `CompilationConfig` to enable the platform specific path. |
| 228 | + """ |
| 229 | + |
| 230 | + def __init__( |
| 231 | + self, |
| 232 | + k_cache, |
| 233 | + quant_block_size: int, |
| 234 | + scale_fmt: str, |
| 235 | + topk_tokens: int, |
| 236 | + head_dim: int, |
| 237 | + max_model_len: int, |
| 238 | + max_total_seq_len: int, |
| 239 | + topk_indices_buffer: torch.Tensor, |
| 240 | + ): |
| 241 | + super().__init__() |
| 242 | + self.k_cache = k_cache |
| 243 | + self.quant_block_size = quant_block_size |
| 244 | + self.scale_fmt = scale_fmt |
| 245 | + self.topk_tokens = topk_tokens |
| 246 | + self.head_dim = head_dim |
| 247 | + self.max_model_len = max_model_len |
| 248 | + self.max_total_seq_len = max_total_seq_len |
| 249 | + self.topk_indices_buffer = topk_indices_buffer |
| 250 | + |
| 251 | + def forward_native( |
| 252 | + self, |
| 253 | + hidden_states: torch.Tensor, |
| 254 | + q_fp8: torch.Tensor, |
| 255 | + k: torch.Tensor, |
| 256 | + weights: torch.Tensor, |
| 257 | + ): |
| 258 | + if current_platform.is_cuda(): |
| 259 | + return self.forward_cuda(hidden_states, q_fp8, k, weights) |
| 260 | + elif current_platform.is_rocm(): |
| 261 | + return self.forward_hip(hidden_states, q_fp8, k, weights) |
| 262 | + else: |
| 263 | + raise NotImplementedError( |
| 264 | + "SparseAttnIndexer native forward is only implemented for " |
| 265 | + "CUDA and ROCm platform." |
| 266 | + ) |
| 267 | + |
| 268 | + def forward_cuda( |
| 269 | + self, |
| 270 | + hidden_states: torch.Tensor, |
| 271 | + q_fp8: torch.Tensor, |
| 272 | + k: torch.Tensor, |
| 273 | + weights: torch.Tensor, |
| 274 | + ): |
| 275 | + return torch.ops.vllm.sparse_attn_indexer( |
| 276 | + hidden_states, |
| 277 | + self.k_cache.prefix, |
| 278 | + self.k_cache.kv_cache[0], |
| 279 | + q_fp8, |
| 280 | + k, |
| 281 | + weights, |
| 282 | + self.quant_block_size, |
| 283 | + self.scale_fmt, |
| 284 | + self.topk_tokens, |
| 285 | + self.head_dim, |
| 286 | + self.max_model_len, |
| 287 | + self.max_total_seq_len, |
| 288 | + self.topk_indices_buffer, |
| 289 | + ) |
| 290 | + |
| 291 | + def forward_hip( |
| 292 | + self, |
| 293 | + hidden_states: torch.Tensor, |
| 294 | + q_fp8: torch.Tensor, |
| 295 | + k: torch.Tensor, |
| 296 | + weights: torch.Tensor, |
| 297 | + ): |
| 298 | + if rocm_aiter_ops.is_enabled(): |
| 299 | + return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( |
| 300 | + hidden_states, |
| 301 | + self.k_cache.prefix, |
| 302 | + self.k_cache.kv_cache[0], |
| 303 | + q_fp8, |
| 304 | + k, |
| 305 | + weights, |
| 306 | + self.quant_block_size, |
| 307 | + self.scale_fmt, |
| 308 | + self.topk_tokens, |
| 309 | + self.head_dim, |
| 310 | + self.max_model_len, |
| 311 | + self.max_total_seq_len, |
| 312 | + self.topk_indices_buffer, |
| 313 | + ) |
| 314 | + else: |
| 315 | + raise RuntimeError( |
| 316 | + "Sparse attention indexer ROCm custom op requires ROCm " |
| 317 | + "Aiter ops to be enabled." |
| 318 | + ) |
0 commit comments