From 6a23db8a215aa7d3b216f3b6ae98b2767fd26f1c Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Tue, 24 Feb 2026 15:00:02 +0000 Subject: [PATCH 01/12] fix: resolve prefix caching crashes with MTP speculative decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix GPU memory access fault caused by double conversion of block_tables in cached prefill path. kv_indices_generate_triton applies block_ratio internally, but was receiving already-converted block_tables (via block_tables_converted), causing indices to be multiplied by block_ratio twice (e.g. block_id*256 instead of block_id*16), exceeding KV cache bounds. Key changes: - Use raw block_tables for kv_indices generation in aiter_mla prefill - Route cached prefill through paged MLA attention (supports Q≠K) instead of flash_attn_varlen_func (requires Q==K) - Track has_cached flag through AttentionMetaData for path selection - Fix block_manager: hash table leak, can_allocate cache-hit accounting, can_append for multi-token decode, O(1) free block tracking - Add CacheStats to scheduler for prefix cache hit rate monitoring - Add comprehensive block_manager tests (119 passing) Verified: gsm8k 1319 samples, 95.83% accuracy, 0 GPU faults. --- atom/model_engine/block_manager.py | 95 +++++++------ atom/model_engine/scheduler.py | 71 +++++++++- atom/model_ops/attention_mla.py | 2 +- atom/model_ops/attentions/aiter_mla.py | 5 +- atom/model_ops/attentions/backends.py | 31 ++++- atom/utils/forward_context.py | 3 + tests/test_block_manager.py | 172 +++++++++++++++++++++++ tests/test_prefix_cache_accuracy.py | 185 +++++++++++++++++++++++++ 8 files changed, 511 insertions(+), 53 deletions(-) create mode 100644 tests/test_prefix_cache_accuracy.py diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index df6ebd2c3..da97c10df 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -37,6 +37,7 @@ def __init__(self, config: Config): self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.hash_to_block_id: dict[int, int] = dict() self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.free_block_ids_set: set[int] = set(range(num_blocks)) self.used_block_ids: set[int] = set() self.enable_prefix_caching = config.enable_prefix_caching @@ -48,11 +49,23 @@ def compute_hash(cls, token_ids: list[int], prefix: int = -1): h.update(np.array(token_ids).tobytes()) return h.intdigest() + def _pop_free_block(self) -> int: + """Pop the next available free block id from the FIFO queue (lazy cleanup).""" + while self.free_block_ids: + block_id = self.free_block_ids.popleft() + if block_id in self.free_block_ids_set: + self.free_block_ids_set.discard(block_id) + return block_id + raise AssertionError("No free blocks available") + def _allocate_block(self, block_id: int) -> Block: block = self.blocks[block_id] assert block.ref_count == 0 + # Evict stale hash entry before resetting + if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id: + del self.hash_to_block_id[block.hash] block.reset() - self.free_block_ids.remove(block_id) + self.free_block_ids_set.discard(block_id) self.used_block_ids.add(block_id) return self.blocks[block_id] @@ -60,10 +73,28 @@ def _deallocate_block(self, block_id: int): assert self.blocks[block_id].ref_count == 0 self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) - # self.free_block_ids.appendleft(block_id) + self.free_block_ids_set.add(block_id) def can_allocate(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= seq.num_blocks + seq.num_mamba_blocks + if not self.enable_prefix_caching: + return len(self.free_block_ids_set) >= seq.num_blocks + seq.num_mamba_blocks + # Dry-run: count how many blocks would be cache hits + h = -1 + cache_miss = False + needed_free = 0 + for i in range(seq.num_blocks): + token_ids = seq.block(i) + h = ( + self.compute_hash(token_ids, h) + if len(token_ids) == self.block_size + else -1 + ) + block_id = self.hash_to_block_id.get(h, -1) + if block_id == -1 or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + if cache_miss: + needed_free += 1 + return len(self.free_block_ids_set) >= needed_free def allocate(self, seq: Sequence): assert not seq.block_table @@ -82,7 +113,7 @@ def allocate(self, seq: Sequence): if block_id == -1 or self.blocks[block_id].token_ids != token_ids: cache_miss = True if cache_miss: - block_id = self.free_block_ids[0] + block_id = self._pop_free_block() block = self._allocate_block(block_id) else: seq.num_cached_tokens += self.block_size @@ -122,12 +153,17 @@ def deallocate(self, seq: Sequence): self._deallocate_block(block_id) seq.mamba_block_table.clear() - def can_append(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool: + seq_len = len(seq) + current_blocks = len(seq.block_table) + needed_blocks = ( + seq_len + num_new_tokens + self.block_size - 1 + ) // self.block_size + new_blocks_needed = max(0, needed_blocks - current_blocks) + return len(self.free_block_ids_set) >= new_blocks_needed def may_append(self, seq: Sequence, num_new_tokens: int = 1): block_table = seq.block_table - last_block = self.blocks[block_table[-1]] seq_len = len(seq) # Check if we need to allocate a new block # When len(seq) % block_size == 1, we need a new block for the next token @@ -135,42 +171,9 @@ def may_append(self, seq: Sequence, num_new_tokens: int = 1): if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1: needed_blocks = (seq_len + self.block_size - 1) // self.block_size while len(block_table) < needed_blocks: - # For block_size == 1, we need to update hash for each new block - # For block_size > 1, the previous block should have hash != -1 (unless it's the first block) - if self.block_size == 1: - # Allocate new block and update hash immediately (like allocate does for full blocks) - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - token_ids = [seq[-1]] - prefix = ( - self.blocks[block_table[-2]].hash - if len(block_table) > 1 - else -1 - ) - h = self.compute_hash(token_ids, prefix) - block.update(h, token_ids) - self.hash_to_block_id[h] = block_id - else: - # For block_size > 1, we only allocate new block when needed - # The hash will be updated when the block becomes full - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - last_block = block - elif seq_len % self.block_size == 0: - # Last block is now full, update its hash (similar to allocate) - # TODO: fix hash - token_ids = seq.block(seq.num_blocks - 1) - if len(token_ids) == self.block_size: - prefix = ( - self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 - ) - h = self.compute_hash(token_ids, prefix) - last_block.update(h, token_ids) - self.hash_to_block_id[h] = last_block.block_id - else: - pass - # Last block is not full and not at the boundary - # Hash remains -1 until block is full (consistent with allocate logic) - # assert last_block.hash == -1, last_block.block_id + # Decode-generated blocks: token not finalized yet (depends on + # sampling / speculative verification), so we cannot compute a + # correct hash here. Just allocate the block without hashing. + block_id = self._pop_free_block() + self._allocate_block(block_id) + block_table.append(block_id) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 4c3f5a12e..6f30de71b 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -107,6 +107,70 @@ def _log(self) -> None: ) +class CacheStats: + """Tracks prefix caching hit statistics.""" + + __slots__ = ( + "_log_interval", + "total_requests", + "total_cached_tokens", + "total_full_tokens", + "_interval_requests", + "_interval_cached_tokens", + "_interval_full_tokens", + ) + + def __init__(self, log_interval: int = 100): + self._log_interval = log_interval + self.total_requests: int = 0 + self.total_cached_tokens: int = 0 + self.total_full_tokens: int = 0 + self._interval_requests: int = 0 + self._interval_cached_tokens: int = 0 + self._interval_full_tokens: int = 0 + + def update(self, num_cached_tokens: int, num_full_tokens: int) -> None: + """Record cache stats for one prefill sequence.""" + self.total_requests += 1 + self.total_cached_tokens += num_cached_tokens + self.total_full_tokens += num_full_tokens + self._interval_requests += 1 + self._interval_cached_tokens += num_cached_tokens + self._interval_full_tokens += num_full_tokens + + if self.total_requests % self._log_interval == 0: + self._log() + self._reset_interval() + + @property + def hit_rate(self) -> float: + if self.total_full_tokens == 0: + return 0.0 + return self.total_cached_tokens / self.total_full_tokens + + def _reset_interval(self) -> None: + self._interval_requests = 0 + self._interval_cached_tokens = 0 + self._interval_full_tokens = 0 + + def _log(self) -> None: + iv_rate = ( + self._interval_cached_tokens / self._interval_full_tokens + if self._interval_full_tokens > 0 + else 0.0 + ) + logger.info( + f"[Cache Stats Interval] Reqs: {self._interval_requests}, " + f"Cached/Total tokens: {self._interval_cached_tokens}/{self._interval_full_tokens}, " + f"Hit rate: {iv_rate:.2%}" + ) + logger.info( + f"[Cache Stats ] Reqs: {self.total_requests}, " + f"Cached/Total tokens: {self.total_cached_tokens}/{self.total_full_tokens}, " + f"Hit rate: {self.hit_rate:.2%}" + ) + + class ScheduledBatch: def __init__( self, @@ -246,6 +310,9 @@ def __init__(self, config: Config): self.spec_stats: Optional[SpecStats] = ( SpecStats(mtp_k=self.mtp_k) if self.use_spec else None ) + self.cache_stats: Optional[CacheStats] = ( + CacheStats() if config.enable_prefix_caching else None + ) def is_finished(self): return not self.waiting and not self.running @@ -286,6 +353,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: # Recalculate after allocation: prefix caching may have updated # seq.num_cached_tokens, reducing the actual number of new tokens. num_new_tokens = seq.num_tokens - seq.num_cached_tokens + if self.cache_stats: + self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) num_batched_tokens += num_new_tokens seq.status = SequenceStatus.RUNNING seq.type = SequenceType.PREFILL @@ -319,7 +388,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_seqs_decode = 0 while self.running and num_seqs_decode < self.max_num_seqs: seq = self.running.popleft() - while not self.block_manager.can_append(seq): + while not self.block_manager.can_append(seq, self.mtp_k + 1): if self.running: self.preempt(self.running.pop()) else: diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 9d6fe94d1..435b88ad6 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -622,7 +622,7 @@ def forward_impl_server_mode( kv_cache_data = forward_context.kv_cache_data kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache - if context.is_prefill and not use_prefill_mla: + if context.is_prefill and not use_prefill_mla and not attn_metadata.has_cached: prefill_q = self.q_proj(q, x_scale=q_scale).view( -1, self.num_heads, self.qk_head_dim ) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index b51239fca..0643b25d4 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -300,12 +300,15 @@ def prepare_prefill(self, batch: ScheduledBatch): sum_scheduled_tokens + 1 ) - if hasattr(self.model_runner, "drafter"): + if hasattr(self.model_runner, "drafter") or attn_metadata.has_cached: attn_metadata.kv_indices = var["kv_indices"].gpu attn_metadata.kv_indptr = var["kv_indptr"].gpu[: bs + 1] + attn_metadata.kv_indptr[0] = 0 attn_metadata.kv_indptr[1 : bs + 1] = torch.cumsum( attn_metadata.context_lens, 0 ) + attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] + if attn_metadata.block_tables is None: self.prepare_block_tables(batch) attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index a2226f612..91e6f0af3 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Generic, Optional, Type, TypeVar import torch from atom.model_engine.scheduler import ScheduledBatch + +logger = logging.getLogger("atom") from atom.model_ops.attention_mla import MLAModules from atom.utils import CpuGpuBuffer from atom.utils.block_convert import block_table_convert_triton @@ -141,18 +144,23 @@ def prepare_prefill(self, batch: ScheduledBatch): sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars positions = [] + cu_seqlens_q = [0] cu_seqlens_k = [0] max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] + has_cached = False # seqs = list(batch.seqs.values()) # seqs = seqs[:bs] for i in range(bs): seqlen = batch.context_lens[i] cached_seqlen = batch.num_cached_tokens[i] + if cached_seqlen > 0: + has_cached = True positions.extend(list(range(cached_seqlen, seqlen))) seqlen_q = seqlen - cached_seqlen seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) @@ -166,18 +174,30 @@ def prepare_prefill(self, batch: ScheduledBatch): ) // self.model_runner.block_size last_block_tokens = batch.last_block_num_tokens[i] block_table = batch.block_tables[i] - for i in range(num_cached_blocks, num_blocks): - start = block_table[i] * self.model_runner.block_size - if i != num_blocks - 1: + for blk_idx in range(num_cached_blocks, num_blocks): + start = block_table[blk_idx] * self.model_runner.block_size + if blk_idx != num_blocks - 1: end = start + self.model_runner.block_size else: end = start + last_block_tokens slot_mapping.extend(list(range(start, end))) - if cu_seqlens_k[-1] > batch.total_tokens_num: # prefix cache + if has_cached: self.prepare_block_tables(batch) + # Validate metadata consistency + assert ( + len(positions) == sum_scheduled_tokens + ), f"positions length {len(positions)} != sum_scheduled_tokens {sum_scheduled_tokens}" + if batch.block_tables: + assert ( + len(slot_mapping) == sum_scheduled_tokens + ), f"slot_mapping length {len(slot_mapping)} != sum_scheduled_tokens {sum_scheduled_tokens}" + assert ( + cu_seqlens_q[-1] == sum_scheduled_tokens + ), f"cu_seqlens_q[-1]={cu_seqlens_q[-1]} != sum_scheduled_tokens={sum_scheduled_tokens}" var["positions"].np[:sum_scheduled_tokens] = positions var["slot_mapping"].np[:sum_scheduled_tokens] = -1 var["slot_mapping"].np[: len(slot_mapping)] = slot_mapping + var["cu_seqlens_q"].np[: bs + 1] = cu_seqlens_q cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True) var["context_lens"].np[:bs] = batch.context_lens[:bs] min_seqlen_q = 0 @@ -187,6 +207,8 @@ def prepare_prefill(self, batch: ScheduledBatch): ("slot_mapping", sum_scheduled_tokens), ("context_lens", bs), ] + if has_cached: + vars_used.append(("block_tables", bs)) ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used} if self.block_ratio > 1 and "block_tables" in ctx: @@ -203,6 +225,7 @@ def prepare_prefill(self, batch: ScheduledBatch): max_seqlen_k=max_seqlen_k, min_seqlen_q=min_seqlen_q, dropout_p=dropout_p, + has_cached=has_cached, **ctx, ) positions = var["positions"].copy_to_gpu(sum_scheduled_tokens) diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 9b374620f..8f46d6689 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -188,6 +188,7 @@ class AttentionMetaData: reduce_partial_map: Optional[torch.Tensor] = None block_tables_converted: Optional[torch.Tensor] = None + has_cached: bool = False # only used for plugin mode to store the metadata for attn plugin_metadata: Optional["MetadataForPluginMode"] = None @@ -219,7 +220,9 @@ def __init__( sparse_cu_seqlens_q: Optional[torch.Tensor] = None, token_to_seq_idxs: Optional[torch.Tensor] = None, plugin_metadata: Optional["MetadataForPluginMode"] = None, + has_cached: bool = False, ): + self.has_cached = has_cached self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k self.max_seqlen_q = max_seqlen_q diff --git a/tests/test_block_manager.py b/tests/test_block_manager.py index 9dfb5d484..60a6c1061 100644 --- a/tests/test_block_manager.py +++ b/tests/test_block_manager.py @@ -173,3 +173,175 @@ def test_block_size_1(self, seq_factory): seq.append_token(3) bm.may_append(seq) assert len(seq.block_table) == 3 + + +# ── Prefix caching: can_allocate with cache hits ───────────────────────── + + +class TestCanAllocateWithPrefixCaching: + def test_can_allocate_accounts_for_cache_hits(self, seq_factory): + """With 3 blocks total, allocate 2-block seq, deallocate, then a new + 2-block seq sharing block 1 should need only 1 free block.""" + cfg = MockConfig( + num_kvcache_blocks=3, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + bm.deallocate(s1) # blocks freed, hashes retained + + # Use up 2 of the 3 free blocks + filler = seq_factory([50, 51, 52, 53, 60, 61, 62, 63]) + bm.allocate(filler) + # Only 1 free block left; s2 needs 2 blocks but first is cached + s2 = seq_factory([1, 2, 3, 4, 9, 10, 11, 12]) + assert bm.can_allocate(s2) + + def test_can_allocate_no_false_positive(self, seq_factory): + """can_allocate should return False when even with cache hits + there aren't enough free blocks.""" + cfg = MockConfig( + num_kvcache_blocks=2, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + # 0 free blocks; new seq shares prefix but needs 1 new block + s2 = seq_factory([1, 2, 3, 4, 9, 10, 11, 12]) + assert not bm.can_allocate(s2) + + +# ── Hash table cleanup ─────────────────────────────────────────────────── + + +class TestHashTableCleanup: + def test_stale_hash_entries_evicted_on_reuse(self, seq_factory): + """When a cached block is reused for a different hash, the old + hash_to_block_id entry should be cleaned up.""" + cfg = MockConfig( + num_kvcache_blocks=2, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + h1 = bm.blocks[s1.block_table[0]].hash + bm.deallocate(s1) + + # Allocate with completely different tokens — should overwrite blocks + s2 = seq_factory([90, 91, 92, 93, 94, 95, 96, 97]) + bm.allocate(s2) + # Old hash should no longer point to a valid block + assert bm.hash_to_block_id.get(h1) != s2.block_table[0] + + def test_hash_table_bounded_growth(self, seq_factory): + """hash_to_block_id should not grow beyond num_kvcache_blocks.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + for i in range(20): + tokens = list(range(i * 4, i * 4 + 4)) + seq = seq_factory(tokens) + if bm.can_allocate(seq): + bm.allocate(seq) + bm.deallocate(seq) + assert len(bm.hash_to_block_id) <= cfg.num_kvcache_blocks + + +# ── can_append with multi-token decode (speculative decoding) ──────────── + + +class TestCanAppendMultiToken: + def test_can_append_multi_token_within_block(self, block_manager, seq_factory): + """Appending 3 tokens that stay within the current block.""" + seq = seq_factory([1]) + block_manager.allocate(seq) + seq.append_token(2) + seq.append_token(3) + assert block_manager.can_append(seq, num_new_tokens=3) + + def test_can_append_multi_token_crossing_boundary(self, seq_factory): + """block_size=4, seq_len=14 (3.5 blocks=4 blocks allocated), + appending 5 tokens crosses into block 5 — needs 1 new block.""" + cfg = MockConfig(num_kvcache_blocks=6, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory(list(range(14))) + bm.allocate(seq) + # seq_len=14, 4 blocks. Appending 5 tokens: positions 14..18 → need block 5 + for t in range(14, 19): + seq.append_token(t) + assert bm.can_append(seq, num_new_tokens=5) + + def test_cannot_append_multi_token_no_free(self, seq_factory): + """block_size=4, 4 blocks total, seq fills 4 blocks (16 tokens), + appending 5 tokens needs 2 new blocks but only 0 free.""" + cfg = MockConfig(num_kvcache_blocks=4, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory(list(range(14))) + bm.allocate(seq) + for t in range(14, 19): + seq.append_token(t) + assert not bm.can_append(seq, num_new_tokens=5) + + +# ── Prefix caching + preemption ────────────────────────────────────────── + + +class TestPrefixCachingPreemption: + def test_preempt_and_reschedule_reuses_cache(self, seq_factory): + """Preempted sequence re-discovers cache hits on re-allocation.""" + cfg = MockConfig( + num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + # Simulate preemption + bm.deallocate(s1) + assert s1.num_cached_tokens == 0 + assert s1.block_table == [] + + # Re-allocate — should get cache hits on both blocks + s1_retry = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1_retry) + assert s1_retry.num_cached_tokens == 8 # both blocks cached + + +# ── Edge cases ─────────────────────────────────────────────────────────── + + +class TestPrefixCachingEdgeCases: + def test_single_token_no_cache(self, seq_factory): + """Single token seq (shorter than block_size) — hash is -1, no caching.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([42]) + bm.allocate(s1) + bm.deallocate(s1) + s2 = seq_factory([42]) + bm.allocate(s2) + # Partial block → hash is -1 → no caching + assert s2.num_cached_tokens == 0 + + def test_exact_block_size_fully_cached(self, seq_factory): + """Sequence with exactly block_size tokens — fully cached on reuse.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4]) + bm.allocate(s1) + bm.deallocate(s1) + s2 = seq_factory([1, 2, 3, 4]) + bm.allocate(s2) + assert s2.num_cached_tokens == 4 + + def test_free_block_ids_set_consistent(self, block_manager, seq_factory): + """free_block_ids_set stays consistent through allocate/deallocate.""" + s1 = seq_factory([1, 2, 3, 4]) + block_manager.allocate(s1) + initial_free = len(block_manager.free_block_ids_set) + block_manager.deallocate(s1) + assert len(block_manager.free_block_ids_set) == initial_free + 1 diff --git a/tests/test_prefix_cache_accuracy.py b/tests/test_prefix_cache_accuracy.py new file mode 100644 index 000000000..e10d97cf2 --- /dev/null +++ b/tests/test_prefix_cache_accuracy.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +""" +Test prefix caching accuracy with high cache-hit workloads. + +Sends batches of requests that share long common prefixes, then verifies: +1. Responses are correct (math problems with known answers) +2. Cache hit rate is high (visible in server logs) +3. Repeated identical requests produce consistent results +""" + +import argparse +import concurrent.futures +import json +import re +import sys +import time + +import requests + +BASE_URL = "http://localhost:8000" + +# Long shared prefix: 5-shot math examples (~2000 tokens) +MATH_PREFIX = """You are a precise math assistant. Solve each problem step by step and give the final numerical answer after ####. + +Question: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? +Answer: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. #### 6 + +Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? +Answer: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. #### 5 + +Question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? +Answer: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. #### 39 + +Question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? +Answer: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. #### 8 + +Question: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? +Answer: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 2 + 2 = 4 more toys. 5 + 4 = 9. #### 9 + +""" + +# Test questions with known answers +TEST_QUESTIONS = [ + ("A farmer has 17 sheep. He buys 5 more and then sells 8. How many sheep does he have?", 14), + ("A store had 45 apples. They sold 12 in the morning and 18 in the afternoon. How many apples are left?", 15), + ("Tom has 8 marbles. Jerry has 3 times as many. How many marbles do they have together?", 32), + ("A classroom has 6 rows of desks with 5 desks in each row. If 7 desks are removed, how many remain?", 23), + ("Sarah baked 24 cookies. She gave 1/3 to her neighbor and ate 4 herself. How many cookies does she have left?", 12), + ("A train travels 60 miles per hour for 3 hours, then 40 miles per hour for 2 hours. What is the total distance?", 260), + ("Mike has 50 dollars. He spends 15 dollars on lunch and 20 dollars on a book. How much money does he have left?", 15), + ("A garden has 9 rose bushes. Each bush has 12 roses. If 25 roses are picked, how many roses remain?", 83), + ("Lisa read 35 pages on Monday and twice as many on Tuesday. How many pages did she read in total?", 105), + ("A box contains 100 balls. 40 are red, 35 are blue, and the rest are green. How many green balls are there?", 25), +] + + +def extract_answer(text: str): + """Extract numerical answer after the FIRST #### marker.""" + match = re.search(r"####\s*(-?\d+(?:\.\d+)?)", text) + if match: + return float(match.group(1)) + return None + + +def get_model_name(base_url: str) -> str: + """Get the model name from the server.""" + resp = requests.get(f"{base_url}/v1/models", timeout=5) + resp.raise_for_status() + return resp.json()["data"][0]["id"] + + +def send_completion(prompt: str, max_tokens: int = 256, base_url: str = BASE_URL, model: str = "") -> str: + """Send a completion request to the server.""" + resp = requests.post( + f"{base_url}/v1/completions", + json={ + "model": model, + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, + }, + timeout=120, + ) + resp.raise_for_status() + return resp.json()["choices"][0]["text"] + + +def run_batch(questions, prefix, base_url=BASE_URL, model="", label=""): + """Run a batch of questions and return (correct, total, results).""" + results = [] + + def ask(q_and_a): + question, expected = q_and_a + prompt = prefix + f"Question: {question}\nAnswer:" + try: + response = send_completion(prompt, base_url=base_url, model=model) + answer = extract_answer(response) + correct = answer is not None and abs(answer - expected) < 0.01 + return (question, expected, answer, correct, response.strip()) + except Exception as e: + return (question, expected, None, False, f"ERROR: {e}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + results = list(pool.map(ask, questions)) + + num_correct = sum(1 for r in results if r[3]) + return num_correct, len(results), results + + +def main(): + parser = argparse.ArgumentParser(description="Test prefix cache accuracy") + parser.add_argument("--rounds", type=int, default=3, help="Number of rounds to repeat") + parser.add_argument("--base-url", type=str, default=BASE_URL) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + base_url = args.base_url + + # Health check + try: + r = requests.get(f"{base_url}/health", timeout=5) + r.raise_for_status() + except Exception as e: + print(f"Server not reachable at {base_url}: {e}") + sys.exit(1) + + model = get_model_name(base_url) + print(f"=== Prefix Cache Accuracy Test ===") + print(f"Server: {base_url}") + print(f"Model: {model}") + print(f"Questions per round: {len(TEST_QUESTIONS)}") + print(f"Rounds: {args.rounds}") + print(f"Shared prefix length: ~{len(MATH_PREFIX)} chars") + print() + + all_round_results = [] + + for round_num in range(1, args.rounds + 1): + t0 = time.time() + correct, total, results = run_batch(TEST_QUESTIONS, MATH_PREFIX, base_url=base_url, model=model, label=f"Round {round_num}") + elapsed = time.time() - t0 + accuracy = 100.0 * correct / total + all_round_results.append((correct, total, accuracy, elapsed)) + + print(f"Round {round_num}: {correct}/{total} correct ({accuracy:.1f}%) in {elapsed:.1f}s") + + if args.verbose: + for q, expected, got, ok, resp in results: + status = "OK" if ok else "WRONG" + print(f" [{status}] {q[:60]}... expected={expected} got={got}") + if not ok: + # Show first 200 chars of response for debugging + print(f" response: {resp[:200]}") + print() + + print() + print("=== Summary ===") + total_correct = sum(r[0] for r in all_round_results) + total_questions = sum(r[1] for r in all_round_results) + overall_accuracy = 100.0 * total_correct / total_questions + + # Check consistency: same questions should give same answers across rounds + print(f"Overall: {total_correct}/{total_questions} ({overall_accuracy:.1f}%)") + for i, (c, t, a, e) in enumerate(all_round_results, 1): + cache_note = "(cold)" if i == 1 else "(cache warm)" + print(f" Round {i}: {c}/{t} ({a:.1f}%) {e:.1f}s {cache_note}") + + # Verify rounds 2+ should be faster (cache hits) + if args.rounds >= 2: + r1_time = all_round_results[0][3] + r2_time = all_round_results[1][3] + speedup = r1_time / r2_time if r2_time > 0 else 0 + print(f"\n Speedup round 2 vs round 1: {speedup:.2f}x") + + # Pass/fail + if overall_accuracy >= 80.0: + print(f"\nPASS: accuracy {overall_accuracy:.1f}% >= 80%") + return 0 + else: + print(f"\nFAIL: accuracy {overall_accuracy:.1f}% < 80%") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) From f5ef449567bf91d21678f59087d314e9a7ece1a1 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Thu, 5 Mar 2026 17:11:49 +0800 Subject: [PATCH 02/12] wip --- atom/model_ops/attention_mla.py | 259 +++++++++++++++++++++++++++++++- 1 file changed, 258 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 435b88ad6..c24f74170 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -69,6 +69,258 @@ def is_rocm_aiter_fp4bmm_enabled() -> bool: from atom.model_ops.utils import quark_post_load_weights +import triton +import triton.language as tl + + +@triton.jit +def _triton_gather_kv_b_proj( + batch_size, + k_buffer, # [num_block, block_size, kv_c_dim + kv_pe_dim] + k_scale, # [1] or None + kv_indptr, # [batch_size + 1] + kv_indices, # [total_kv] + kv_prefix_sum_context_lens, # [batch_size + 1] + kv_proj_weight, # [tp_k_head_num * 2 * qk_nope_head_dim, kv_c_dim] + kv_proj_scale, # [tp_k_head_num * 2 * qk_nope_head_dim // 128, kv_c_dim // 128] + k_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim + kv_pe_dim] + v_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim] + KBlockSize: tl.constexpr, + TpNumHeads: tl.constexpr, + QkNopeHeadDim: tl.constexpr, + KV_CDim: tl.constexpr, + KV_PeDim: tl.constexpr, + ChunkK: tl.constexpr, +): + stride_k_buffer: tl.constexpr = KBlockSize * (KV_CDim + KV_PeDim) + stride_k_prefix: tl.constexpr = TpNumHeads * (QkNopeHeadDim + KV_PeDim) + stride_v_prefix: tl.constexpr = TpNumHeads * QkNopeHeadDim + + ScaleKGranularity: tl.constexpr = 128 + ScaleNGranularity: tl.constexpr = 128 + KBlocksPerChunkK: tl.constexpr = ChunkK // KBlockSize + assert KV_CDim == 4 * ScaleKGranularity + + # ===--------------------------------------------------- + # Workload Partition + # ===--------------------------------------------------- + pid = tl.program_id(0) + pid_batch = pid // TpNumHeads + pid_head = pid % TpNumHeads + + kv_block_start = tl.load(kv_indptr + pid_batch) + kv_block_end = tl.load(kv_indptr + pid_batch + 1) + + context_start = tl.load(kv_prefix_sum_context_lens + pid_batch) + context_end = tl.load(kv_prefix_sum_context_lens + pid_batch + 1) + + total_kv_block = kv_block_end - kv_block_start + total_kv_chunk = (total_kv_block + KBlocksPerChunkK - 1) // KBlocksPerChunkK + + # ===--------------------------------------------------- + # Pipeline Start + # ===--------------------------------------------------- + k_type = k_buffer.dtype.element_ty + if k_type == tl.bfloat16: + k_scalar_scale = 1.0 + else: + k_scalar_scale = tl.load(k_scale) + + k_nope_weight_base_offset = ( + kv_proj_weight + + pid_head * 2 * QkNopeHeadDim * KV_CDim + + tl.arange(0, QkNopeHeadDim)[:, None] * KV_CDim + + tl.arange(0, ScaleKGranularity)[None, :] + ) + k_nope_scale_base_offset = ( + kv_proj_scale + + pid_head + * 2 + * QkNopeHeadDim + * KV_CDim + // ScaleKGranularity + // ScaleNGranularity + + tl.arange(0, QkNopeHeadDim // ScaleNGranularity) + * (KV_CDim // ScaleKGranularity) + ) + + k_nope_weight_0 = tl.load(k_nope_weight_base_offset + 0 * ScaleKGranularity).to( + k_type + ) + k_nope_weight_1 = tl.load(k_nope_weight_base_offset + 1 * ScaleKGranularity).to( + k_type + ) + k_nope_weight_2 = tl.load(k_nope_weight_base_offset + 2 * ScaleKGranularity).to( + k_type + ) + k_nope_weight_3 = tl.load(k_nope_weight_base_offset + 3 * ScaleKGranularity).to( + k_type + ) + + v_nope_weight_0 = tl.load( + k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 0 * ScaleKGranularity + ).to(k_type) + v_nope_weight_1 = tl.load( + k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 1 * ScaleKGranularity + ).to(k_type) + v_nope_weight_2 = tl.load( + k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 2 * ScaleKGranularity + ).to(k_type) + v_nope_weight_3 = tl.load( + k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 3 * ScaleKGranularity + ).to(k_type) + + k_nope_scale_0 = tl.load(k_nope_scale_base_offset + 0) + k_nope_scale_1 = tl.load(k_nope_scale_base_offset + 1) + k_nope_scale_2 = tl.load(k_nope_scale_base_offset + 2) + k_nope_scale_3 = tl.load(k_nope_scale_base_offset + 3) + + v_nope_scale_0 = tl.load( + k_nope_scale_base_offset + + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity + + 0 + ) + v_nope_scale_1 = tl.load( + k_nope_scale_base_offset + + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity + + 1 + ) + v_nope_scale_2 = tl.load( + k_nope_scale_base_offset + + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity + + 2 + ) + v_nope_scale_3 = tl.load( + k_nope_scale_base_offset + + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity + + 3 + ) + + for chunk_id in range(total_kv_chunk): + kv_block_idx = tl.load( + kv_indices + + kv_block_start + + chunk_id * KBlocksPerChunkK + + tl.arange(0, ChunkK) // KBlockSize, + mask=chunk_id * KBlocksPerChunkK + tl.arange(0, ChunkK) // KBlockSize + < total_kv_block, + ) + kv_c_data_base_offset = ( + kv_block_idx[:, None] * stride_k_buffer + + tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim) + + tl.arange(0, ScaleKGranularity)[None, :] + ) # [ChunkK, kv_c_dim] + + accum_k = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32) + accum_v = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32) + + kv_c_data_0 = tl.load(k_buffer + kv_c_data_base_offset + 0 * ScaleKGranularity) + kv_c_data_1 = tl.load(k_buffer + kv_c_data_base_offset + 1 * ScaleKGranularity) + kv_c_data_2 = tl.load(k_buffer + kv_c_data_base_offset + 2 * ScaleKGranularity) + kv_c_data_3 = tl.load(k_buffer + kv_c_data_base_offset + 3 * ScaleKGranularity) + kv_pe_data = tl.load( + k_buffer + + kv_block_idx[:, None] * stride_k_buffer + + tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim) + + KV_CDim + + tl.arange(0, KV_PeDim)[None, :], + ) + + accum_k += tl.dot(kv_c_data_0, k_nope_weight_0.T) * k_nope_scale_0 + accum_v += tl.dot(kv_c_data_0, v_nope_weight_0.T) * v_nope_scale_0 + accum_k += tl.dot(kv_c_data_1, k_nope_weight_1.T) * k_nope_scale_1 + accum_v += tl.dot(kv_c_data_1, v_nope_weight_1.T) * v_nope_scale_1 + accum_k += tl.dot(kv_c_data_2, k_nope_weight_2.T) * k_nope_scale_2 + accum_v += tl.dot(kv_c_data_2, v_nope_weight_2.T) * v_nope_scale_2 + accum_k += tl.dot(kv_c_data_3, k_nope_weight_3.T) * k_nope_scale_3 + accum_v += tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_3 + + accum_k *= k_scalar_scale + accum_v *= k_scalar_scale + kv_pe_data *= k_scalar_scale + + context_mask = ( + context_start + chunk_id * ChunkK + tl.arange(0, ChunkK) < context_end + ) + tl.store( + k_prefix + + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] + * stride_k_prefix + + pid_head * (QkNopeHeadDim + KV_PeDim) + + QkNopeHeadDim + + tl.arange(0, KV_PeDim)[None, :], + kv_pe_data, + mask=context_mask[:, None], + ) + tl.store( + k_prefix + + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] + * stride_k_prefix + + pid_head * (QkNopeHeadDim + KV_PeDim) + + tl.arange(0, QkNopeHeadDim)[None, :], + accum_k, + mask=context_mask[:, None], + ) + tl.store( + v_prefix + + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] + * stride_v_prefix + + pid_head * QkNopeHeadDim + + tl.arange(0, QkNopeHeadDim)[None, :], + accum_v, + mask=context_mask[:, None], + ) + +def gather_kv_b_proj( + k_buffer: torch.Tensor, # [num_block, block_size, hidden_dim] + k_scale: torch.Tensor, # [1] + kv_indptr: torch.Tensor, # [batch_size + 1] + kv_indices: torch.Tensor, # len(kv_indices) = kv_indptr[-1] + kv_prefix_sum_context_lens: torch.Tensor, # [batch_size + 1] + kv_proj_weight: torch.Tensor, # [2 * 128 // TP * 128, 512] + kv_proj_scale: torch.Tensor, # [2 * 128 // TP, 4], blockscale=128 x 128 + k_prefix: torch.Tensor, # [total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim] + v_prefix: torch.Tensor, # [total_kv, tp_k_head_num, qk_nope_head_dim] +): + num_block, block_size, hidden_dim = k_buffer.shape + batch_size = kv_indptr.shape[0] - 1 + weight_n, weight_k = kv_proj_weight.shape + scale_n, scale_k = kv_proj_scale.shape + total_kv_k, tp_k_head_num_k, qk_nope_pe_dim = k_prefix.shape + total_kv_v, tp_k_head_num_v, qk_nope_dim = v_prefix.shape + + scale_k_granularity = weight_k // scale_k + scale_n_granularity = weight_n // scale_n + + ChunkK = 16 if k_buffer.dtype in [torch.float16, torch.bfloat16] else 32 + + assert total_kv_k == total_kv_v + assert tp_k_head_num_k == tp_k_head_num_v + assert scale_k_granularity == 128 + assert scale_n_granularity == 128 + assert ChunkK % block_size == 0 + + grid = (batch_size * tp_k_head_num_k,) + kernel = _triton_gather_kv_b_proj[grid]( + batch_size, + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + kv_proj_weight, + kv_proj_scale, + k_prefix, + v_prefix, + KBlockSize=block_size, + TpNumHeads=tp_k_head_num_k, + QkNopeHeadDim=qk_nope_dim, + KV_CDim=weight_k, + KV_PeDim=qk_nope_pe_dim - qk_nope_dim, + ChunkK=ChunkK, + num_stages=3, + ) + # MLA Specific Arguments @dataclass class MLAModules: @@ -622,7 +874,12 @@ def forward_impl_server_mode( kv_cache_data = forward_context.kv_cache_data kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache - if context.is_prefill and not use_prefill_mla and not attn_metadata.has_cached: + if context.is_prefill and not use_prefill_mla: + use_prefix_cache = ( + attn_metadata.has_cached + and not is_rocm_aiter_fp4bmm_enabled() + and self.qk_nope_head_dim == self.v_head_dim + ) prefill_q = self.q_proj(q, x_scale=q_scale).view( -1, self.num_heads, self.qk_head_dim ) From b11d9c54c260bcb10712044d9f63484b7bc4c11a Mon Sep 17 00:00:00 2001 From: jiayyu Date: Thu, 5 Mar 2026 19:35:49 +0800 Subject: [PATCH 03/12] support mla prefix cache --- atom/model_ops/attention_mla.py | 224 +++++++++++++++++++++++++++++++- 1 file changed, 221 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index c24f74170..efa867744 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -686,6 +686,133 @@ def _forward_prefill_mha( return self.o_proj(output.flatten(start_dim=-2)) + def _forward_prefill_mha_prefix_cache( + self, + q: torch.Tensor, + k_prefix: torch.Tensor, + v_prefix: torch.Tensor, + kv_c_normed: torch.Tensor, + k_rope: torch.Tensor, + attn_metadata: AttentionMetaData, + ) -> torch.Tensor: + """Prefill MHA with prefix cache (doc §4.3, §4.4): project new tokens, + concat [cached | new] per batch, flash attention.""" + assert attn_metadata is not None and attn_metadata.has_cached + + if k_rope.dim() == 2: + k_rope = k_rope.unsqueeze(1) + + # Step 4.3: Project new tokens' latent to k_new, v_new + if use_triton_gemm(): + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + if ( + fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None + and weight.dtype == dtypes.fp8 + ): + weight_shuffled = weight.reshape( + weight.shape[0] // 16, weight.shape[1] * 16 + ) + output_dtype = kv_c_normed.dtype + quant_func = functools_partial( + get_hip_quant(QuantType.per_1x128), transpose_scale=True + ) + q_input, x_scale = quant_func( + kv_c_normed, + quant_dtype=dtypes.fp8, + scale=getattr(self.kv_b_proj, "input_scale", None), + ) + k_nope_new, v_new = fused_gemm_a8w8_blockscale_preshuffle_split_cat( + q_input, + weight_shuffled, + k_rope.expand((-1, self.num_heads, -1)), + x_scale, + weight_scale, + self.qk_nope_head_dim, + self.v_head_dim, + output_dtype, + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope_new, v_new = kv_nope.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k_new = torch.cat( + (k_nope_new, k_rope.expand((*k_nope_new.shape[:-1], -1))), dim=-1 + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope_new, v_new = kv_nope.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + k_new = torch.cat( + (k_nope_new, k_rope.expand((*k_nope_new.shape[:-1], -1))), dim=-1 + ) + + # Step 4.4: Concat [cached | new] per batch, flash attention + bs = attn_metadata.cu_seqlens_q.shape[0] - 1 + cu_seqlens_q = attn_metadata.cu_seqlens_q + cu_seqlens_k = attn_metadata.cu_seqlens_k + num_cached_tokens = attn_metadata.num_cached_tokens + + total_tokens = cu_seqlens_k[-1].item() + output_dtype = q.dtype + k_full = torch.empty( + (total_tokens, self.num_heads, self.qk_head_dim), + dtype=output_dtype, + device=q.device, + ) + v_full = torch.empty( + (total_tokens, self.num_heads, self.v_head_dim), + dtype=output_dtype, + device=q.device, + ) + + cached_offset = 0 + for i in range(bs): + start = cu_seqlens_k[i].item() + end = cu_seqlens_k[i + 1].item() + cached_i = num_cached_tokens[i].item() + new_i = end - start - cached_i + + # k_prefix/v_prefix are 2D [total_cached, num_heads*head_dim], reshape to 3D + k_full[start : start + cached_i] = ( + k_prefix[cached_offset : cached_offset + cached_i] + .view(cached_i, self.num_heads, self.qk_head_dim) + .to(output_dtype) + ) + v_full[start : start + cached_i] = ( + v_prefix[cached_offset : cached_offset + cached_i] + .view(cached_i, self.num_heads, self.v_head_dim) + .to(output_dtype) + ) + + new_start = cu_seqlens_q[i].item() + k_full[start + cached_i : end] = k_new[new_start : new_start + new_i] + v_full[start + cached_i : end] = v_new[new_start : new_start + new_i] + + cached_offset += cached_i + + output = flash_attn_varlen_func( + q=q, + k=k_full, + v=v_full, + cu_seqlens_q=attn_metadata.cu_seqlens_q, + cu_seqlens_k=attn_metadata.cu_seqlens_k, + max_seqlen_q=attn_metadata.max_seqlen_q, + max_seqlen_k=attn_metadata.max_seqlen_k, + min_seqlen_q=attn_metadata.min_seqlen_q, + dropout_p=attn_metadata.dropout_p, + softmax_scale=self.scale, + causal=True, + ) + + return self.o_proj(output.flatten(start_dim=-2)) + def _forward_prefill_mla( self, q: torch.Tensor, @@ -880,6 +1007,86 @@ def forward_impl_server_mode( and not is_rocm_aiter_fp4bmm_enabled() and self.qk_nope_head_dim == self.v_head_dim ) + + # Step 4.1: Build cached-only metadata (doc §4.1) + if use_prefix_cache: + num_cached_tokens = attn_metadata.num_cached_tokens + bs = num_cached_tokens.shape[0] + total_cached = num_cached_tokens.sum().item() + block_size = kv_cache.shape[1] + num_cached_blocks = ( + num_cached_tokens.to(q.device) + block_size - 1 + ) // block_size + + kv_indptr_cached = torch.zeros( + bs + 1, dtype=torch.int32, device=q.device + ) + kv_indptr_cached[1:] = torch.cumsum(num_cached_blocks, dim=0) + kv_prefix_sum_context_lens = torch.zeros( + bs + 1, dtype=torch.int32, device=q.device + ) + kv_prefix_sum_context_lens[1:] = torch.cumsum( + num_cached_tokens.to(q.device), dim=0 + ) + + kv_indices_cached = torch.empty( + kv_indptr_cached[-1].item(), + dtype=torch.int32, + device=q.device, + ) + block_tables = attn_metadata.block_tables + for i in range(bs): + n = num_cached_blocks[i].item() + if n > 0: + kv_indices_cached[ + kv_indptr_cached[i].item() : kv_indptr_cached[ + i + 1 + ].item() + ] = block_tables[i, :n] + + # Step 4.2: gather_kv_b_proj - Gather + Project (doc §4.2) + k_prefix = torch.zeros( + ( + total_cached, + self.num_heads + * (self.qk_nope_head_dim + self.qk_rope_head_dim), + ), + device=q.device, + dtype=( + dtypes.fp8 + if self.kv_cache_dtype.startswith("fp8") + else self.dtype + ), + ) + v_prefix = torch.zeros( + ( + total_cached, + self.num_heads * self.qk_nope_head_dim, + ), + device=q.device, + dtype=( + dtypes.fp8 + if self.kv_cache_dtype.startswith("fp8") + else self.dtype + ), + ) + + gather_kv_b_proj( + kv_cache, + self._k_scale, + kv_indptr_cached, + kv_indices_cached, + kv_prefix_sum_context_lens, + self.kv_b_proj.weight, + self.kv_b_proj.weight_scale, + k_prefix.view( + -1, + self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + ), + v_prefix.view(-1, self.num_heads, self.qk_nope_head_dim), + ) + prefill_q = self.q_proj(q, x_scale=q_scale).view( -1, self.num_heads, self.qk_head_dim ) @@ -896,9 +1103,20 @@ def forward_impl_server_mode( scale=self._k_scale, ) - output = self._forward_prefill_mha( - prefill_q, k_nope, k_rope, kv_cache, attn_metadata - ) + # Step 4.3-4.4: Concat cached + new, flash attention (doc §4.3, §4.4) + if use_prefix_cache: + output = self._forward_prefill_mha_prefix_cache( + prefill_q, + k_prefix, + v_prefix, + k_nope, + k_rope, + attn_metadata, + ) + else: + output = self._forward_prefill_mha( + prefill_q, k_nope, k_rope, kv_cache, attn_metadata + ) else: q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) From 0bee8dfa4e888cb327e656b40372d21f01fc494f Mon Sep 17 00:00:00 2001 From: jiayyu Date: Fri, 6 Mar 2026 17:26:53 +0800 Subject: [PATCH 04/12] mha --- atom/model_ops/attention_mha.py | 142 +++++++++++++++++++++++++- atom/model_ops/attentions/backends.py | 11 ++ atom/model_ops/base_attention.py | 17 +-- atom/utils/forward_context.py | 3 + 4 files changed, 165 insertions(+), 8 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index b38ca0ee3..c2a5d75db 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -15,8 +15,13 @@ from .attention_mla import MLAModules +import logging + from atom.plugin.prepare import is_plugin_mode, is_vllm from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode +from atom.model_ops.base_attention import cp_mha_gather_cache + +logger = logging.getLogger("atom") @PagedAttentionImplDecoratorForPluginMode @@ -127,6 +132,16 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): and self.q_norm is not None and self.k_norm is not None ): + # fused_qk_norm_rope_cache_quant_shuffle expects V cache layout + # [num_blocks, num_kv_heads, block_size//x, head_size, x], not [n, nh, hd, bs] + x = 16 // k_cache.element_size() + if k_cache.dim() == 5 and v_cache.dim() == 4: + n, nh, hd, bs = v_cache.shape + v_cache_shuffle = v_cache.view( + n, nh, bs // x, hd, x + ) + else: + v_cache_shuffle = v_cache fused_qk_norm_rope_cache_quant_shuffle( qkv, num_heads_q=self.num_heads, @@ -140,7 +155,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): is_neox_style=self.rotary_emb.is_neox_style, pos_ids=position, k_cache=k_cache, - v_cache=v_cache, + v_cache=v_cache_shuffle, slot_mapping=attn_metadata.slot_mapping, kv_cache_dtype=( "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype @@ -212,8 +227,133 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): asm_layout=asm_layout, ) + # Prefix cache hit: gather cached KV from paged cache and concat with new tokens + if attn_metadata.has_cached: + q, k, v, k_cache, v_cache, k_scale, v_scale = self._gather_prefix_and_concat_kv( + q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata + ) + return q, k, v, k_cache, v_cache, k_scale, v_scale + def _gather_prefix_and_concat_kv( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + attn_metadata, + ): + """ + When prefix cache hits, gather cached KV from paged cache and concat with + new tokens' k, v for full-sequence attention (e.g. flash_attn). + """ + bs = attn_metadata.cu_seqlens_q.shape[0] - 1 + cu_seqlens_q = attn_metadata.cu_seqlens_q + cu_seqlens_k = attn_metadata.cu_seqlens_k + cached_seqlen = attn_metadata.num_cached_tokens # [bs], from prepare_prefill + + total_cached = cached_seqlen.sum().item() + + num_kv_heads = k.shape[1] + head_dim = k.shape[2] + device = k.device + dtype = k.dtype + + # Build metadata for cp_mha_gather_cache (cached tokens only) + # token_to_batch: [0]*cached[0] + [1]*cached[1] + ... + # cu_seqlens_kv: [0, cached[0], cached[0]+cached[1], ...] + # seq_starts: [0]*bs (prefix starts at position 0) + token_to_batch = torch.zeros( + total_cached, dtype=torch.int32, device=device + ) + cu_seqlens_kv = [0] + for i in range(bs): + c = cached_seqlen[i].item() + token_to_batch[cu_seqlens_kv[-1] : cu_seqlens_kv[-1] + c] = i + cu_seqlens_kv.append(cu_seqlens_kv[-1] + c) + cu_seqlens_kv = torch.tensor(cu_seqlens_kv, dtype=torch.int32, device=device) + seq_starts = torch.zeros(bs, dtype=torch.int32, device=device) + + k_prefix = torch.empty( + (total_cached, num_kv_heads, head_dim), dtype=dtype, device=device + ) + v_prefix = torch.empty( + (total_cached, num_kv_heads, head_dim), dtype=dtype, device=device + ) + + # Convert cache for cp_mha_gather_cache + # fused_qk_norm_rope_cache_quant_shuffle: K [n, nh, hd//x, bs, x], V [n, nh, bs//x, hd, x] (SHUFFLE) + # fused_qk_rope_reshape_and_cache: K [n, nh, hd//x, bs, x], V [n, nh, hd, bs] -> NHD + if k_cache.dim() == 5: + x = 16 // k_cache.element_size() + n, nh, _, block_size, _ = k_cache.shape + if v_cache.dim() == 4: + # fused_qk_norm_rope_cache_quant_shuffle: V data in [n, nh, bs//x, hd, x] layout + use_shuffle = True + k_cache_gather = k_cache + v_cache_gather = v_cache.view(n, nh, block_size // x, head_dim, x) + else: + # fused_qk_rope_reshape_and_cache: V [n, nh, hd, bs] -> NHD + use_shuffle = False + k_cache_gather = ( + k_cache.permute(0, 3, 1, 2, 4) + .contiguous() + .view(n, block_size, nh, head_dim) + ) + v_cache_gather = v_cache.permute(0, 3, 1, 2).contiguous() + else: + use_shuffle = False + k_cache_gather = k_cache + v_cache_gather = v_cache + block_size = k_cache.shape[1] + + block_tables = attn_metadata.block_tables + cp_mha_gather_cache( + key_cache=k_cache_gather, + value_cache=v_cache_gather, + key=k_prefix, + value=v_prefix, + block_tables=block_tables, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv, + token_to_batch=token_to_batch, + seq_starts=seq_starts, + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="SHUFFLE" if use_shuffle else "NHD", + total_tokens=total_cached, + ) + + # Build full k, v: for each batch i, [cached_i | new_i] + total_tokens = cu_seqlens_k[-1].item() + k_full = torch.empty( + (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device + ) + v_full = torch.empty( + (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device + ) + + cached_offset = 0 + for i in range(bs): + start = cu_seqlens_k[i].item() + end = cu_seqlens_k[i + 1].item() + cached_i = cached_seqlen[i].item() + new_i = end - start - cached_i + + k_full[start : start + cached_i] = k_prefix[cached_offset : cached_offset + cached_i] + v_full[start : start + cached_i] = v_prefix[cached_offset : cached_offset + cached_i] + + new_start = cu_seqlens_q[i].item() + k_full[start + cached_i : end] = k[new_start : new_start + new_i] + v_full[start + cached_i : end] = v[new_start : new_start + new_i] + + cached_offset += cached_i + + return q, k_full, v_full, k_cache, v_cache, k_scale, v_scale + def paged_attention_triton( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext ): diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index 91e6f0af3..7bd1b675c 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -219,6 +219,16 @@ def prepare_prefill(self, batch: ScheduledBatch): self.block_ratio, ) ctx["block_tables_converted"] = var["block_tables_converted"].gpu[:bs] + num_cached_tokens = None + if has_cached: + num_cached_tokens = torch.tensor( + batch.num_cached_tokens[:bs], dtype=torch.int32, pin_memory=True + ).cuda(non_blocking=True) + if self.model_runner.rank == 0: + logger.info(f"{has_cached=}") + logger.info( + f"Prefill batch has {num_cached_tokens.sum().item()} cached tokens and of {sum_scheduled_tokens} total tokens" + ) attn_metadata = AttentionMetaData( cu_seqlens_k=cu_seqlens_k.cuda(non_blocking=True), max_seqlen_q=max_seqlen_q, @@ -226,6 +236,7 @@ def prepare_prefill(self, batch: ScheduledBatch): min_seqlen_q=min_seqlen_q, dropout_p=dropout_p, has_cached=has_cached, + num_cached_tokens=num_cached_tokens, **ctx, ) positions = var["positions"].copy_to_gpu(sum_scheduled_tokens) diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 47caa6a36..90054852d 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -152,13 +152,16 @@ def cp_mha_gather_cache( assert k_scales is not None and v_scales is not None head_dim = key.shape[2] x = 16 // key_cache.element_size() - # For k cache layout: [num_blocks, num_heads, page_size, head_dim] - assert head_dim == key_cache.shape[3], ( - "We assume your kv cache layout is [num_blocks, " - "page_size, num_heads, head_dim], but got otherwise" - ) - page_size = key_cache.shape[1] - num_heads = key_cache.shape[2] + if kv_cache_layout == "NHD": + # K: [num_blocks, page_size, num_heads, head_dim] + assert head_dim == key_cache.shape[3] + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + else: + # SHUFFLE: K [num_blocks, num_heads, head_dim//x, page_size, x] + assert key_cache.dim() == 5 and head_dim == key_cache.shape[2] * key_cache.shape[4] + page_size = key_cache.shape[3] + num_heads = key_cache.shape[1] grid = lambda meta: (total_tokens, num_heads) # noqa: E731 cp_mha_gather_cache_kernel[grid]( diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 8f46d6689..5f96a89df 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -189,6 +189,7 @@ class AttentionMetaData: block_tables_converted: Optional[torch.Tensor] = None has_cached: bool = False + num_cached_tokens: Optional[torch.Tensor] = None # [bs] when has_cached # only used for plugin mode to store the metadata for attn plugin_metadata: Optional["MetadataForPluginMode"] = None @@ -221,8 +222,10 @@ def __init__( token_to_seq_idxs: Optional[torch.Tensor] = None, plugin_metadata: Optional["MetadataForPluginMode"] = None, has_cached: bool = False, + num_cached_tokens: Optional[torch.Tensor] = None, ): self.has_cached = has_cached + self.num_cached_tokens = num_cached_tokens self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k self.max_seqlen_q = max_seqlen_q From 7f5149f28486278b6de93313c60955da5f14f199 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Wed, 11 Mar 2026 16:42:51 +0800 Subject: [PATCH 05/12] mla shuffled weight to be supported --- atom/model_ops/attention_mla.py | 514 +++---------------------- atom/model_ops/attentions/aiter_mla.py | 11 +- 2 files changed, 56 insertions(+), 469 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index efa867744..110c379fa 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -68,258 +68,7 @@ def is_rocm_aiter_fp4bmm_enabled() -> bool: from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 from atom.model_ops.utils import quark_post_load_weights - -import triton -import triton.language as tl - - -@triton.jit -def _triton_gather_kv_b_proj( - batch_size, - k_buffer, # [num_block, block_size, kv_c_dim + kv_pe_dim] - k_scale, # [1] or None - kv_indptr, # [batch_size + 1] - kv_indices, # [total_kv] - kv_prefix_sum_context_lens, # [batch_size + 1] - kv_proj_weight, # [tp_k_head_num * 2 * qk_nope_head_dim, kv_c_dim] - kv_proj_scale, # [tp_k_head_num * 2 * qk_nope_head_dim // 128, kv_c_dim // 128] - k_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim + kv_pe_dim] - v_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim] - KBlockSize: tl.constexpr, - TpNumHeads: tl.constexpr, - QkNopeHeadDim: tl.constexpr, - KV_CDim: tl.constexpr, - KV_PeDim: tl.constexpr, - ChunkK: tl.constexpr, -): - stride_k_buffer: tl.constexpr = KBlockSize * (KV_CDim + KV_PeDim) - stride_k_prefix: tl.constexpr = TpNumHeads * (QkNopeHeadDim + KV_PeDim) - stride_v_prefix: tl.constexpr = TpNumHeads * QkNopeHeadDim - - ScaleKGranularity: tl.constexpr = 128 - ScaleNGranularity: tl.constexpr = 128 - KBlocksPerChunkK: tl.constexpr = ChunkK // KBlockSize - assert KV_CDim == 4 * ScaleKGranularity - - # ===--------------------------------------------------- - # Workload Partition - # ===--------------------------------------------------- - pid = tl.program_id(0) - pid_batch = pid // TpNumHeads - pid_head = pid % TpNumHeads - - kv_block_start = tl.load(kv_indptr + pid_batch) - kv_block_end = tl.load(kv_indptr + pid_batch + 1) - - context_start = tl.load(kv_prefix_sum_context_lens + pid_batch) - context_end = tl.load(kv_prefix_sum_context_lens + pid_batch + 1) - - total_kv_block = kv_block_end - kv_block_start - total_kv_chunk = (total_kv_block + KBlocksPerChunkK - 1) // KBlocksPerChunkK - - # ===--------------------------------------------------- - # Pipeline Start - # ===--------------------------------------------------- - k_type = k_buffer.dtype.element_ty - if k_type == tl.bfloat16: - k_scalar_scale = 1.0 - else: - k_scalar_scale = tl.load(k_scale) - - k_nope_weight_base_offset = ( - kv_proj_weight - + pid_head * 2 * QkNopeHeadDim * KV_CDim - + tl.arange(0, QkNopeHeadDim)[:, None] * KV_CDim - + tl.arange(0, ScaleKGranularity)[None, :] - ) - k_nope_scale_base_offset = ( - kv_proj_scale - + pid_head - * 2 - * QkNopeHeadDim - * KV_CDim - // ScaleKGranularity - // ScaleNGranularity - + tl.arange(0, QkNopeHeadDim // ScaleNGranularity) - * (KV_CDim // ScaleKGranularity) - ) - - k_nope_weight_0 = tl.load(k_nope_weight_base_offset + 0 * ScaleKGranularity).to( - k_type - ) - k_nope_weight_1 = tl.load(k_nope_weight_base_offset + 1 * ScaleKGranularity).to( - k_type - ) - k_nope_weight_2 = tl.load(k_nope_weight_base_offset + 2 * ScaleKGranularity).to( - k_type - ) - k_nope_weight_3 = tl.load(k_nope_weight_base_offset + 3 * ScaleKGranularity).to( - k_type - ) - - v_nope_weight_0 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 0 * ScaleKGranularity - ).to(k_type) - v_nope_weight_1 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 1 * ScaleKGranularity - ).to(k_type) - v_nope_weight_2 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 2 * ScaleKGranularity - ).to(k_type) - v_nope_weight_3 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 3 * ScaleKGranularity - ).to(k_type) - - k_nope_scale_0 = tl.load(k_nope_scale_base_offset + 0) - k_nope_scale_1 = tl.load(k_nope_scale_base_offset + 1) - k_nope_scale_2 = tl.load(k_nope_scale_base_offset + 2) - k_nope_scale_3 = tl.load(k_nope_scale_base_offset + 3) - - v_nope_scale_0 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 0 - ) - v_nope_scale_1 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 1 - ) - v_nope_scale_2 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 2 - ) - v_nope_scale_3 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 3 - ) - - for chunk_id in range(total_kv_chunk): - kv_block_idx = tl.load( - kv_indices - + kv_block_start - + chunk_id * KBlocksPerChunkK - + tl.arange(0, ChunkK) // KBlockSize, - mask=chunk_id * KBlocksPerChunkK + tl.arange(0, ChunkK) // KBlockSize - < total_kv_block, - ) - kv_c_data_base_offset = ( - kv_block_idx[:, None] * stride_k_buffer - + tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim) - + tl.arange(0, ScaleKGranularity)[None, :] - ) # [ChunkK, kv_c_dim] - - accum_k = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32) - accum_v = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32) - - kv_c_data_0 = tl.load(k_buffer + kv_c_data_base_offset + 0 * ScaleKGranularity) - kv_c_data_1 = tl.load(k_buffer + kv_c_data_base_offset + 1 * ScaleKGranularity) - kv_c_data_2 = tl.load(k_buffer + kv_c_data_base_offset + 2 * ScaleKGranularity) - kv_c_data_3 = tl.load(k_buffer + kv_c_data_base_offset + 3 * ScaleKGranularity) - kv_pe_data = tl.load( - k_buffer - + kv_block_idx[:, None] * stride_k_buffer - + tl.arange(0, ChunkK)[:, None] % KBlockSize * (KV_CDim + KV_PeDim) - + KV_CDim - + tl.arange(0, KV_PeDim)[None, :], - ) - - accum_k += tl.dot(kv_c_data_0, k_nope_weight_0.T) * k_nope_scale_0 - accum_v += tl.dot(kv_c_data_0, v_nope_weight_0.T) * v_nope_scale_0 - accum_k += tl.dot(kv_c_data_1, k_nope_weight_1.T) * k_nope_scale_1 - accum_v += tl.dot(kv_c_data_1, v_nope_weight_1.T) * v_nope_scale_1 - accum_k += tl.dot(kv_c_data_2, k_nope_weight_2.T) * k_nope_scale_2 - accum_v += tl.dot(kv_c_data_2, v_nope_weight_2.T) * v_nope_scale_2 - accum_k += tl.dot(kv_c_data_3, k_nope_weight_3.T) * k_nope_scale_3 - accum_v += tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_3 - - accum_k *= k_scalar_scale - accum_v *= k_scalar_scale - kv_pe_data *= k_scalar_scale - - context_mask = ( - context_start + chunk_id * ChunkK + tl.arange(0, ChunkK) < context_end - ) - tl.store( - k_prefix - + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] - * stride_k_prefix - + pid_head * (QkNopeHeadDim + KV_PeDim) - + QkNopeHeadDim - + tl.arange(0, KV_PeDim)[None, :], - kv_pe_data, - mask=context_mask[:, None], - ) - tl.store( - k_prefix - + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] - * stride_k_prefix - + pid_head * (QkNopeHeadDim + KV_PeDim) - + tl.arange(0, QkNopeHeadDim)[None, :], - accum_k, - mask=context_mask[:, None], - ) - tl.store( - v_prefix - + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] - * stride_v_prefix - + pid_head * QkNopeHeadDim - + tl.arange(0, QkNopeHeadDim)[None, :], - accum_v, - mask=context_mask[:, None], - ) - -def gather_kv_b_proj( - k_buffer: torch.Tensor, # [num_block, block_size, hidden_dim] - k_scale: torch.Tensor, # [1] - kv_indptr: torch.Tensor, # [batch_size + 1] - kv_indices: torch.Tensor, # len(kv_indices) = kv_indptr[-1] - kv_prefix_sum_context_lens: torch.Tensor, # [batch_size + 1] - kv_proj_weight: torch.Tensor, # [2 * 128 // TP * 128, 512] - kv_proj_scale: torch.Tensor, # [2 * 128 // TP, 4], blockscale=128 x 128 - k_prefix: torch.Tensor, # [total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim] - v_prefix: torch.Tensor, # [total_kv, tp_k_head_num, qk_nope_head_dim] -): - num_block, block_size, hidden_dim = k_buffer.shape - batch_size = kv_indptr.shape[0] - 1 - weight_n, weight_k = kv_proj_weight.shape - scale_n, scale_k = kv_proj_scale.shape - total_kv_k, tp_k_head_num_k, qk_nope_pe_dim = k_prefix.shape - total_kv_v, tp_k_head_num_v, qk_nope_dim = v_prefix.shape - - scale_k_granularity = weight_k // scale_k - scale_n_granularity = weight_n // scale_n - - ChunkK = 16 if k_buffer.dtype in [torch.float16, torch.bfloat16] else 32 - - assert total_kv_k == total_kv_v - assert tp_k_head_num_k == tp_k_head_num_v - assert scale_k_granularity == 128 - assert scale_n_granularity == 128 - assert ChunkK % block_size == 0 - - grid = (batch_size * tp_k_head_num_k,) - kernel = _triton_gather_kv_b_proj[grid]( - batch_size, - k_buffer, - k_scale, - kv_indptr, - kv_indices, - kv_prefix_sum_context_lens, - kv_proj_weight, - kv_proj_scale, - k_prefix, - v_prefix, - KBlockSize=block_size, - TpNumHeads=tp_k_head_num_k, - QkNopeHeadDim=qk_nope_dim, - KV_CDim=weight_k, - KV_PeDim=qk_nope_pe_dim - qk_nope_dim, - ChunkK=ChunkK, - num_stages=3, - ) +from aiter.ops.triton.gather_kv_b_proj import gather_kv_b_proj # MLA Specific Arguments @dataclass @@ -686,133 +435,6 @@ def _forward_prefill_mha( return self.o_proj(output.flatten(start_dim=-2)) - def _forward_prefill_mha_prefix_cache( - self, - q: torch.Tensor, - k_prefix: torch.Tensor, - v_prefix: torch.Tensor, - kv_c_normed: torch.Tensor, - k_rope: torch.Tensor, - attn_metadata: AttentionMetaData, - ) -> torch.Tensor: - """Prefill MHA with prefix cache (doc §4.3, §4.4): project new tokens, - concat [cached | new] per batch, flash attention.""" - assert attn_metadata is not None and attn_metadata.has_cached - - if k_rope.dim() == 2: - k_rope = k_rope.unsqueeze(1) - - # Step 4.3: Project new tokens' latent to k_new, v_new - if use_triton_gemm(): - weight = self.kv_b_proj.weight - weight_scale = self.kv_b_proj.weight_scale - if ( - fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None - and weight.dtype == dtypes.fp8 - ): - weight_shuffled = weight.reshape( - weight.shape[0] // 16, weight.shape[1] * 16 - ) - output_dtype = kv_c_normed.dtype - quant_func = functools_partial( - get_hip_quant(QuantType.per_1x128), transpose_scale=True - ) - q_input, x_scale = quant_func( - kv_c_normed, - quant_dtype=dtypes.fp8, - scale=getattr(self.kv_b_proj, "input_scale", None), - ) - k_nope_new, v_new = fused_gemm_a8w8_blockscale_preshuffle_split_cat( - q_input, - weight_shuffled, - k_rope.expand((-1, self.num_heads, -1)), - x_scale, - weight_scale, - self.qk_nope_head_dim, - self.v_head_dim, - output_dtype, - ) - else: - kv_nope = self.kv_b_proj(kv_c_normed).view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope_new, v_new = kv_nope.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - k_new = torch.cat( - (k_nope_new, k_rope.expand((*k_nope_new.shape[:-1], -1))), dim=-1 - ) - else: - kv_nope = self.kv_b_proj(kv_c_normed).view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim - ) - k_nope_new, v_new = kv_nope.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - k_new = torch.cat( - (k_nope_new, k_rope.expand((*k_nope_new.shape[:-1], -1))), dim=-1 - ) - - # Step 4.4: Concat [cached | new] per batch, flash attention - bs = attn_metadata.cu_seqlens_q.shape[0] - 1 - cu_seqlens_q = attn_metadata.cu_seqlens_q - cu_seqlens_k = attn_metadata.cu_seqlens_k - num_cached_tokens = attn_metadata.num_cached_tokens - - total_tokens = cu_seqlens_k[-1].item() - output_dtype = q.dtype - k_full = torch.empty( - (total_tokens, self.num_heads, self.qk_head_dim), - dtype=output_dtype, - device=q.device, - ) - v_full = torch.empty( - (total_tokens, self.num_heads, self.v_head_dim), - dtype=output_dtype, - device=q.device, - ) - - cached_offset = 0 - for i in range(bs): - start = cu_seqlens_k[i].item() - end = cu_seqlens_k[i + 1].item() - cached_i = num_cached_tokens[i].item() - new_i = end - start - cached_i - - # k_prefix/v_prefix are 2D [total_cached, num_heads*head_dim], reshape to 3D - k_full[start : start + cached_i] = ( - k_prefix[cached_offset : cached_offset + cached_i] - .view(cached_i, self.num_heads, self.qk_head_dim) - .to(output_dtype) - ) - v_full[start : start + cached_i] = ( - v_prefix[cached_offset : cached_offset + cached_i] - .view(cached_i, self.num_heads, self.v_head_dim) - .to(output_dtype) - ) - - new_start = cu_seqlens_q[i].item() - k_full[start + cached_i : end] = k_new[new_start : new_start + new_i] - v_full[start + cached_i : end] = v_new[new_start : new_start + new_i] - - cached_offset += cached_i - - output = flash_attn_varlen_func( - q=q, - k=k_full, - v=v_full, - cu_seqlens_q=attn_metadata.cu_seqlens_q, - cu_seqlens_k=attn_metadata.cu_seqlens_k, - max_seqlen_q=attn_metadata.max_seqlen_q, - max_seqlen_k=attn_metadata.max_seqlen_k, - min_seqlen_q=attn_metadata.min_seqlen_q, - dropout_p=attn_metadata.dropout_p, - softmax_scale=self.scale, - causal=True, - ) - - return self.o_proj(output.flatten(start_dim=-2)) - def _forward_prefill_mla( self, q: torch.Tensor, @@ -1008,111 +630,73 @@ def forward_impl_server_mode( and self.qk_nope_head_dim == self.v_head_dim ) - # Step 4.1: Build cached-only metadata (doc §4.1) - if use_prefix_cache: - num_cached_tokens = attn_metadata.num_cached_tokens - bs = num_cached_tokens.shape[0] - total_cached = num_cached_tokens.sum().item() - block_size = kv_cache.shape[1] - num_cached_blocks = ( - num_cached_tokens.to(q.device) + block_size - 1 - ) // block_size - - kv_indptr_cached = torch.zeros( - bs + 1, dtype=torch.int32, device=q.device - ) - kv_indptr_cached[1:] = torch.cumsum(num_cached_blocks, dim=0) - kv_prefix_sum_context_lens = torch.zeros( - bs + 1, dtype=torch.int32, device=q.device - ) - kv_prefix_sum_context_lens[1:] = torch.cumsum( - num_cached_tokens.to(q.device), dim=0 + prefill_q = self.q_proj(q, x_scale=q_scale).view( + -1, self.num_heads, self.qk_head_dim + ) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] + self.rotary_emb(positions, prefill_q_pe, k_rope) + + if kv_cache.numel() > 0: + concat_and_cache_mla( + k_nope, + k_rope.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, ) - kv_indices_cached = torch.empty( - kv_indptr_cached[-1].item(), - dtype=torch.int32, - device=q.device, + if use_prefix_cache: + total_tokens = attn_metadata.cu_seqlens_k[-1].item() + output_dtype = ( + dtypes.fp8 + if self.kv_cache_dtype.startswith("fp8") + else self.dtype ) - block_tables = attn_metadata.block_tables - for i in range(bs): - n = num_cached_blocks[i].item() - if n > 0: - kv_indices_cached[ - kv_indptr_cached[i].item() : kv_indptr_cached[ - i + 1 - ].item() - ] = block_tables[i, :n] - - # Step 4.2: gather_kv_b_proj - Gather + Project (doc §4.2) - k_prefix = torch.zeros( + k_full = torch.empty( ( - total_cached, - self.num_heads - * (self.qk_nope_head_dim + self.qk_rope_head_dim), + total_tokens, + self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, ), device=q.device, - dtype=( - dtypes.fp8 - if self.kv_cache_dtype.startswith("fp8") - else self.dtype - ), + dtype=output_dtype, ) - v_prefix = torch.zeros( + v_full = torch.empty( ( - total_cached, - self.num_heads * self.qk_nope_head_dim, + total_tokens, + self.num_heads, + self.qk_nope_head_dim, ), device=q.device, - dtype=( - dtypes.fp8 - if self.kv_cache_dtype.startswith("fp8") - else self.dtype - ), + dtype=output_dtype, ) gather_kv_b_proj( kv_cache, self._k_scale, - kv_indptr_cached, - kv_indices_cached, - kv_prefix_sum_context_lens, + attn_metadata.kv_indptr, + attn_metadata.kv_indices, + attn_metadata.cu_seqlens_k, self.kv_b_proj.weight, self.kv_b_proj.weight_scale, - k_prefix.view( - -1, - self.num_heads, - self.qk_nope_head_dim + self.qk_rope_head_dim, - ), - v_prefix.view(-1, self.num_heads, self.qk_nope_head_dim), + k_full, + v_full, ) - - prefill_q = self.q_proj(q, x_scale=q_scale).view( - -1, self.num_heads, self.qk_head_dim - ) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] - self.rotary_emb(positions, prefill_q_pe, k_rope) - - if kv_cache.numel() > 0: - concat_and_cache_mla( - k_nope, - k_rope.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=self._k_scale, - ) - - # Step 4.3-4.4: Concat cached + new, flash attention (doc §4.3, §4.4) - if use_prefix_cache: - output = self._forward_prefill_mha_prefix_cache( - prefill_q, - k_prefix, - v_prefix, - k_nope, - k_rope, - attn_metadata, + output = flash_attn_varlen_func( + q=prefill_q, + k=k_full, + v=v_full, + cu_seqlens_q=attn_metadata.cu_seqlens_q, + cu_seqlens_k=attn_metadata.cu_seqlens_k, + max_seqlen_q=attn_metadata.max_seqlen_q, + max_seqlen_k=attn_metadata.max_seqlen_k, + min_seqlen_q=attn_metadata.min_seqlen_q, + dropout_p=attn_metadata.dropout_p, + softmax_scale=self.scale, + causal=True, ) + output = self.o_proj(output.flatten(start_dim=-2)) else: output = self._forward_prefill_mha( prefill_q, k_nope, k_rope, kv_cache, attn_metadata diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 0643b25d4..8e1330d67 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -309,11 +309,14 @@ def prepare_prefill(self, batch: ScheduledBatch): ) attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] - if attn_metadata.block_tables is None: - self.prepare_block_tables(batch) - attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) + # kv_indices_generate_triton expects RAW block_tables (physical block ids, + # one per block_ratio tokens). When is_sparse, attn_metadata.block_tables + # may have been overwritten with block_tables_converted (slot per token). + # Always use raw block_tables for kv_indices. + self.prepare_block_tables(batch) + block_tables_for_kv = var["block_tables"].copy_to_gpu(bs) kv_indices_generate_triton( - attn_metadata.block_tables, + block_tables_for_kv, attn_metadata.kv_indices, attn_metadata.kv_indptr, self.block_ratio, From d39ad62b13e7f65c70d876bf094368535a4f9fd3 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Wed, 11 Mar 2026 17:07:33 +0800 Subject: [PATCH 06/12] fix format --- atom/model_ops/attention_mha.py | 22 ++++++++++++---------- atom/model_ops/attention_mla.py | 5 ++--- atom/model_ops/attentions/backends.py | 3 +-- atom/model_ops/base_attention.py | 4 +++- tests/test_prefix_cache_accuracy.py | 3 +-- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index c2a5d75db..48f295cde 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -137,9 +137,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): x = 16 // k_cache.element_size() if k_cache.dim() == 5 and v_cache.dim() == 4: n, nh, hd, bs = v_cache.shape - v_cache_shuffle = v_cache.view( - n, nh, bs // x, hd, x - ) + v_cache_shuffle = v_cache.view(n, nh, bs // x, hd, x) else: v_cache_shuffle = v_cache fused_qk_norm_rope_cache_quant_shuffle( @@ -229,8 +227,10 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): # Prefix cache hit: gather cached KV from paged cache and concat with new tokens if attn_metadata.has_cached: - q, k, v, k_cache, v_cache, k_scale, v_scale = self._gather_prefix_and_concat_kv( - q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata + q, k, v, k_cache, v_cache, k_scale, v_scale = ( + self._gather_prefix_and_concat_kv( + q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata + ) ) return q, k, v, k_cache, v_cache, k_scale, v_scale @@ -266,9 +266,7 @@ def _gather_prefix_and_concat_kv( # token_to_batch: [0]*cached[0] + [1]*cached[1] + ... # cu_seqlens_kv: [0, cached[0], cached[0]+cached[1], ...] # seq_starts: [0]*bs (prefix starts at position 0) - token_to_batch = torch.zeros( - total_cached, dtype=torch.int32, device=device - ) + token_to_batch = torch.zeros(total_cached, dtype=torch.int32, device=device) cu_seqlens_kv = [0] for i in range(bs): c = cached_seqlen[i].item() @@ -343,8 +341,12 @@ def _gather_prefix_and_concat_kv( cached_i = cached_seqlen[i].item() new_i = end - start - cached_i - k_full[start : start + cached_i] = k_prefix[cached_offset : cached_offset + cached_i] - v_full[start : start + cached_i] = v_prefix[cached_offset : cached_offset + cached_i] + k_full[start : start + cached_i] = k_prefix[ + cached_offset : cached_offset + cached_i + ] + v_full[start : start + cached_i] = v_prefix[ + cached_offset : cached_offset + cached_i + ] new_start = cu_seqlens_q[i].item() k_full[start + cached_i : end] = k[new_start : new_start + new_i] diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 110c379fa..e0d3e73b8 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -70,6 +70,7 @@ def is_rocm_aiter_fp4bmm_enabled() -> bool: from aiter.ops.triton.gather_kv_b_proj import gather_kv_b_proj + # MLA Specific Arguments @dataclass class MLAModules: @@ -649,9 +650,7 @@ def forward_impl_server_mode( if use_prefix_cache: total_tokens = attn_metadata.cu_seqlens_k[-1].item() output_dtype = ( - dtypes.fp8 - if self.kv_cache_dtype.startswith("fp8") - else self.dtype + dtypes.fp8 if self.kv_cache_dtype.startswith("fp8") else self.dtype ) k_full = torch.empty( ( diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index 7bd1b675c..cb04f0a43 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -7,14 +7,13 @@ import torch from atom.model_engine.scheduler import ScheduledBatch - -logger = logging.getLogger("atom") from atom.model_ops.attention_mla import MLAModules from atom.utils import CpuGpuBuffer from atom.utils.block_convert import block_table_convert_triton from atom.utils.forward_context import AttentionMetaData from torch import nn +logger = logging.getLogger("atom") T = TypeVar("T", bound="BroadcastableModelInput") diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 90054852d..f3a41b3c8 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -159,7 +159,9 @@ def cp_mha_gather_cache( num_heads = key_cache.shape[2] else: # SHUFFLE: K [num_blocks, num_heads, head_dim//x, page_size, x] - assert key_cache.dim() == 5 and head_dim == key_cache.shape[2] * key_cache.shape[4] + assert ( + key_cache.dim() == 5 and head_dim == key_cache.shape[2] * key_cache.shape[4] + ) page_size = key_cache.shape[3] num_heads = key_cache.shape[1] diff --git a/tests/test_prefix_cache_accuracy.py b/tests/test_prefix_cache_accuracy.py index e10d97cf2..d68c2b67e 100644 --- a/tests/test_prefix_cache_accuracy.py +++ b/tests/test_prefix_cache_accuracy.py @@ -10,7 +10,6 @@ import argparse import concurrent.futures -import json import re import sys import time @@ -125,7 +124,7 @@ def main(): sys.exit(1) model = get_model_name(base_url) - print(f"=== Prefix Cache Accuracy Test ===") + print("=== Prefix Cache Accuracy Test ===") print(f"Server: {base_url}") print(f"Model: {model}") print(f"Questions per round: {len(TEST_QUESTIONS)}") From 661b631dc02994698da294c40b5cc4db6c3d5ff8 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Wed, 11 Mar 2026 17:09:16 +0800 Subject: [PATCH 07/12] fix format --- tests/test_prefix_cache_accuracy.py | 70 +++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/tests/test_prefix_cache_accuracy.py b/tests/test_prefix_cache_accuracy.py index d68c2b67e..0901b93ce 100644 --- a/tests/test_prefix_cache_accuracy.py +++ b/tests/test_prefix_cache_accuracy.py @@ -40,16 +40,46 @@ # Test questions with known answers TEST_QUESTIONS = [ - ("A farmer has 17 sheep. He buys 5 more and then sells 8. How many sheep does he have?", 14), - ("A store had 45 apples. They sold 12 in the morning and 18 in the afternoon. How many apples are left?", 15), - ("Tom has 8 marbles. Jerry has 3 times as many. How many marbles do they have together?", 32), - ("A classroom has 6 rows of desks with 5 desks in each row. If 7 desks are removed, how many remain?", 23), - ("Sarah baked 24 cookies. She gave 1/3 to her neighbor and ate 4 herself. How many cookies does she have left?", 12), - ("A train travels 60 miles per hour for 3 hours, then 40 miles per hour for 2 hours. What is the total distance?", 260), - ("Mike has 50 dollars. He spends 15 dollars on lunch and 20 dollars on a book. How much money does he have left?", 15), - ("A garden has 9 rose bushes. Each bush has 12 roses. If 25 roses are picked, how many roses remain?", 83), - ("Lisa read 35 pages on Monday and twice as many on Tuesday. How many pages did she read in total?", 105), - ("A box contains 100 balls. 40 are red, 35 are blue, and the rest are green. How many green balls are there?", 25), + ( + "A farmer has 17 sheep. He buys 5 more and then sells 8. How many sheep does he have?", + 14, + ), + ( + "A store had 45 apples. They sold 12 in the morning and 18 in the afternoon. How many apples are left?", + 15, + ), + ( + "Tom has 8 marbles. Jerry has 3 times as many. How many marbles do they have together?", + 32, + ), + ( + "A classroom has 6 rows of desks with 5 desks in each row. If 7 desks are removed, how many remain?", + 23, + ), + ( + "Sarah baked 24 cookies. She gave 1/3 to her neighbor and ate 4 herself. How many cookies does she have left?", + 12, + ), + ( + "A train travels 60 miles per hour for 3 hours, then 40 miles per hour for 2 hours. What is the total distance?", + 260, + ), + ( + "Mike has 50 dollars. He spends 15 dollars on lunch and 20 dollars on a book. How much money does he have left?", + 15, + ), + ( + "A garden has 9 rose bushes. Each bush has 12 roses. If 25 roses are picked, how many roses remain?", + 83, + ), + ( + "Lisa read 35 pages on Monday and twice as many on Tuesday. How many pages did she read in total?", + 105, + ), + ( + "A box contains 100 balls. 40 are red, 35 are blue, and the rest are green. How many green balls are there?", + 25, + ), ] @@ -68,7 +98,9 @@ def get_model_name(base_url: str) -> str: return resp.json()["data"][0]["id"] -def send_completion(prompt: str, max_tokens: int = 256, base_url: str = BASE_URL, model: str = "") -> str: +def send_completion( + prompt: str, max_tokens: int = 256, base_url: str = BASE_URL, model: str = "" +) -> str: """Send a completion request to the server.""" resp = requests.post( f"{base_url}/v1/completions", @@ -108,7 +140,9 @@ def ask(q_and_a): def main(): parser = argparse.ArgumentParser(description="Test prefix cache accuracy") - parser.add_argument("--rounds", type=int, default=3, help="Number of rounds to repeat") + parser.add_argument( + "--rounds", type=int, default=3, help="Number of rounds to repeat" + ) parser.add_argument("--base-url", type=str, default=BASE_URL) parser.add_argument("--verbose", action="store_true") args = parser.parse_args() @@ -136,12 +170,20 @@ def main(): for round_num in range(1, args.rounds + 1): t0 = time.time() - correct, total, results = run_batch(TEST_QUESTIONS, MATH_PREFIX, base_url=base_url, model=model, label=f"Round {round_num}") + correct, total, results = run_batch( + TEST_QUESTIONS, + MATH_PREFIX, + base_url=base_url, + model=model, + label=f"Round {round_num}", + ) elapsed = time.time() - t0 accuracy = 100.0 * correct / total all_round_results.append((correct, total, accuracy, elapsed)) - print(f"Round {round_num}: {correct}/{total} correct ({accuracy:.1f}%) in {elapsed:.1f}s") + print( + f"Round {round_num}: {correct}/{total} correct ({accuracy:.1f}%) in {elapsed:.1f}s" + ) if args.verbose: for q, expected, got, ok, resp in results: From 8a674a8c5124a84532f6b778d4b3946dfd994d3b Mon Sep 17 00:00:00 2001 From: jiayyu Date: Thu, 12 Mar 2026 10:31:12 +0800 Subject: [PATCH 08/12] refine mha --- atom/model_ops/attention_mha.py | 72 ++++++--------------------- atom/model_ops/attention_mla.py | 1 + atom/model_ops/attentions/backends.py | 17 +++++++ atom/utils/forward_context.py | 10 +++- 4 files changed, 42 insertions(+), 58 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 48f295cde..c42a8c235 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -247,39 +247,26 @@ def _gather_prefix_and_concat_kv( attn_metadata, ): """ - When prefix cache hits, gather cached KV from paged cache and concat with - new tokens' k, v for full-sequence attention (e.g. flash_attn). + When prefix cache hits, gather full KV (cached + new) from paged cache in + one pass. New tokens are already written by fused_qk_rope_reshape_and_cache. + Same flow as gather_kv_b_proj: write new first, then read cached+new together. + token_to_batch, seq_starts are built in prepare_prefill. """ - bs = attn_metadata.cu_seqlens_q.shape[0] - 1 - cu_seqlens_q = attn_metadata.cu_seqlens_q cu_seqlens_k = attn_metadata.cu_seqlens_k - cached_seqlen = attn_metadata.num_cached_tokens # [bs], from prepare_prefill - - total_cached = cached_seqlen.sum().item() + total_tokens = cu_seqlens_k[-1].item() + token_to_batch = attn_metadata.token_to_batch + seq_starts = attn_metadata.seq_starts num_kv_heads = k.shape[1] head_dim = k.shape[2] device = k.device dtype = k.dtype - # Build metadata for cp_mha_gather_cache (cached tokens only) - # token_to_batch: [0]*cached[0] + [1]*cached[1] + ... - # cu_seqlens_kv: [0, cached[0], cached[0]+cached[1], ...] - # seq_starts: [0]*bs (prefix starts at position 0) - token_to_batch = torch.zeros(total_cached, dtype=torch.int32, device=device) - cu_seqlens_kv = [0] - for i in range(bs): - c = cached_seqlen[i].item() - token_to_batch[cu_seqlens_kv[-1] : cu_seqlens_kv[-1] + c] = i - cu_seqlens_kv.append(cu_seqlens_kv[-1] + c) - cu_seqlens_kv = torch.tensor(cu_seqlens_kv, dtype=torch.int32, device=device) - seq_starts = torch.zeros(bs, dtype=torch.int32, device=device) - - k_prefix = torch.empty( - (total_cached, num_kv_heads, head_dim), dtype=dtype, device=device + k_full = torch.empty( + (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device ) - v_prefix = torch.empty( - (total_cached, num_kv_heads, head_dim), dtype=dtype, device=device + v_full = torch.empty( + (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device ) # Convert cache for cp_mha_gather_cache @@ -312,48 +299,19 @@ def _gather_prefix_and_concat_kv( cp_mha_gather_cache( key_cache=k_cache_gather, value_cache=v_cache_gather, - key=k_prefix, - value=v_prefix, + key=k_full, + value=v_full, block_tables=block_tables, k_scales=k_scale, v_scales=v_scale, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_kv=cu_seqlens_k, token_to_batch=token_to_batch, seq_starts=seq_starts, dequant=self.kv_cache_dtype.startswith("fp8"), kv_cache_layout="SHUFFLE" if use_shuffle else "NHD", - total_tokens=total_cached, + total_tokens=total_tokens, ) - # Build full k, v: for each batch i, [cached_i | new_i] - total_tokens = cu_seqlens_k[-1].item() - k_full = torch.empty( - (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device - ) - v_full = torch.empty( - (total_tokens, num_kv_heads, head_dim), dtype=dtype, device=device - ) - - cached_offset = 0 - for i in range(bs): - start = cu_seqlens_k[i].item() - end = cu_seqlens_k[i + 1].item() - cached_i = cached_seqlen[i].item() - new_i = end - start - cached_i - - k_full[start : start + cached_i] = k_prefix[ - cached_offset : cached_offset + cached_i - ] - v_full[start : start + cached_i] = v_prefix[ - cached_offset : cached_offset + cached_i - ] - - new_start = cu_seqlens_q[i].item() - k_full[start + cached_i : end] = k[new_start : new_start + new_i] - v_full[start + cached_i : end] = v[new_start : new_start + new_i] - - cached_offset += cached_i - return q, k_full, v_full, k_cache, v_cache, k_scale, v_scale def paged_attention_triton( diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index e0d3e73b8..ebb84a656 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -681,6 +681,7 @@ def forward_impl_server_mode( self.kv_b_proj.weight_scale, k_full, v_full, + weight_preshuffle=True, ) output = flash_attn_varlen_func( q=prefill_q, diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index cb04f0a43..247a512f5 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -116,6 +116,8 @@ def __init__(self, model_runner): ), "cu_seqlens_q": CpuGpuBuffer(self.max_bs + 1, **i32_kwargs), "cu_seqlens_k": CpuGpuBuffer(self.max_bs + 1, **i32_kwargs), + # seq_starts for cp_mha_gather_cache: always zeros (prefix at position 0) + "seq_starts": CpuGpuBuffer(self.max_bs, **i32_kwargs), } if self.block_ratio > 1: attn_metadata["block_tables_converted"] = CpuGpuBuffer( @@ -128,6 +130,8 @@ def __init__(self, model_runner): torch.arange(0, self.max_bs + 1, step=1, dtype=torch.int32) ) attn_metadata["cu_seqlens_q"].copy_to_gpu() + attn_metadata["seq_starts"].cpu.zero_() + attn_metadata["seq_starts"].copy_to_gpu() self.model_runner.forward_vars.update(attn_metadata) self.has_sliding_window = hasattr(hf_config, "sliding_window") @@ -208,6 +212,7 @@ def prepare_prefill(self, batch: ScheduledBatch): ] if has_cached: vars_used.append(("block_tables", bs)) + vars_used.append(("seq_starts", bs)) ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used} if self.block_ratio > 1 and "block_tables" in ctx: @@ -219,6 +224,7 @@ def prepare_prefill(self, batch: ScheduledBatch): ) ctx["block_tables_converted"] = var["block_tables_converted"].gpu[:bs] num_cached_tokens = None + token_to_batch = None if has_cached: num_cached_tokens = torch.tensor( batch.num_cached_tokens[:bs], dtype=torch.int32, pin_memory=True @@ -228,6 +234,16 @@ def prepare_prefill(self, batch: ScheduledBatch): logger.info( f"Prefill batch has {num_cached_tokens.sum().item()} cached tokens and of {sum_scheduled_tokens} total tokens" ) + # Build metadata for cp_mha_gather_cache (full sequence: cached + new) + # token_to_batch: [0]*len0 + [1]*len1 + ... + total_tokens = cu_seqlens_k[-1] + token_to_batch = torch.zeros( + total_tokens, dtype=torch.int32, device=self.device + ) + for i in range(bs): + start = cu_seqlens_k[i] + end = cu_seqlens_k[i + 1] + token_to_batch[start:end] = i attn_metadata = AttentionMetaData( cu_seqlens_k=cu_seqlens_k.cuda(non_blocking=True), max_seqlen_q=max_seqlen_q, @@ -236,6 +252,7 @@ def prepare_prefill(self, batch: ScheduledBatch): dropout_p=dropout_p, has_cached=has_cached, num_cached_tokens=num_cached_tokens, + token_to_batch=token_to_batch, **ctx, ) positions = var["positions"].copy_to_gpu(sum_scheduled_tokens) diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 5f96a89df..2f1a637bd 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -188,8 +188,12 @@ class AttentionMetaData: reduce_partial_map: Optional[torch.Tensor] = None block_tables_converted: Optional[torch.Tensor] = None + + # for prefix cache has_cached: bool = False - num_cached_tokens: Optional[torch.Tensor] = None # [bs] when has_cached + num_cached_tokens: Optional[torch.Tensor] = None + token_to_batch: Optional[torch.Tensor] = None + seq_starts: Optional[torch.Tensor] = None # only used for plugin mode to store the metadata for attn plugin_metadata: Optional["MetadataForPluginMode"] = None @@ -223,9 +227,13 @@ def __init__( plugin_metadata: Optional["MetadataForPluginMode"] = None, has_cached: bool = False, num_cached_tokens: Optional[torch.Tensor] = None, + token_to_batch: Optional[torch.Tensor] = None, + seq_starts: Optional[torch.Tensor] = None, ): self.has_cached = has_cached self.num_cached_tokens = num_cached_tokens + self.token_to_batch = token_to_batch + self.seq_starts = seq_starts self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k self.max_seqlen_q = max_seqlen_q From 66c7753e72401e225810e37b784bbd18a588ebe5 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Thu, 12 Mar 2026 12:13:45 +0800 Subject: [PATCH 09/12] fix format --- atom/model_ops/attention_mla.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index ebb84a656..5c0978a06 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -33,7 +33,7 @@ from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, ) - +from aiter.ops.triton.gather_kv_b_proj import gather_kv_b_proj from atom.plugin import is_plugin_mode @@ -68,8 +68,6 @@ def is_rocm_aiter_fp4bmm_enabled() -> bool: from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 from atom.model_ops.utils import quark_post_load_weights -from aiter.ops.triton.gather_kv_b_proj import gather_kv_b_proj - # MLA Specific Arguments @dataclass From 15f54d23cb527459bb9f8622ef46b2122dc42a1f Mon Sep 17 00:00:00 2001 From: jiayyu Date: Thu, 12 Mar 2026 17:38:13 +0800 Subject: [PATCH 10/12] prefill mla wip --- atom/model_ops/attention_mla.py | 2 +- atom/model_ops/attentions/aiter_mla.py | 36 +++++++++++++++++++++----- atom/models/deepseek_v2.py | 18 ++++++++----- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 5c0978a06..5d62b1589 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -466,7 +466,7 @@ def _forward_prefill_mla( attn_metadata.token_to_seq_idxs, self.topk_indices_buffer[:B], attn_metadata.block_tables, - attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], ) paged_cu_seqlens_q = attn_metadata.sparse_cu_seqlens_q diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 8e1330d67..ab4eb3bac 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -264,13 +264,23 @@ def prepare_prefill(self, batch: ScheduledBatch): self.block_ratio, ) attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] - var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = ( - np.arange(sum_scheduled_tokens, dtype=np.int32) + 1 - ) counts = var["cu_seqlens_q"].np[1 : bs + 1] - var["cu_seqlens_q"].np[:bs] - var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( - var["cu_seqlens_q"].np[:bs], counts - ) + if attn_metadata.has_cached: + # Full context (cached + new): use cu_seqlens_k for indexer + cu_seqlens_k_np = attn_metadata.cu_seqlens_k.cpu().numpy() + var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( + cu_seqlens_k_np[:-1], counts + ) + var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = np.repeat( + cu_seqlens_k_np[1:], counts + ) + else: + var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = ( + np.arange(sum_scheduled_tokens, dtype=np.int32) + 1 + ) + var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( + var["cu_seqlens_q"].np[:bs], counts + ) attn_metadata.cu_seqlen_ks = var["cu_seqlen_ks"].copy_to_gpu( sum_scheduled_tokens ) @@ -284,9 +294,11 @@ def prepare_prefill(self, batch: ScheduledBatch): :sum_scheduled_tokens ] + # Per-query req_id: token_id 0..sum_scheduled_tokens-1 maps to batch id. + # Use counts (new tokens per batch), not context_lens (full seq len). attn_metadata.token_to_seq_idxs = torch.repeat_interleave( torch.arange(bs, dtype=torch.int32, device=self.device), - attn_metadata.context_lens, + torch.tensor(counts, dtype=torch.int64, device=self.device), ) var["sparse_kv_indptr"].np[0] = 0 var["sparse_kv_indptr"].np[1 : sum_scheduled_tokens + 1] = np.cumsum( @@ -301,6 +313,16 @@ def prepare_prefill(self, batch: ScheduledBatch): ) if hasattr(self.model_runner, "drafter") or attn_metadata.has_cached: + # Populate kv_last_page_lens for full sequence (needed for MLA prefill with + # prefix cache; decode does the same) + if self.model_runner.block_size != 1: + var["kv_last_page_lens"].np[:bs] = np.asarray( + batch.last_block_num_tokens[:bs], dtype=np.int32 + ) + else: + var["kv_last_page_lens"].np[:bs] = 1 + var["kv_last_page_lens"].copy_to_gpu() + attn_metadata.kv_indices = var["kv_indices"].gpu attn_metadata.kv_indptr = var["kv_indptr"].gpu[: bs + 1] attn_metadata.kv_indptr[0] = 0 diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d1da9f052..2fc334bbc 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -967,11 +967,17 @@ def sparse_attn_indexer( return weights prefill_metadata = attn_metadata num_prefills = context.batch_size - total_seq_lens = hidden_states.shape[0] + num_tokens = hidden_states.shape[0] + # When has_cached, gather full KV (cached + new) for indexer top-k + total_kv = ( + prefill_metadata.cu_seqlens_k[-1].item() + if prefill_metadata.has_cached + else num_tokens + ) k_fp8 = torch.empty( - [total_seq_lens, head_dim], device=k.device, dtype=dtypes.fp8 + [total_kv, head_dim], device=k.device, dtype=dtypes.fp8 ) - k_scale = torch.empty([total_seq_lens, 1], device=k.device, dtype=torch.float32) + k_scale = torch.empty([total_kv, 1], device=k.device, dtype=torch.float32) if prefill_metadata.block_tables.shape[0] < num_prefills: new_shape = (num_prefills, prefill_metadata.block_tables.shape[1]) prefill_metadata.block_tables = torch.full( @@ -985,12 +991,12 @@ def sparse_attn_indexer( k_fp8, k_scale.view(dtypes.fp8), prefill_metadata.block_tables, - prefill_metadata.cu_seqlens_q, - # num_prefills, + prefill_metadata.cu_seqlens_k + if prefill_metadata.has_cached + else prefill_metadata.cu_seqlens_q, ) cu_seqlen_ks = prefill_metadata.cu_seqlen_ks cu_seqlen_ke = prefill_metadata.cu_seqlen_ke - num_tokens = hidden_states.shape[0] logits = fp8_mqa_logits( Q=q_fp8[num_decode_tokens:num_tokens], KV=k_fp8, From e17b6141eeb3f5e6b68ef8dfcf22379f2a072c01 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Fri, 13 Mar 2026 15:54:40 +0800 Subject: [PATCH 11/12] fix format --- atom/models/deepseek_v2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 2fc334bbc..f98480beb 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -974,9 +974,7 @@ def sparse_attn_indexer( if prefill_metadata.has_cached else num_tokens ) - k_fp8 = torch.empty( - [total_kv, head_dim], device=k.device, dtype=dtypes.fp8 - ) + k_fp8 = torch.empty([total_kv, head_dim], device=k.device, dtype=dtypes.fp8) k_scale = torch.empty([total_kv, 1], device=k.device, dtype=torch.float32) if prefill_metadata.block_tables.shape[0] < num_prefills: new_shape = (num_prefills, prefill_metadata.block_tables.shape[1]) @@ -991,9 +989,11 @@ def sparse_attn_indexer( k_fp8, k_scale.view(dtypes.fp8), prefill_metadata.block_tables, - prefill_metadata.cu_seqlens_k - if prefill_metadata.has_cached - else prefill_metadata.cu_seqlens_q, + ( + prefill_metadata.cu_seqlens_k + if prefill_metadata.has_cached + else prefill_metadata.cu_seqlens_q + ), ) cu_seqlen_ks = prefill_metadata.cu_seqlen_ks cu_seqlen_ke = prefill_metadata.cu_seqlen_ke From e20c998f11fd9afe9ba62d4b5282fa5d0cf1609c Mon Sep 17 00:00:00 2001 From: jiayyu Date: Fri, 13 Mar 2026 16:02:25 +0800 Subject: [PATCH 12/12] trigger ci