Open
Conversation
80465c7 to
0e76108
Compare
Fix GPU memory access fault caused by double conversion of block_tables in cached prefill path. kv_indices_generate_triton applies block_ratio internally, but was receiving already-converted block_tables (via block_tables_converted), causing indices to be multiplied by block_ratio twice (e.g. block_id*256 instead of block_id*16), exceeding KV cache bounds. Key changes: - Use raw block_tables for kv_indices generation in aiter_mla prefill - Route cached prefill through paged MLA attention (supports Q≠K) instead of flash_attn_varlen_func (requires Q==K) - Track has_cached flag through AttentionMetaData for path selection - Fix block_manager: hash table leak, can_allocate cache-hit accounting, can_append for multi-token decode, O(1) free block tracking - Add CacheStats to scheduler for prefix cache hit rate monitoring - Add comprehensive block_manager tests (119 passing) Verified: gsm8k 1319 samples, 95.83% accuracy, 0 GPU faults.
bf02a43 to
e17b614
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
This PR extends prefix-caching support through the scheduler, metadata builders, and attention implementations (including DeepSeek-v2/MLA paths), and adds unit tests for cache-aware block management behavior.
Changes:
- Add prefix-cache metadata plumbing (
has_cached,num_cached_tokens, etc.) and use it to gather cached+new KV for attention/indexing. - Update BlockManager/scheduler logic to account for cache hits and multi-token decode allocation.
- Add/extend tests covering prefix-cache allocation behavior and hash-table cleanup.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_prefix_cache_accuracy.py | Adds an integration-style prefix-cache accuracy driver script under tests/. |
| tests/test_block_manager.py | Adds new unit tests for cache-aware allocation, hash cleanup, and multi-token append scenarios. |
| atom/utils/forward_context.py | Extends AttentionMetaData with prefix-cache-related fields. |
| atom/models/deepseek_v2.py | Adjusts sparse indexer to consider full KV length when prefix cache is present. |
| atom/model_ops/base_attention.py | Extends cp_mha_gather_cache to support multiple KV cache layouts. |
| atom/model_ops/attentions/backends.py | Builds prefill metadata accounting for cached tokens; adds token-to-batch mapping for cache gather. |
| atom/model_ops/attentions/aiter_mla.py | Updates MLA prefill metadata generation to use full-context lengths when prefix cache is present. |
| atom/model_ops/attention_mla.py | Adds prefix-cache path to gather full KV and run varlen flash-attn prefill. |
| atom/model_ops/attention_mha.py | Adds prefix-cache KV gather+concat path for MHA via cp_mha_gather_cache. |
| atom/model_engine/scheduler.py | Adds cache hit-rate stats logging; updates decode scheduling to reserve multi-token space. |
| atom/model_engine/block_manager.py | Adds free-block tracking set + cache-aware can_allocate; updates can_append to support multi-token appends. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+259
to
+260
| seq.append_token(2) | ||
| seq.append_token(3) |
Comment on lines
+263
to
+273
| def test_can_append_multi_token_crossing_boundary(self, seq_factory): | ||
| """block_size=4, seq_len=14 (3.5 blocks=4 blocks allocated), | ||
| appending 5 tokens crosses into block 5 — needs 1 new block.""" | ||
| cfg = MockConfig(num_kvcache_blocks=6, kv_cache_block_size=4) | ||
| bm = BlockManager(cfg) | ||
| seq = seq_factory(list(range(14))) | ||
| bm.allocate(seq) | ||
| # seq_len=14, 4 blocks. Appending 5 tokens: positions 14..18 → need block 5 | ||
| for t in range(14, 19): | ||
| seq.append_token(t) | ||
| assert bm.can_append(seq, num_new_tokens=5) |
Comment on lines
+17
to
+18
| import requests | ||
|
|
| cache_miss = True | ||
| if cache_miss: | ||
| needed_free += 1 | ||
| return len(self.free_block_ids_set) >= needed_free |
Comment on lines
115
to
118
| if cache_miss: | ||
| block_id = self.free_block_ids[0] | ||
| block_id = self._pop_free_block() | ||
| block = self._allocate_block(block_id) | ||
| else: |
Comment on lines
171
to
+179
| if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1: | ||
| needed_blocks = (seq_len + self.block_size - 1) // self.block_size | ||
| while len(block_table) < needed_blocks: | ||
| # For block_size == 1, we need to update hash for each new block | ||
| # For block_size > 1, the previous block should have hash != -1 (unless it's the first block) | ||
| if self.block_size == 1: | ||
| # Allocate new block and update hash immediately (like allocate does for full blocks) | ||
| block_id = self.free_block_ids[0] | ||
| block = self._allocate_block(block_id) | ||
| block_table.append(block_id) | ||
| token_ids = [seq[-1]] | ||
| prefix = ( | ||
| self.blocks[block_table[-2]].hash | ||
| if len(block_table) > 1 | ||
| else -1 | ||
| ) | ||
| h = self.compute_hash(token_ids, prefix) | ||
| block.update(h, token_ids) | ||
| self.hash_to_block_id[h] = block_id | ||
| else: | ||
| # For block_size > 1, we only allocate new block when needed | ||
| # The hash will be updated when the block becomes full | ||
| block_id = self.free_block_ids[0] | ||
| block = self._allocate_block(block_id) | ||
| block_table.append(block_id) | ||
| last_block = block | ||
| elif seq_len % self.block_size == 0: | ||
| # Last block is now full, update its hash (similar to allocate) | ||
| # TODO: fix hash | ||
| token_ids = seq.block(seq.num_blocks - 1) | ||
| if len(token_ids) == self.block_size: | ||
| prefix = ( | ||
| self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 | ||
| ) | ||
| h = self.compute_hash(token_ids, prefix) | ||
| last_block.update(h, token_ids) | ||
| self.hash_to_block_id[h] = last_block.block_id | ||
| else: | ||
| pass | ||
| # Last block is not full and not at the boundary | ||
| # Hash remains -1 until block is full (consistent with allocate logic) | ||
| # assert last_block.hash == -1, last_block.block_id | ||
| # Decode-generated blocks: token not finalized yet (depends on | ||
| # sampling / speculative verification), so we cannot compute a | ||
| # correct hash here. Just allocate the block without hashing. | ||
| block_id = self._pop_free_block() | ||
| self._allocate_block(block_id) | ||
| block_table.append(block_id) |
Comment on lines
+239
to
+245
| total_tokens = cu_seqlens_k[-1] | ||
| token_to_batch = torch.zeros( | ||
| total_tokens, dtype=torch.int32, device=self.device | ||
| ) | ||
| for i in range(bs): | ||
| start = cu_seqlens_k[i] | ||
| end = cu_seqlens_k[i + 1] |
Comment on lines
+233
to
+234
| logger.info(f"{has_cached=}") | ||
| logger.info( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist