-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Hybrid kv cache for LLaMA4 #6563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
hybrid cache hybrid cache hybrid cache end with evict rules and reformat 1 2
5e66e89
to
d053027
Compare
c9fa7ea
to
d1203cb
Compare
@@ -624,6 +626,9 @@ def forward_extend( | |||
q_rope: Optional[torch.Tensor] = None, | |||
k_rope: Optional[torch.Tensor] = None, | |||
): | |||
use_hybrid_loc = self.is_hybrid is not None and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't find the usage of use_hybrid_loc
@@ -887,6 +892,9 @@ def forward_decode( | |||
q_rope: Optional[torch.Tensor] = None, | |||
k_rope: Optional[torch.Tensor] = None, | |||
) -> torch.Tensor: | |||
use_hybrid_loc = self.is_hybrid is not None and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar here
@@ -523,6 +526,8 @@ def __init__( | |||
# Prefix info | |||
# The indices to kv cache for the shared prefix. | |||
self.prefix_indices: torch.Tensor = [] | |||
# The indices to local kv cache for the shared prefix. | |||
self.prefix_indices_local: torch.Tensor = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't find usage of prefix_indices_local
@@ -55,6 +57,11 @@ def __init__( | |||
def debug_print(self) -> str: | |||
return "" | |||
|
|||
def log_usage(self, evictable_size: int = 0): | |||
num_used = self.size - (self.available_size() + evictable_size) | |||
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we show both swa and full token usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I define log_usage for SWA case around line 216 in allocator.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks great! Left some comments, all small changes, thanks!
available_token_size = self.token_to_kv_pool_allocator.full_available_size() | ||
else: | ||
available_token_size = self.token_to_kv_pool_allocator.available_size() | ||
available_size = available_token_size + self.tree_cache.evictable_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use self.full_max_total_num_tokens
and self.swa_max_total_num_tokens
here, I think you already have it, each determines the max total per full attn / swa layer, resp. And compare full_available_size + 0 == max_total_full_num_tokens
and swa_available_size + 0 = max_total_swa_num_tokens
@@ -113,7 +120,7 @@ def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache) | |||
def clear(self): | |||
# The padded slot 0 is used for writing dummy outputs from padded tokens. | |||
self.free_pages = torch.arange( | |||
1, self.size + 1, dtype=torch.int64, device=self.device | |||
1, self.size + 1, dtype=torch.int32, device=self.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the reason behind this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not quite sure this part. I will make it back to int64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made all kv indices to be torch.int64
. And later those indices will convert to torch.int32
when building page_table in order to support flash_attn_with_kvcache
device=device, | ||
) | ||
self.clear() | ||
self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question as gemini, better to explain the reason for weakref
f"#token: global={used_full}, swa={used_swa}, " | ||
f"token usage: global={used_full / self.size_full:.2f}, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's be consistent with naming, either full
or global
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will double check the namings, Thank you so much for pointing out this problem.
|
||
def log_usage(self, evictable_size: int = 0): | ||
used_full = self.size_full - (self.full_available_size() + evictable_size) | ||
used_swa = self.size_swa - self.swa_available_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should pass in both swa_evictable_size
and full_evictable_size
. For SWAChunkCache, this value is always 0, but the logic here is cleaner.
* self.attention_chunk_size | ||
/ self.model_config.context_len | ||
) | ||
self.local_max_total_num_tokens = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consistent naming please, either local
or swa
self.local_max_total_num_tokens = ( | ||
4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1) | ||
) | ||
self.max_total_num_tokens = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add full
or global
for consistent naming
@@ -852,6 +859,39 @@ def profile_max_num_token(self, total_gpu_memory: int): | |||
max_num_token = int(rest_memory * (1 << 30) // cell_size) | |||
return max_num_token | |||
|
|||
def get_num_token_hybrid(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_num_token_hybrid(self): | |
def set_num_token_hybrid(self): |
|
||
if self.token_to_kv_pool_allocator is None: | ||
if self.page_size == 1: | ||
if self.is_hybrid is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to do if self.is_hybrid:
, easier for future additions
@@ -61,6 +61,7 @@ class ServerArgs: | |||
is_embedding: bool = False | |||
enable_multimodal: Optional[bool] = None | |||
revision: Optional[str] = None | |||
hybrid_kvcache_ratio: Optional[float] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't find usage of it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is for getting a mix ratio from server parser.
hybrid_kvcache_ratio == 0: pure uniform: swa_size / full_size = 1.
hybrid_kvcache_ratio ==1.0: pure hybrid: swa_size / full_size = local_attention_size / context_length
It is called it in model_config.py around 280
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have modified my code according to your comments. Thank you very much for your helpful suggestions.
@@ -63,3 +66,32 @@ def dec_lock_ref(self, node: Any): | |||
|
|||
def pretty_print(self): | |||
return "" | |||
|
|||
|
|||
class SWAChunkCache(ChunkCache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does SWA mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to make my namings in consistent with the ones in #7367.
SWA means sliding window attention (I guess...).
Co-authored-by: Hanming Lu <[email protected]>
Co-authored-by: Hanming Lu <[email protected]>
@@ -431,6 +436,136 @@ def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): | |||
) | |||
|
|||
|
|||
class SWAKVPool(KVCache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a docstring for this to indicate its usage and meaning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for addressing all the comments.
@@ -29,6 +29,7 @@ | |||
from sglang.srt.custom_op import CustomOp | |||
from sglang.srt.distributed import get_tensor_model_parallel_rank | |||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture | |||
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't import this on non-nv devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrote this part.
Motivation
LLaMA 4 uses local attention in 3/4 of its layers. To accommodate this, we divide the KV cache into two parts: a global cache and a local cache. Determining the optimal ratio between their sizes is nontrivial, so we introduce a tunable parameter
p
, where0 ≤ p ≤ 1
.p = 1
, the ratio of global to local cache sizes is equal tocontext_length / attention_chunk_size
(e.g., 8192).p = 0
, the two caches are of equal size.p
varies from 0 to 1. By default, we setp = 0.5
.Currently, we disable the radix tree, so prefix matching is not a concern.
During local attention, certain KV cache entries can be safely removed:
attention_chunk_size * (prelen // attention_chunk_size)
are no longer needed and can be evicted.attention_chunk_size * ((seqlen - 1) // attention_chunk_size)
are similarly unused and can be discardedModifications
Add a server argument:
hybrid_kvcache_ratio
with default value0.5
.This turns on the hybrid KV cache mode and controls the global-to-local cache size ratio.In model_config.py: add
is_hybrid_model()
to determine whether the current model configuration satisfies the conditions to enable hybrid KV caching.In model_runner.py:
get_num_token_hybrid()
to get the size of global and local KV cachetoken_to_kv_pool_allocator_local
to allocate local cache indicesIn memory_pool.py:
ReqToTokenPool
, add a new attrreq_to_token_local
to store local indices in KV cache per reqMHATokenToKVPool._create_buffer
to create global and local cache buffers.In schedule_batch.py:
prepare_for_extend()
andprepare_for_decode()
, allocateout_cache_loc_local
for local attention KV indices, and store them intoken_to_token_pool.req_to_token_local
.self.tree_cache.evict_hybrid()
right before allocating new indices.In chunk_cache.py:
evict_hybrid()
is defined to apply the new evict rule in chunked prefill and decoding.cache_finished_req()
to free local indices once the reqs are finishedIn flashattention_backend.py
cache_loc = forward_batch.out_cache_loc_local
in normal both decode and extend forward.page_table
s in metadata are modified correspondinglysome essential modification for memory computations
Every time we meet
token_to_kv_pool_allocator.available_size()
, change it tomin(token_to_kv_pool_allocator.available_size(), token_to_kv_pool_allocator_local.available_size())
Some essential changes to support CUDA graph
Experiments
Loogle Evaluation on H100:
Enabling hybrid KV cache increases the throughput by ~10% w.r.t the baseline.
Context Length Improvements with Hybrid KV Cache
On H100:
Enabling hybrid KV cache significantly increases the maximum context length from 1.3M to 5M tokens:
On H200:
With hybrid KV cache enabled, the maximum context length for LLaMA-4 reaches 10M tokens, compared to the 3.5M token baseline:
TODO
page_size
> 1Checklist