Skip to content

Ds prefix cache2#286

Open
jiayyu wants to merge 12 commits intomainfrom
ds_prefix_cache2
Open

Ds prefix cache2#286
jiayyu wants to merge 12 commits intomainfrom
ds_prefix_cache2

Conversation

@jiayyu
Copy link
Contributor

@jiayyu jiayyu commented Mar 9, 2026

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@jiayyu jiayyu force-pushed the ds_prefix_cache2 branch 2 times, most recently from 80465c7 to 0e76108 Compare March 12, 2026 03:01
valarLip and others added 11 commits March 13, 2026 15:54
Fix GPU memory access fault caused by double conversion of block_tables
in cached prefill path. kv_indices_generate_triton applies block_ratio
internally, but was receiving already-converted block_tables (via
block_tables_converted), causing indices to be multiplied by block_ratio
twice (e.g. block_id*256 instead of block_id*16), exceeding KV cache
bounds.

Key changes:
- Use raw block_tables for kv_indices generation in aiter_mla prefill
- Route cached prefill through paged MLA attention (supports Q≠K)
  instead of flash_attn_varlen_func (requires Q==K)
- Track has_cached flag through AttentionMetaData for path selection
- Fix block_manager: hash table leak, can_allocate cache-hit accounting,
  can_append for multi-token decode, O(1) free block tracking
- Add CacheStats to scheduler for prefix cache hit rate monitoring
- Add comprehensive block_manager tests (119 passing)

Verified: gsm8k 1319 samples, 95.83% accuracy, 0 GPU faults.
@jiayyu jiayyu force-pushed the ds_prefix_cache2 branch from bf02a43 to e17b614 Compare March 13, 2026 07:55
@jiayyu jiayyu marked this pull request as ready for review March 13, 2026 07:56
Copilot AI review requested due to automatic review settings March 13, 2026 07:56
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends prefix-caching support through the scheduler, metadata builders, and attention implementations (including DeepSeek-v2/MLA paths), and adds unit tests for cache-aware block management behavior.

Changes:

  • Add prefix-cache metadata plumbing (has_cached, num_cached_tokens, etc.) and use it to gather cached+new KV for attention/indexing.
  • Update BlockManager/scheduler logic to account for cache hits and multi-token decode allocation.
  • Add/extend tests covering prefix-cache allocation behavior and hash-table cleanup.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/test_prefix_cache_accuracy.py Adds an integration-style prefix-cache accuracy driver script under tests/.
tests/test_block_manager.py Adds new unit tests for cache-aware allocation, hash cleanup, and multi-token append scenarios.
atom/utils/forward_context.py Extends AttentionMetaData with prefix-cache-related fields.
atom/models/deepseek_v2.py Adjusts sparse indexer to consider full KV length when prefix cache is present.
atom/model_ops/base_attention.py Extends cp_mha_gather_cache to support multiple KV cache layouts.
atom/model_ops/attentions/backends.py Builds prefill metadata accounting for cached tokens; adds token-to-batch mapping for cache gather.
atom/model_ops/attentions/aiter_mla.py Updates MLA prefill metadata generation to use full-context lengths when prefix cache is present.
atom/model_ops/attention_mla.py Adds prefix-cache path to gather full KV and run varlen flash-attn prefill.
atom/model_ops/attention_mha.py Adds prefix-cache KV gather+concat path for MHA via cp_mha_gather_cache.
atom/model_engine/scheduler.py Adds cache hit-rate stats logging; updates decode scheduling to reserve multi-token space.
atom/model_engine/block_manager.py Adds free-block tracking set + cache-aware can_allocate; updates can_append to support multi-token appends.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +259 to +260
seq.append_token(2)
seq.append_token(3)
Comment on lines +263 to +273
def test_can_append_multi_token_crossing_boundary(self, seq_factory):
"""block_size=4, seq_len=14 (3.5 blocks=4 blocks allocated),
appending 5 tokens crosses into block 5 — needs 1 new block."""
cfg = MockConfig(num_kvcache_blocks=6, kv_cache_block_size=4)
bm = BlockManager(cfg)
seq = seq_factory(list(range(14)))
bm.allocate(seq)
# seq_len=14, 4 blocks. Appending 5 tokens: positions 14..18 → need block 5
for t in range(14, 19):
seq.append_token(t)
assert bm.can_append(seq, num_new_tokens=5)
Comment on lines +17 to +18
import requests

cache_miss = True
if cache_miss:
needed_free += 1
return len(self.free_block_ids_set) >= needed_free
Comment on lines 115 to 118
if cache_miss:
block_id = self.free_block_ids[0]
block_id = self._pop_free_block()
block = self._allocate_block(block_id)
else:
Comment on lines 171 to +179
if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1:
needed_blocks = (seq_len + self.block_size - 1) // self.block_size
while len(block_table) < needed_blocks:
# For block_size == 1, we need to update hash for each new block
# For block_size > 1, the previous block should have hash != -1 (unless it's the first block)
if self.block_size == 1:
# Allocate new block and update hash immediately (like allocate does for full blocks)
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
token_ids = [seq[-1]]
prefix = (
self.blocks[block_table[-2]].hash
if len(block_table) > 1
else -1
)
h = self.compute_hash(token_ids, prefix)
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
else:
# For block_size > 1, we only allocate new block when needed
# The hash will be updated when the block becomes full
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
last_block = block
elif seq_len % self.block_size == 0:
# Last block is now full, update its hash (similar to allocate)
# TODO: fix hash
token_ids = seq.block(seq.num_blocks - 1)
if len(token_ids) == self.block_size:
prefix = (
self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
)
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
pass
# Last block is not full and not at the boundary
# Hash remains -1 until block is full (consistent with allocate logic)
# assert last_block.hash == -1, last_block.block_id
# Decode-generated blocks: token not finalized yet (depends on
# sampling / speculative verification), so we cannot compute a
# correct hash here. Just allocate the block without hashing.
block_id = self._pop_free_block()
self._allocate_block(block_id)
block_table.append(block_id)
Comment on lines +239 to +245
total_tokens = cu_seqlens_k[-1]
token_to_batch = torch.zeros(
total_tokens, dtype=torch.int32, device=self.device
)
for i in range(bs):
start = cu_seqlens_k[i]
end = cu_seqlens_k[i + 1]
Comment on lines +233 to +234
logger.info(f"{has_cached=}")
logger.info(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants