diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index df6ebd2c..da97c10d 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -37,6 +37,7 @@ def __init__(self, config: Config): self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.hash_to_block_id: dict[int, int] = dict() self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.free_block_ids_set: set[int] = set(range(num_blocks)) self.used_block_ids: set[int] = set() self.enable_prefix_caching = config.enable_prefix_caching @@ -48,11 +49,23 @@ def compute_hash(cls, token_ids: list[int], prefix: int = -1): h.update(np.array(token_ids).tobytes()) return h.intdigest() + def _pop_free_block(self) -> int: + """Pop the next available free block id from the FIFO queue (lazy cleanup).""" + while self.free_block_ids: + block_id = self.free_block_ids.popleft() + if block_id in self.free_block_ids_set: + self.free_block_ids_set.discard(block_id) + return block_id + raise AssertionError("No free blocks available") + def _allocate_block(self, block_id: int) -> Block: block = self.blocks[block_id] assert block.ref_count == 0 + # Evict stale hash entry before resetting + if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id: + del self.hash_to_block_id[block.hash] block.reset() - self.free_block_ids.remove(block_id) + self.free_block_ids_set.discard(block_id) self.used_block_ids.add(block_id) return self.blocks[block_id] @@ -60,10 +73,28 @@ def _deallocate_block(self, block_id: int): assert self.blocks[block_id].ref_count == 0 self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) - # self.free_block_ids.appendleft(block_id) + self.free_block_ids_set.add(block_id) def can_allocate(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= seq.num_blocks + seq.num_mamba_blocks + if not self.enable_prefix_caching: + return len(self.free_block_ids_set) >= seq.num_blocks + seq.num_mamba_blocks + # Dry-run: count how many blocks would be cache hits + h = -1 + cache_miss = False + needed_free = 0 + for i in range(seq.num_blocks): + token_ids = seq.block(i) + h = ( + self.compute_hash(token_ids, h) + if len(token_ids) == self.block_size + else -1 + ) + block_id = self.hash_to_block_id.get(h, -1) + if block_id == -1 or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + if cache_miss: + needed_free += 1 + return len(self.free_block_ids_set) >= needed_free def allocate(self, seq: Sequence): assert not seq.block_table @@ -82,7 +113,7 @@ def allocate(self, seq: Sequence): if block_id == -1 or self.blocks[block_id].token_ids != token_ids: cache_miss = True if cache_miss: - block_id = self.free_block_ids[0] + block_id = self._pop_free_block() block = self._allocate_block(block_id) else: seq.num_cached_tokens += self.block_size @@ -122,12 +153,17 @@ def deallocate(self, seq: Sequence): self._deallocate_block(block_id) seq.mamba_block_table.clear() - def can_append(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool: + seq_len = len(seq) + current_blocks = len(seq.block_table) + needed_blocks = ( + seq_len + num_new_tokens + self.block_size - 1 + ) // self.block_size + new_blocks_needed = max(0, needed_blocks - current_blocks) + return len(self.free_block_ids_set) >= new_blocks_needed def may_append(self, seq: Sequence, num_new_tokens: int = 1): block_table = seq.block_table - last_block = self.blocks[block_table[-1]] seq_len = len(seq) # Check if we need to allocate a new block # When len(seq) % block_size == 1, we need a new block for the next token @@ -135,42 +171,9 @@ def may_append(self, seq: Sequence, num_new_tokens: int = 1): if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1: needed_blocks = (seq_len + self.block_size - 1) // self.block_size while len(block_table) < needed_blocks: - # For block_size == 1, we need to update hash for each new block - # For block_size > 1, the previous block should have hash != -1 (unless it's the first block) - if self.block_size == 1: - # Allocate new block and update hash immediately (like allocate does for full blocks) - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - token_ids = [seq[-1]] - prefix = ( - self.blocks[block_table[-2]].hash - if len(block_table) > 1 - else -1 - ) - h = self.compute_hash(token_ids, prefix) - block.update(h, token_ids) - self.hash_to_block_id[h] = block_id - else: - # For block_size > 1, we only allocate new block when needed - # The hash will be updated when the block becomes full - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - last_block = block - elif seq_len % self.block_size == 0: - # Last block is now full, update its hash (similar to allocate) - # TODO: fix hash - token_ids = seq.block(seq.num_blocks - 1) - if len(token_ids) == self.block_size: - prefix = ( - self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 - ) - h = self.compute_hash(token_ids, prefix) - last_block.update(h, token_ids) - self.hash_to_block_id[h] = last_block.block_id - else: - pass - # Last block is not full and not at the boundary - # Hash remains -1 until block is full (consistent with allocate logic) - # assert last_block.hash == -1, last_block.block_id + # Decode-generated blocks: token not finalized yet (depends on + # sampling / speculative verification), so we cannot compute a + # correct hash here. Just allocate the block without hashing. + block_id = self._pop_free_block() + self._allocate_block(block_id) + block_table.append(block_id) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 4c3f5a12..6f30de71 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -107,6 +107,70 @@ def _log(self) -> None: ) +class CacheStats: + """Tracks prefix caching hit statistics.""" + + __slots__ = ( + "_log_interval", + "total_requests", + "total_cached_tokens", + "total_full_tokens", + "_interval_requests", + "_interval_cached_tokens", + "_interval_full_tokens", + ) + + def __init__(self, log_interval: int = 100): + self._log_interval = log_interval + self.total_requests: int = 0 + self.total_cached_tokens: int = 0 + self.total_full_tokens: int = 0 + self._interval_requests: int = 0 + self._interval_cached_tokens: int = 0 + self._interval_full_tokens: int = 0 + + def update(self, num_cached_tokens: int, num_full_tokens: int) -> None: + """Record cache stats for one prefill sequence.""" + self.total_requests += 1 + self.total_cached_tokens += num_cached_tokens + self.total_full_tokens += num_full_tokens + self._interval_requests += 1 + self._interval_cached_tokens += num_cached_tokens + self._interval_full_tokens += num_full_tokens + + if self.total_requests % self._log_interval == 0: + self._log() + self._reset_interval() + + @property + def hit_rate(self) -> float: + if self.total_full_tokens == 0: + return 0.0 + return self.total_cached_tokens / self.total_full_tokens + + def _reset_interval(self) -> None: + self._interval_requests = 0 + self._interval_cached_tokens = 0 + self._interval_full_tokens = 0 + + def _log(self) -> None: + iv_rate = ( + self._interval_cached_tokens / self._interval_full_tokens + if self._interval_full_tokens > 0 + else 0.0 + ) + logger.info( + f"[Cache Stats Interval] Reqs: {self._interval_requests}, " + f"Cached/Total tokens: {self._interval_cached_tokens}/{self._interval_full_tokens}, " + f"Hit rate: {iv_rate:.2%}" + ) + logger.info( + f"[Cache Stats ] Reqs: {self.total_requests}, " + f"Cached/Total tokens: {self.total_cached_tokens}/{self.total_full_tokens}, " + f"Hit rate: {self.hit_rate:.2%}" + ) + + class ScheduledBatch: def __init__( self, @@ -246,6 +310,9 @@ def __init__(self, config: Config): self.spec_stats: Optional[SpecStats] = ( SpecStats(mtp_k=self.mtp_k) if self.use_spec else None ) + self.cache_stats: Optional[CacheStats] = ( + CacheStats() if config.enable_prefix_caching else None + ) def is_finished(self): return not self.waiting and not self.running @@ -286,6 +353,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: # Recalculate after allocation: prefix caching may have updated # seq.num_cached_tokens, reducing the actual number of new tokens. num_new_tokens = seq.num_tokens - seq.num_cached_tokens + if self.cache_stats: + self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) num_batched_tokens += num_new_tokens seq.status = SequenceStatus.RUNNING seq.type = SequenceType.PREFILL @@ -319,7 +388,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_seqs_decode = 0 while self.running and num_seqs_decode < self.max_num_seqs: seq = self.running.popleft() - while not self.block_manager.can_append(seq): + while not self.block_manager.can_append(seq, self.mtp_k + 1): if self.running: self.preempt(self.running.pop()) else: diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index b38ca0ee..c42a8c23 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -15,8 +15,13 @@ from .attention_mla import MLAModules +import logging + from atom.plugin.prepare import is_plugin_mode, is_vllm from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode +from atom.model_ops.base_attention import cp_mha_gather_cache + +logger = logging.getLogger("atom") @PagedAttentionImplDecoratorForPluginMode @@ -127,6 +132,14 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): and self.q_norm is not None and self.k_norm is not None ): + # fused_qk_norm_rope_cache_quant_shuffle expects V cache layout + # [num_blocks, num_kv_heads, block_size//x, head_size, x], not [n, nh, hd, bs] + x = 16 // k_cache.element_size() + if k_cache.dim() == 5 and v_cache.dim() == 4: + n, nh, hd, bs = v_cache.shape + v_cache_shuffle = v_cache.view(n, nh, bs // x, hd, x) + else: + v_cache_shuffle = v_cache fused_qk_norm_rope_cache_quant_shuffle( qkv, num_heads_q=self.num_heads, @@ -140,7 +153,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): is_neox_style=self.rotary_emb.is_neox_style, pos_ids=position, k_cache=k_cache, - v_cache=v_cache, + v_cache=v_cache_shuffle, slot_mapping=attn_metadata.slot_mapping, kv_cache_dtype=( "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype @@ -212,8 +225,95 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): asm_layout=asm_layout, ) + # Prefix cache hit: gather cached KV from paged cache and concat with new tokens + if attn_metadata.has_cached: + q, k, v, k_cache, v_cache, k_scale, v_scale = ( + self._gather_prefix_and_concat_kv( + q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata + ) + ) + return q, k, v, k_cache, v_cache, k_scale, v_scale + def _gather_prefix_and_concat_kv( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + attn_metadata, + ): + """ + When prefix cache hits, gather full KV (cached + new) from paged cache in + one pass. New tokens are already written by fused_qk_rope_reshape_and_cache. + Same flow as gather_kv_b_proj: write new first, then read cached+new together. + token_to_batch, seq_starts are built in prepare_prefill. + """ + cu_seqlens_k = attn_metadata.cu_seqlens_k + total_tokens = cu_seqlens_k[-1].item() + token_to_batch = attn_metadata.token_to_batch + seq_starts = attn_metadata.seq_starts + + num_kv_heads = k.shape[1] + head_dim = k.shape[2] + device = k.device + dtype = k.dtype + + k_full = torch.empty( + (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device + ) + v_full = torch.empty( + (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device + ) + + # Convert cache for cp_mha_gather_cache + # fused_qk_norm_rope_cache_quant_shuffle: K [n, nh, hd//x, bs, x], V [n, nh, bs//x, hd, x] (SHUFFLE) + # fused_qk_rope_reshape_and_cache: K [n, nh, hd//x, bs, x], V [n, nh, hd, bs] -> NHD + if k_cache.dim() == 5: + x = 16 // k_cache.element_size() + n, nh, _, block_size, _ = k_cache.shape + if v_cache.dim() == 4: + # fused_qk_norm_rope_cache_quant_shuffle: V data in [n, nh, bs//x, hd, x] layout + use_shuffle = True + k_cache_gather = k_cache + v_cache_gather = v_cache.view(n, nh, block_size // x, head_dim, x) + else: + # fused_qk_rope_reshape_and_cache: V [n, nh, hd, bs] -> NHD + use_shuffle = False + k_cache_gather = ( + k_cache.permute(0, 3, 1, 2, 4) + .contiguous() + .view(n, block_size, nh, head_dim) + ) + v_cache_gather = v_cache.permute(0, 3, 1, 2).contiguous() + else: + use_shuffle = False + k_cache_gather = k_cache + v_cache_gather = v_cache + block_size = k_cache.shape[1] + + block_tables = attn_metadata.block_tables + cp_mha_gather_cache( + key_cache=k_cache_gather, + value_cache=v_cache_gather, + key=k_full, + value=v_full, + block_tables=block_tables, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_k, + token_to_batch=token_to_batch, + seq_starts=seq_starts, + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="SHUFFLE" if use_shuffle else "NHD", + total_tokens=total_tokens, + ) + + return q, k_full, v_full, k_cache, v_cache, k_scale, v_scale + def paged_attention_triton( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext ): diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 9d6fe94d..5d62b158 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -33,7 +33,7 @@ from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, ) - +from aiter.ops.triton.gather_kv_b_proj import gather_kv_b_proj from atom.plugin import is_plugin_mode @@ -466,7 +466,7 @@ def _forward_prefill_mla( attn_metadata.token_to_seq_idxs, self.topk_indices_buffer[:B], attn_metadata.block_tables, - attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], ) paged_cu_seqlens_q = attn_metadata.sparse_cu_seqlens_q @@ -623,6 +623,12 @@ def forward_impl_server_mode( kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache if context.is_prefill and not use_prefill_mla: + use_prefix_cache = ( + attn_metadata.has_cached + and not is_rocm_aiter_fp4bmm_enabled() + and self.qk_nope_head_dim == self.v_head_dim + ) + prefill_q = self.q_proj(q, x_scale=q_scale).view( -1, self.num_heads, self.qk_head_dim ) @@ -639,9 +645,60 @@ def forward_impl_server_mode( scale=self._k_scale, ) - output = self._forward_prefill_mha( - prefill_q, k_nope, k_rope, kv_cache, attn_metadata - ) + if use_prefix_cache: + total_tokens = attn_metadata.cu_seqlens_k[-1].item() + output_dtype = ( + dtypes.fp8 if self.kv_cache_dtype.startswith("fp8") else self.dtype + ) + k_full = torch.empty( + ( + total_tokens, + self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + ), + device=q.device, + dtype=output_dtype, + ) + v_full = torch.empty( + ( + total_tokens, + self.num_heads, + self.qk_nope_head_dim, + ), + device=q.device, + dtype=output_dtype, + ) + + gather_kv_b_proj( + kv_cache, + self._k_scale, + attn_metadata.kv_indptr, + attn_metadata.kv_indices, + attn_metadata.cu_seqlens_k, + self.kv_b_proj.weight, + self.kv_b_proj.weight_scale, + k_full, + v_full, + weight_preshuffle=True, + ) + output = flash_attn_varlen_func( + q=prefill_q, + k=k_full, + v=v_full, + cu_seqlens_q=attn_metadata.cu_seqlens_q, + cu_seqlens_k=attn_metadata.cu_seqlens_k, + max_seqlen_q=attn_metadata.max_seqlen_q, + max_seqlen_k=attn_metadata.max_seqlen_k, + min_seqlen_q=attn_metadata.min_seqlen_q, + dropout_p=attn_metadata.dropout_p, + softmax_scale=self.scale, + causal=True, + ) + output = self.o_proj(output.flatten(start_dim=-2)) + else: + output = self._forward_prefill_mha( + prefill_q, k_nope, k_rope, kv_cache, attn_metadata + ) else: q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index b51239fc..ab4eb3ba 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -264,13 +264,23 @@ def prepare_prefill(self, batch: ScheduledBatch): self.block_ratio, ) attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] - var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = ( - np.arange(sum_scheduled_tokens, dtype=np.int32) + 1 - ) counts = var["cu_seqlens_q"].np[1 : bs + 1] - var["cu_seqlens_q"].np[:bs] - var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( - var["cu_seqlens_q"].np[:bs], counts - ) + if attn_metadata.has_cached: + # Full context (cached + new): use cu_seqlens_k for indexer + cu_seqlens_k_np = attn_metadata.cu_seqlens_k.cpu().numpy() + var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( + cu_seqlens_k_np[:-1], counts + ) + var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = np.repeat( + cu_seqlens_k_np[1:], counts + ) + else: + var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = ( + np.arange(sum_scheduled_tokens, dtype=np.int32) + 1 + ) + var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( + var["cu_seqlens_q"].np[:bs], counts + ) attn_metadata.cu_seqlen_ks = var["cu_seqlen_ks"].copy_to_gpu( sum_scheduled_tokens ) @@ -284,9 +294,11 @@ def prepare_prefill(self, batch: ScheduledBatch): :sum_scheduled_tokens ] + # Per-query req_id: token_id 0..sum_scheduled_tokens-1 maps to batch id. + # Use counts (new tokens per batch), not context_lens (full seq len). attn_metadata.token_to_seq_idxs = torch.repeat_interleave( torch.arange(bs, dtype=torch.int32, device=self.device), - attn_metadata.context_lens, + torch.tensor(counts, dtype=torch.int64, device=self.device), ) var["sparse_kv_indptr"].np[0] = 0 var["sparse_kv_indptr"].np[1 : sum_scheduled_tokens + 1] = np.cumsum( @@ -300,17 +312,33 @@ def prepare_prefill(self, batch: ScheduledBatch): sum_scheduled_tokens + 1 ) - if hasattr(self.model_runner, "drafter"): + if hasattr(self.model_runner, "drafter") or attn_metadata.has_cached: + # Populate kv_last_page_lens for full sequence (needed for MLA prefill with + # prefix cache; decode does the same) + if self.model_runner.block_size != 1: + var["kv_last_page_lens"].np[:bs] = np.asarray( + batch.last_block_num_tokens[:bs], dtype=np.int32 + ) + else: + var["kv_last_page_lens"].np[:bs] = 1 + var["kv_last_page_lens"].copy_to_gpu() + attn_metadata.kv_indices = var["kv_indices"].gpu attn_metadata.kv_indptr = var["kv_indptr"].gpu[: bs + 1] + attn_metadata.kv_indptr[0] = 0 attn_metadata.kv_indptr[1 : bs + 1] = torch.cumsum( attn_metadata.context_lens, 0 ) - if attn_metadata.block_tables is None: - self.prepare_block_tables(batch) - attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) + attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] + + # kv_indices_generate_triton expects RAW block_tables (physical block ids, + # one per block_ratio tokens). When is_sparse, attn_metadata.block_tables + # may have been overwritten with block_tables_converted (slot per token). + # Always use raw block_tables for kv_indices. + self.prepare_block_tables(batch) + block_tables_for_kv = var["block_tables"].copy_to_gpu(bs) kv_indices_generate_triton( - attn_metadata.block_tables, + block_tables_for_kv, attn_metadata.kv_indices, attn_metadata.kv_indptr, self.block_ratio, diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index a2226f61..247a512f 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Generic, Optional, Type, TypeVar @@ -12,6 +13,7 @@ from atom.utils.forward_context import AttentionMetaData from torch import nn +logger = logging.getLogger("atom") T = TypeVar("T", bound="BroadcastableModelInput") @@ -114,6 +116,8 @@ def __init__(self, model_runner): ), "cu_seqlens_q": CpuGpuBuffer(self.max_bs + 1, **i32_kwargs), "cu_seqlens_k": CpuGpuBuffer(self.max_bs + 1, **i32_kwargs), + # seq_starts for cp_mha_gather_cache: always zeros (prefix at position 0) + "seq_starts": CpuGpuBuffer(self.max_bs, **i32_kwargs), } if self.block_ratio > 1: attn_metadata["block_tables_converted"] = CpuGpuBuffer( @@ -126,6 +130,8 @@ def __init__(self, model_runner): torch.arange(0, self.max_bs + 1, step=1, dtype=torch.int32) ) attn_metadata["cu_seqlens_q"].copy_to_gpu() + attn_metadata["seq_starts"].cpu.zero_() + attn_metadata["seq_starts"].copy_to_gpu() self.model_runner.forward_vars.update(attn_metadata) self.has_sliding_window = hasattr(hf_config, "sliding_window") @@ -141,18 +147,23 @@ def prepare_prefill(self, batch: ScheduledBatch): sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars positions = [] + cu_seqlens_q = [0] cu_seqlens_k = [0] max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] + has_cached = False # seqs = list(batch.seqs.values()) # seqs = seqs[:bs] for i in range(bs): seqlen = batch.context_lens[i] cached_seqlen = batch.num_cached_tokens[i] + if cached_seqlen > 0: + has_cached = True positions.extend(list(range(cached_seqlen, seqlen))) seqlen_q = seqlen - cached_seqlen seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) @@ -166,18 +177,30 @@ def prepare_prefill(self, batch: ScheduledBatch): ) // self.model_runner.block_size last_block_tokens = batch.last_block_num_tokens[i] block_table = batch.block_tables[i] - for i in range(num_cached_blocks, num_blocks): - start = block_table[i] * self.model_runner.block_size - if i != num_blocks - 1: + for blk_idx in range(num_cached_blocks, num_blocks): + start = block_table[blk_idx] * self.model_runner.block_size + if blk_idx != num_blocks - 1: end = start + self.model_runner.block_size else: end = start + last_block_tokens slot_mapping.extend(list(range(start, end))) - if cu_seqlens_k[-1] > batch.total_tokens_num: # prefix cache + if has_cached: self.prepare_block_tables(batch) + # Validate metadata consistency + assert ( + len(positions) == sum_scheduled_tokens + ), f"positions length {len(positions)} != sum_scheduled_tokens {sum_scheduled_tokens}" + if batch.block_tables: + assert ( + len(slot_mapping) == sum_scheduled_tokens + ), f"slot_mapping length {len(slot_mapping)} != sum_scheduled_tokens {sum_scheduled_tokens}" + assert ( + cu_seqlens_q[-1] == sum_scheduled_tokens + ), f"cu_seqlens_q[-1]={cu_seqlens_q[-1]} != sum_scheduled_tokens={sum_scheduled_tokens}" var["positions"].np[:sum_scheduled_tokens] = positions var["slot_mapping"].np[:sum_scheduled_tokens] = -1 var["slot_mapping"].np[: len(slot_mapping)] = slot_mapping + var["cu_seqlens_q"].np[: bs + 1] = cu_seqlens_q cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True) var["context_lens"].np[:bs] = batch.context_lens[:bs] min_seqlen_q = 0 @@ -187,6 +210,9 @@ def prepare_prefill(self, batch: ScheduledBatch): ("slot_mapping", sum_scheduled_tokens), ("context_lens", bs), ] + if has_cached: + vars_used.append(("block_tables", bs)) + vars_used.append(("seq_starts", bs)) ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used} if self.block_ratio > 1 and "block_tables" in ctx: @@ -197,12 +223,36 @@ def prepare_prefill(self, batch: ScheduledBatch): self.block_ratio, ) ctx["block_tables_converted"] = var["block_tables_converted"].gpu[:bs] + num_cached_tokens = None + token_to_batch = None + if has_cached: + num_cached_tokens = torch.tensor( + batch.num_cached_tokens[:bs], dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + if self.model_runner.rank == 0: + logger.info(f"{has_cached=}") + logger.info( + f"Prefill batch has {num_cached_tokens.sum().item()} cached tokens and of {sum_scheduled_tokens} total tokens" + ) + # Build metadata for cp_mha_gather_cache (full sequence: cached + new) + # token_to_batch: [0]*len0 + [1]*len1 + ... + total_tokens = cu_seqlens_k[-1] + token_to_batch = torch.zeros( + total_tokens, dtype=torch.int32, device=self.device + ) + for i in range(bs): + start = cu_seqlens_k[i] + end = cu_seqlens_k[i + 1] + token_to_batch[start:end] = i attn_metadata = AttentionMetaData( cu_seqlens_k=cu_seqlens_k.cuda(non_blocking=True), max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, min_seqlen_q=min_seqlen_q, dropout_p=dropout_p, + has_cached=has_cached, + num_cached_tokens=num_cached_tokens, + token_to_batch=token_to_batch, **ctx, ) positions = var["positions"].copy_to_gpu(sum_scheduled_tokens) diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 47caa6a3..f3a41b3c 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -152,13 +152,18 @@ def cp_mha_gather_cache( assert k_scales is not None and v_scales is not None head_dim = key.shape[2] x = 16 // key_cache.element_size() - # For k cache layout: [num_blocks, num_heads, page_size, head_dim] - assert head_dim == key_cache.shape[3], ( - "We assume your kv cache layout is [num_blocks, " - "page_size, num_heads, head_dim], but got otherwise" - ) - page_size = key_cache.shape[1] - num_heads = key_cache.shape[2] + if kv_cache_layout == "NHD": + # K: [num_blocks, page_size, num_heads, head_dim] + assert head_dim == key_cache.shape[3] + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + else: + # SHUFFLE: K [num_blocks, num_heads, head_dim//x, page_size, x] + assert ( + key_cache.dim() == 5 and head_dim == key_cache.shape[2] * key_cache.shape[4] + ) + page_size = key_cache.shape[3] + num_heads = key_cache.shape[1] grid = lambda meta: (total_tokens, num_heads) # noqa: E731 cp_mha_gather_cache_kernel[grid]( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d1da9f05..f98480be 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -967,11 +967,15 @@ def sparse_attn_indexer( return weights prefill_metadata = attn_metadata num_prefills = context.batch_size - total_seq_lens = hidden_states.shape[0] - k_fp8 = torch.empty( - [total_seq_lens, head_dim], device=k.device, dtype=dtypes.fp8 + num_tokens = hidden_states.shape[0] + # When has_cached, gather full KV (cached + new) for indexer top-k + total_kv = ( + prefill_metadata.cu_seqlens_k[-1].item() + if prefill_metadata.has_cached + else num_tokens ) - k_scale = torch.empty([total_seq_lens, 1], device=k.device, dtype=torch.float32) + k_fp8 = torch.empty([total_kv, head_dim], device=k.device, dtype=dtypes.fp8) + k_scale = torch.empty([total_kv, 1], device=k.device, dtype=torch.float32) if prefill_metadata.block_tables.shape[0] < num_prefills: new_shape = (num_prefills, prefill_metadata.block_tables.shape[1]) prefill_metadata.block_tables = torch.full( @@ -985,12 +989,14 @@ def sparse_attn_indexer( k_fp8, k_scale.view(dtypes.fp8), prefill_metadata.block_tables, - prefill_metadata.cu_seqlens_q, - # num_prefills, + ( + prefill_metadata.cu_seqlens_k + if prefill_metadata.has_cached + else prefill_metadata.cu_seqlens_q + ), ) cu_seqlen_ks = prefill_metadata.cu_seqlen_ks cu_seqlen_ke = prefill_metadata.cu_seqlen_ke - num_tokens = hidden_states.shape[0] logits = fp8_mqa_logits( Q=q_fp8[num_decode_tokens:num_tokens], KV=k_fp8, diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 9b374620..2f1a637b 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -189,6 +189,12 @@ class AttentionMetaData: block_tables_converted: Optional[torch.Tensor] = None + # for prefix cache + has_cached: bool = False + num_cached_tokens: Optional[torch.Tensor] = None + token_to_batch: Optional[torch.Tensor] = None + seq_starts: Optional[torch.Tensor] = None + # only used for plugin mode to store the metadata for attn plugin_metadata: Optional["MetadataForPluginMode"] = None @@ -219,7 +225,15 @@ def __init__( sparse_cu_seqlens_q: Optional[torch.Tensor] = None, token_to_seq_idxs: Optional[torch.Tensor] = None, plugin_metadata: Optional["MetadataForPluginMode"] = None, + has_cached: bool = False, + num_cached_tokens: Optional[torch.Tensor] = None, + token_to_batch: Optional[torch.Tensor] = None, + seq_starts: Optional[torch.Tensor] = None, ): + self.has_cached = has_cached + self.num_cached_tokens = num_cached_tokens + self.token_to_batch = token_to_batch + self.seq_starts = seq_starts self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k self.max_seqlen_q = max_seqlen_q diff --git a/tests/test_block_manager.py b/tests/test_block_manager.py index 9dfb5d48..60a6c106 100644 --- a/tests/test_block_manager.py +++ b/tests/test_block_manager.py @@ -173,3 +173,175 @@ def test_block_size_1(self, seq_factory): seq.append_token(3) bm.may_append(seq) assert len(seq.block_table) == 3 + + +# ── Prefix caching: can_allocate with cache hits ───────────────────────── + + +class TestCanAllocateWithPrefixCaching: + def test_can_allocate_accounts_for_cache_hits(self, seq_factory): + """With 3 blocks total, allocate 2-block seq, deallocate, then a new + 2-block seq sharing block 1 should need only 1 free block.""" + cfg = MockConfig( + num_kvcache_blocks=3, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + bm.deallocate(s1) # blocks freed, hashes retained + + # Use up 2 of the 3 free blocks + filler = seq_factory([50, 51, 52, 53, 60, 61, 62, 63]) + bm.allocate(filler) + # Only 1 free block left; s2 needs 2 blocks but first is cached + s2 = seq_factory([1, 2, 3, 4, 9, 10, 11, 12]) + assert bm.can_allocate(s2) + + def test_can_allocate_no_false_positive(self, seq_factory): + """can_allocate should return False when even with cache hits + there aren't enough free blocks.""" + cfg = MockConfig( + num_kvcache_blocks=2, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + # 0 free blocks; new seq shares prefix but needs 1 new block + s2 = seq_factory([1, 2, 3, 4, 9, 10, 11, 12]) + assert not bm.can_allocate(s2) + + +# ── Hash table cleanup ─────────────────────────────────────────────────── + + +class TestHashTableCleanup: + def test_stale_hash_entries_evicted_on_reuse(self, seq_factory): + """When a cached block is reused for a different hash, the old + hash_to_block_id entry should be cleaned up.""" + cfg = MockConfig( + num_kvcache_blocks=2, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + h1 = bm.blocks[s1.block_table[0]].hash + bm.deallocate(s1) + + # Allocate with completely different tokens — should overwrite blocks + s2 = seq_factory([90, 91, 92, 93, 94, 95, 96, 97]) + bm.allocate(s2) + # Old hash should no longer point to a valid block + assert bm.hash_to_block_id.get(h1) != s2.block_table[0] + + def test_hash_table_bounded_growth(self, seq_factory): + """hash_to_block_id should not grow beyond num_kvcache_blocks.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + for i in range(20): + tokens = list(range(i * 4, i * 4 + 4)) + seq = seq_factory(tokens) + if bm.can_allocate(seq): + bm.allocate(seq) + bm.deallocate(seq) + assert len(bm.hash_to_block_id) <= cfg.num_kvcache_blocks + + +# ── can_append with multi-token decode (speculative decoding) ──────────── + + +class TestCanAppendMultiToken: + def test_can_append_multi_token_within_block(self, block_manager, seq_factory): + """Appending 3 tokens that stay within the current block.""" + seq = seq_factory([1]) + block_manager.allocate(seq) + seq.append_token(2) + seq.append_token(3) + assert block_manager.can_append(seq, num_new_tokens=3) + + def test_can_append_multi_token_crossing_boundary(self, seq_factory): + """block_size=4, seq_len=14 (3.5 blocks=4 blocks allocated), + appending 5 tokens crosses into block 5 — needs 1 new block.""" + cfg = MockConfig(num_kvcache_blocks=6, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory(list(range(14))) + bm.allocate(seq) + # seq_len=14, 4 blocks. Appending 5 tokens: positions 14..18 → need block 5 + for t in range(14, 19): + seq.append_token(t) + assert bm.can_append(seq, num_new_tokens=5) + + def test_cannot_append_multi_token_no_free(self, seq_factory): + """block_size=4, 4 blocks total, seq fills 4 blocks (16 tokens), + appending 5 tokens needs 2 new blocks but only 0 free.""" + cfg = MockConfig(num_kvcache_blocks=4, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory(list(range(14))) + bm.allocate(seq) + for t in range(14, 19): + seq.append_token(t) + assert not bm.can_append(seq, num_new_tokens=5) + + +# ── Prefix caching + preemption ────────────────────────────────────────── + + +class TestPrefixCachingPreemption: + def test_preempt_and_reschedule_reuses_cache(self, seq_factory): + """Preempted sequence re-discovers cache hits on re-allocation.""" + cfg = MockConfig( + num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + # Simulate preemption + bm.deallocate(s1) + assert s1.num_cached_tokens == 0 + assert s1.block_table == [] + + # Re-allocate — should get cache hits on both blocks + s1_retry = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1_retry) + assert s1_retry.num_cached_tokens == 8 # both blocks cached + + +# ── Edge cases ─────────────────────────────────────────────────────────── + + +class TestPrefixCachingEdgeCases: + def test_single_token_no_cache(self, seq_factory): + """Single token seq (shorter than block_size) — hash is -1, no caching.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([42]) + bm.allocate(s1) + bm.deallocate(s1) + s2 = seq_factory([42]) + bm.allocate(s2) + # Partial block → hash is -1 → no caching + assert s2.num_cached_tokens == 0 + + def test_exact_block_size_fully_cached(self, seq_factory): + """Sequence with exactly block_size tokens — fully cached on reuse.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4]) + bm.allocate(s1) + bm.deallocate(s1) + s2 = seq_factory([1, 2, 3, 4]) + bm.allocate(s2) + assert s2.num_cached_tokens == 4 + + def test_free_block_ids_set_consistent(self, block_manager, seq_factory): + """free_block_ids_set stays consistent through allocate/deallocate.""" + s1 = seq_factory([1, 2, 3, 4]) + block_manager.allocate(s1) + initial_free = len(block_manager.free_block_ids_set) + block_manager.deallocate(s1) + assert len(block_manager.free_block_ids_set) == initial_free + 1 diff --git a/tests/test_prefix_cache_accuracy.py b/tests/test_prefix_cache_accuracy.py new file mode 100644 index 00000000..0901b93c --- /dev/null +++ b/tests/test_prefix_cache_accuracy.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +""" +Test prefix caching accuracy with high cache-hit workloads. + +Sends batches of requests that share long common prefixes, then verifies: +1. Responses are correct (math problems with known answers) +2. Cache hit rate is high (visible in server logs) +3. Repeated identical requests produce consistent results +""" + +import argparse +import concurrent.futures +import re +import sys +import time + +import requests + +BASE_URL = "http://localhost:8000" + +# Long shared prefix: 5-shot math examples (~2000 tokens) +MATH_PREFIX = """You are a precise math assistant. Solve each problem step by step and give the final numerical answer after ####. + +Question: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? +Answer: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. #### 6 + +Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? +Answer: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. #### 5 + +Question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? +Answer: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. #### 39 + +Question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? +Answer: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. #### 8 + +Question: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? +Answer: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 2 + 2 = 4 more toys. 5 + 4 = 9. #### 9 + +""" + +# Test questions with known answers +TEST_QUESTIONS = [ + ( + "A farmer has 17 sheep. He buys 5 more and then sells 8. How many sheep does he have?", + 14, + ), + ( + "A store had 45 apples. They sold 12 in the morning and 18 in the afternoon. How many apples are left?", + 15, + ), + ( + "Tom has 8 marbles. Jerry has 3 times as many. How many marbles do they have together?", + 32, + ), + ( + "A classroom has 6 rows of desks with 5 desks in each row. If 7 desks are removed, how many remain?", + 23, + ), + ( + "Sarah baked 24 cookies. She gave 1/3 to her neighbor and ate 4 herself. How many cookies does she have left?", + 12, + ), + ( + "A train travels 60 miles per hour for 3 hours, then 40 miles per hour for 2 hours. What is the total distance?", + 260, + ), + ( + "Mike has 50 dollars. He spends 15 dollars on lunch and 20 dollars on a book. How much money does he have left?", + 15, + ), + ( + "A garden has 9 rose bushes. Each bush has 12 roses. If 25 roses are picked, how many roses remain?", + 83, + ), + ( + "Lisa read 35 pages on Monday and twice as many on Tuesday. How many pages did she read in total?", + 105, + ), + ( + "A box contains 100 balls. 40 are red, 35 are blue, and the rest are green. How many green balls are there?", + 25, + ), +] + + +def extract_answer(text: str): + """Extract numerical answer after the FIRST #### marker.""" + match = re.search(r"####\s*(-?\d+(?:\.\d+)?)", text) + if match: + return float(match.group(1)) + return None + + +def get_model_name(base_url: str) -> str: + """Get the model name from the server.""" + resp = requests.get(f"{base_url}/v1/models", timeout=5) + resp.raise_for_status() + return resp.json()["data"][0]["id"] + + +def send_completion( + prompt: str, max_tokens: int = 256, base_url: str = BASE_URL, model: str = "" +) -> str: + """Send a completion request to the server.""" + resp = requests.post( + f"{base_url}/v1/completions", + json={ + "model": model, + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, + }, + timeout=120, + ) + resp.raise_for_status() + return resp.json()["choices"][0]["text"] + + +def run_batch(questions, prefix, base_url=BASE_URL, model="", label=""): + """Run a batch of questions and return (correct, total, results).""" + results = [] + + def ask(q_and_a): + question, expected = q_and_a + prompt = prefix + f"Question: {question}\nAnswer:" + try: + response = send_completion(prompt, base_url=base_url, model=model) + answer = extract_answer(response) + correct = answer is not None and abs(answer - expected) < 0.01 + return (question, expected, answer, correct, response.strip()) + except Exception as e: + return (question, expected, None, False, f"ERROR: {e}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + results = list(pool.map(ask, questions)) + + num_correct = sum(1 for r in results if r[3]) + return num_correct, len(results), results + + +def main(): + parser = argparse.ArgumentParser(description="Test prefix cache accuracy") + parser.add_argument( + "--rounds", type=int, default=3, help="Number of rounds to repeat" + ) + parser.add_argument("--base-url", type=str, default=BASE_URL) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + base_url = args.base_url + + # Health check + try: + r = requests.get(f"{base_url}/health", timeout=5) + r.raise_for_status() + except Exception as e: + print(f"Server not reachable at {base_url}: {e}") + sys.exit(1) + + model = get_model_name(base_url) + print("=== Prefix Cache Accuracy Test ===") + print(f"Server: {base_url}") + print(f"Model: {model}") + print(f"Questions per round: {len(TEST_QUESTIONS)}") + print(f"Rounds: {args.rounds}") + print(f"Shared prefix length: ~{len(MATH_PREFIX)} chars") + print() + + all_round_results = [] + + for round_num in range(1, args.rounds + 1): + t0 = time.time() + correct, total, results = run_batch( + TEST_QUESTIONS, + MATH_PREFIX, + base_url=base_url, + model=model, + label=f"Round {round_num}", + ) + elapsed = time.time() - t0 + accuracy = 100.0 * correct / total + all_round_results.append((correct, total, accuracy, elapsed)) + + print( + f"Round {round_num}: {correct}/{total} correct ({accuracy:.1f}%) in {elapsed:.1f}s" + ) + + if args.verbose: + for q, expected, got, ok, resp in results: + status = "OK" if ok else "WRONG" + print(f" [{status}] {q[:60]}... expected={expected} got={got}") + if not ok: + # Show first 200 chars of response for debugging + print(f" response: {resp[:200]}") + print() + + print() + print("=== Summary ===") + total_correct = sum(r[0] for r in all_round_results) + total_questions = sum(r[1] for r in all_round_results) + overall_accuracy = 100.0 * total_correct / total_questions + + # Check consistency: same questions should give same answers across rounds + print(f"Overall: {total_correct}/{total_questions} ({overall_accuracy:.1f}%)") + for i, (c, t, a, e) in enumerate(all_round_results, 1): + cache_note = "(cold)" if i == 1 else "(cache warm)" + print(f" Round {i}: {c}/{t} ({a:.1f}%) {e:.1f}s {cache_note}") + + # Verify rounds 2+ should be faster (cache hits) + if args.rounds >= 2: + r1_time = all_round_results[0][3] + r2_time = all_round_results[1][3] + speedup = r1_time / r2_time if r2_time > 0 else 0 + print(f"\n Speedup round 2 vs round 1: {speedup:.2f}x") + + # Pass/fail + if overall_accuracy >= 80.0: + print(f"\nPASS: accuracy {overall_accuracy:.1f}% >= 80%") + return 0 + else: + print(f"\nFAIL: accuracy {overall_accuracy:.1f}% < 80%") + return 1 + + +if __name__ == "__main__": + sys.exit(main())