Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 49 additions & 46 deletions atom/model_engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, config: Config):
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.free_block_ids_set: set[int] = set(range(num_blocks))
self.used_block_ids: set[int] = set()
self.enable_prefix_caching = config.enable_prefix_caching

Expand All @@ -48,22 +49,52 @@ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h.update(np.array(token_ids).tobytes())
return h.intdigest()

def _pop_free_block(self) -> int:
"""Pop the next available free block id from the FIFO queue (lazy cleanup)."""
while self.free_block_ids:
block_id = self.free_block_ids.popleft()
if block_id in self.free_block_ids_set:
self.free_block_ids_set.discard(block_id)
return block_id
raise AssertionError("No free blocks available")

def _allocate_block(self, block_id: int) -> Block:
block = self.blocks[block_id]
assert block.ref_count == 0
# Evict stale hash entry before resetting
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
block.reset()
self.free_block_ids.remove(block_id)
self.free_block_ids_set.discard(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]

def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
# self.free_block_ids.appendleft(block_id)
self.free_block_ids_set.add(block_id)

def can_allocate(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= seq.num_blocks + seq.num_mamba_blocks
if not self.enable_prefix_caching:
return len(self.free_block_ids_set) >= seq.num_blocks + seq.num_mamba_blocks
# Dry-run: count how many blocks would be cache hits
h = -1
cache_miss = False
needed_free = 0
for i in range(seq.num_blocks):
token_ids = seq.block(i)
h = (
self.compute_hash(token_ids, h)
if len(token_ids) == self.block_size
else -1
)
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
needed_free += 1
return len(self.free_block_ids_set) >= needed_free

def allocate(self, seq: Sequence):
assert not seq.block_table
Expand All @@ -82,7 +113,7 @@ def allocate(self, seq: Sequence):
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
block_id = self.free_block_ids[0]
block_id = self._pop_free_block()
block = self._allocate_block(block_id)
else:
Comment on lines 115 to 118
seq.num_cached_tokens += self.block_size
Expand Down Expand Up @@ -122,55 +153,27 @@ def deallocate(self, seq: Sequence):
self._deallocate_block(block_id)
seq.mamba_block_table.clear()

def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool:
seq_len = len(seq)
current_blocks = len(seq.block_table)
needed_blocks = (
seq_len + num_new_tokens + self.block_size - 1
) // self.block_size
new_blocks_needed = max(0, needed_blocks - current_blocks)
return len(self.free_block_ids_set) >= new_blocks_needed

def may_append(self, seq: Sequence, num_new_tokens: int = 1):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
seq_len = len(seq)
# Check if we need to allocate a new block
# When len(seq) % block_size == 1, we need a new block for the next token
# When block_size == 1, every token needs a new block
if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1:
needed_blocks = (seq_len + self.block_size - 1) // self.block_size
while len(block_table) < needed_blocks:
# For block_size == 1, we need to update hash for each new block
# For block_size > 1, the previous block should have hash != -1 (unless it's the first block)
if self.block_size == 1:
# Allocate new block and update hash immediately (like allocate does for full blocks)
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
token_ids = [seq[-1]]
prefix = (
self.blocks[block_table[-2]].hash
if len(block_table) > 1
else -1
)
h = self.compute_hash(token_ids, prefix)
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
else:
# For block_size > 1, we only allocate new block when needed
# The hash will be updated when the block becomes full
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
last_block = block
elif seq_len % self.block_size == 0:
# Last block is now full, update its hash (similar to allocate)
# TODO: fix hash
token_ids = seq.block(seq.num_blocks - 1)
if len(token_ids) == self.block_size:
prefix = (
self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
)
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
pass
# Last block is not full and not at the boundary
# Hash remains -1 until block is full (consistent with allocate logic)
# assert last_block.hash == -1, last_block.block_id
# Decode-generated blocks: token not finalized yet (depends on
# sampling / speculative verification), so we cannot compute a
# correct hash here. Just allocate the block without hashing.
block_id = self._pop_free_block()
self._allocate_block(block_id)
block_table.append(block_id)
Comment on lines 171 to +179
71 changes: 70 additions & 1 deletion atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,70 @@ def _log(self) -> None:
)


class CacheStats:
"""Tracks prefix caching hit statistics."""

__slots__ = (
"_log_interval",
"total_requests",
"total_cached_tokens",
"total_full_tokens",
"_interval_requests",
"_interval_cached_tokens",
"_interval_full_tokens",
)

def __init__(self, log_interval: int = 100):
self._log_interval = log_interval
self.total_requests: int = 0
self.total_cached_tokens: int = 0
self.total_full_tokens: int = 0
self._interval_requests: int = 0
self._interval_cached_tokens: int = 0
self._interval_full_tokens: int = 0

def update(self, num_cached_tokens: int, num_full_tokens: int) -> None:
"""Record cache stats for one prefill sequence."""
self.total_requests += 1
self.total_cached_tokens += num_cached_tokens
self.total_full_tokens += num_full_tokens
self._interval_requests += 1
self._interval_cached_tokens += num_cached_tokens
self._interval_full_tokens += num_full_tokens

if self.total_requests % self._log_interval == 0:
self._log()
self._reset_interval()

@property
def hit_rate(self) -> float:
if self.total_full_tokens == 0:
return 0.0
return self.total_cached_tokens / self.total_full_tokens

def _reset_interval(self) -> None:
self._interval_requests = 0
self._interval_cached_tokens = 0
self._interval_full_tokens = 0

def _log(self) -> None:
iv_rate = (
self._interval_cached_tokens / self._interval_full_tokens
if self._interval_full_tokens > 0
else 0.0
)
logger.info(
f"[Cache Stats Interval] Reqs: {self._interval_requests}, "
f"Cached/Total tokens: {self._interval_cached_tokens}/{self._interval_full_tokens}, "
f"Hit rate: {iv_rate:.2%}"
)
logger.info(
f"[Cache Stats ] Reqs: {self.total_requests}, "
f"Cached/Total tokens: {self.total_cached_tokens}/{self.total_full_tokens}, "
f"Hit rate: {self.hit_rate:.2%}"
)


class ScheduledBatch:
def __init__(
self,
Expand Down Expand Up @@ -246,6 +310,9 @@ def __init__(self, config: Config):
self.spec_stats: Optional[SpecStats] = (
SpecStats(mtp_k=self.mtp_k) if self.use_spec else None
)
self.cache_stats: Optional[CacheStats] = (
CacheStats() if config.enable_prefix_caching else None
)

def is_finished(self):
return not self.waiting and not self.running
Expand Down Expand Up @@ -286,6 +353,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
# Recalculate after allocation: prefix caching may have updated
# seq.num_cached_tokens, reducing the actual number of new tokens.
num_new_tokens = seq.num_tokens - seq.num_cached_tokens
if self.cache_stats:
self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens)
num_batched_tokens += num_new_tokens
seq.status = SequenceStatus.RUNNING
seq.type = SequenceType.PREFILL
Expand Down Expand Up @@ -319,7 +388,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
num_seqs_decode = 0
while self.running and num_seqs_decode < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
while not self.block_manager.can_append(seq, self.mtp_k + 1):
if self.running:
self.preempt(self.running.pop())
else:
Expand Down
102 changes: 101 additions & 1 deletion atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@

from .attention_mla import MLAModules

import logging

from atom.plugin.prepare import is_plugin_mode, is_vllm
from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode
from atom.model_ops.base_attention import cp_mha_gather_cache

logger = logging.getLogger("atom")


@PagedAttentionImplDecoratorForPluginMode
Expand Down Expand Up @@ -127,6 +132,14 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
and self.q_norm is not None
and self.k_norm is not None
):
# fused_qk_norm_rope_cache_quant_shuffle expects V cache layout
# [num_blocks, num_kv_heads, block_size//x, head_size, x], not [n, nh, hd, bs]
x = 16 // k_cache.element_size()
if k_cache.dim() == 5 and v_cache.dim() == 4:
n, nh, hd, bs = v_cache.shape
v_cache_shuffle = v_cache.view(n, nh, bs // x, hd, x)
else:
v_cache_shuffle = v_cache
fused_qk_norm_rope_cache_quant_shuffle(
qkv,
num_heads_q=self.num_heads,
Expand All @@ -140,7 +153,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
is_neox_style=self.rotary_emb.is_neox_style,
pos_ids=position,
k_cache=k_cache,
v_cache=v_cache,
v_cache=v_cache_shuffle,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=(
"auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype
Expand Down Expand Up @@ -212,8 +225,95 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
asm_layout=asm_layout,
)

# Prefix cache hit: gather cached KV from paged cache and concat with new tokens
if attn_metadata.has_cached:
q, k, v, k_cache, v_cache, k_scale, v_scale = (
self._gather_prefix_and_concat_kv(
q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata
)
)

return q, k, v, k_cache, v_cache, k_scale, v_scale

def _gather_prefix_and_concat_kv(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
attn_metadata,
):
"""
When prefix cache hits, gather full KV (cached + new) from paged cache in
one pass. New tokens are already written by fused_qk_rope_reshape_and_cache.
Same flow as gather_kv_b_proj: write new first, then read cached+new together.
token_to_batch, seq_starts are built in prepare_prefill.
"""
cu_seqlens_k = attn_metadata.cu_seqlens_k
total_tokens = cu_seqlens_k[-1].item()
token_to_batch = attn_metadata.token_to_batch
seq_starts = attn_metadata.seq_starts

num_kv_heads = k.shape[1]
head_dim = k.shape[2]
device = k.device
dtype = k.dtype

k_full = torch.empty(
(total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device
)
v_full = torch.empty(
(total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device
)

# Convert cache for cp_mha_gather_cache
# fused_qk_norm_rope_cache_quant_shuffle: K [n, nh, hd//x, bs, x], V [n, nh, bs//x, hd, x] (SHUFFLE)
# fused_qk_rope_reshape_and_cache: K [n, nh, hd//x, bs, x], V [n, nh, hd, bs] -> NHD
if k_cache.dim() == 5:
x = 16 // k_cache.element_size()
n, nh, _, block_size, _ = k_cache.shape
if v_cache.dim() == 4:
# fused_qk_norm_rope_cache_quant_shuffle: V data in [n, nh, bs//x, hd, x] layout
use_shuffle = True
k_cache_gather = k_cache
v_cache_gather = v_cache.view(n, nh, block_size // x, head_dim, x)
else:
# fused_qk_rope_reshape_and_cache: V [n, nh, hd, bs] -> NHD
use_shuffle = False
k_cache_gather = (
k_cache.permute(0, 3, 1, 2, 4)
.contiguous()
.view(n, block_size, nh, head_dim)
)
v_cache_gather = v_cache.permute(0, 3, 1, 2).contiguous()
else:
use_shuffle = False
k_cache_gather = k_cache
v_cache_gather = v_cache
block_size = k_cache.shape[1]

block_tables = attn_metadata.block_tables
cp_mha_gather_cache(
key_cache=k_cache_gather,
value_cache=v_cache_gather,
key=k_full,
value=v_full,
block_tables=block_tables,
k_scales=k_scale,
v_scales=v_scale,
cu_seqlens_kv=cu_seqlens_k,
token_to_batch=token_to_batch,
seq_starts=seq_starts,
dequant=self.kv_cache_dtype.startswith("fp8"),
kv_cache_layout="SHUFFLE" if use_shuffle else "NHD",
total_tokens=total_tokens,
)

return q, k_full, v_full, k_cache, v_cache, k_scale, v_scale

def paged_attention_triton(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
):
Expand Down
Loading
Loading