Skip to content

Hybrid kv cache for LLaMA4 #6563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jun 28, 2025
Merged

Hybrid kv cache for LLaMA4 #6563

merged 45 commits into from
Jun 28, 2025

Conversation

tarinkk
Copy link
Contributor

@tarinkk tarinkk commented May 24, 2025

Motivation

LLaMA 4 uses local attention in 3/4 of its layers. To accommodate this, we divide the KV cache into two parts: a global cache and a local cache. Determining the optimal ratio between their sizes is nontrivial, so we introduce a tunable parameter p, where 0 ≤ p ≤ 1.

  • When p = 1, the ratio of global to local cache sizes is equal to context_length / attention_chunk_size (e.g., 8192).
  • When p = 0, the two caches are of equal size.
  • The ratio transitions linearly as p varies from 0 to 1. By default, we set p = 0.5.

Currently, we disable the radix tree, so prefix matching is not a concern.

During local attention, certain KV cache entries can be safely removed:

  • In chunked prefill: entries in the range attention_chunk_size * (prelen // attention_chunk_size) are no longer needed and can be evicted.
  • In decoding: entries in the range attention_chunk_size * ((seqlen - 1) // attention_chunk_size) are similarly unused and can be discarded

Modifications

  1. Add a server argument: hybrid_kvcache_ratio with default value 0.5 .This turns on the hybrid KV cache mode and controls the global-to-local cache size ratio.

  2. In model_config.py: add is_hybrid_model() to determine whether the current model configuration satisfies the conditions to enable hybrid KV caching.

  3. In model_runner.py:

    • Implement get_num_token_hybrid() to get the size of global and local KV cache
    • Initialize token_to_kv_pool_allocator_local to allocate local cache indices
  4. In memory_pool.py:

    • In ReqToTokenPool, add a new attr req_to_token_local to store local indices in KV cache per req
    • modify MHATokenToKVPool._create_buffer to create global and local cache buffers.
  5. In schedule_batch.py:

    • In prepare_for_extend() and prepare_for_decode(), allocate out_cache_loc_local for local attention KV indices, and store them in token_to_token_pool.req_to_token_local.
    • Apply the new eviction rule via self.tree_cache.evict_hybrid() right before allocating new indices.
  6. In chunk_cache.py:

    • evict_hybrid() is defined to apply the new evict rule in chunked prefill and decoding.
    • Modify cache_finished_req() to free local indices once the reqs are finished
  7. In flashattention_backend.py

    • When hybrid cache is enabled, set cache_loc = forward_batch.out_cache_loc_local in normal both decode and extend forward.
    • The page_tables in metadata are modified correspondingly
  8. some essential modification for memory computations

    Every time we meet token_to_kv_pool_allocator.available_size(), change it to

    min(token_to_kv_pool_allocator.available_size(), token_to_kv_pool_allocator_local.available_size())

  9. Some essential changes to support CUDA graph

Experiments

Loogle Evaluation on H100:

Enabling hybrid KV cache increases the throughput by ~10% w.r.t the baseline.

  • With hybrid KV cache (total time ~ 694s, throughput ~222 token/s )
python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --mem-fraction-static 0.8 --context-length 100000 --attention-backend fa3 --disable-radix-cache --hybrid-kvcache-ratio 0.95
  • Baseline (total time ~ 746s, throughput ~204 token/s)
python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --mem-fraction-static 0.8 --context-length 100000 --attention-backend fa3 --disable-radix-cache

Context Length Improvements with Hybrid KV Cache

On H100:
Enabling hybrid KV cache significantly increases the maximum context length from 1.3M to 5M tokens:

  • With hybrid KV cache (5M context length):
python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --context-length 5000000 --attention-backend fa3 --disable-radix-cache --hybrid-kvcache-ratio 1 --cuda-graph-max-bs 16 --max-running-requests 16
  • Baseline (1.3M context length):
python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --context-length 1300000 --attention-backend fa3 --cuda-graph-max-bs 16 --max-running-requests 16

On H200:
With hybrid KV cache enabled, the maximum context length for LLaMA-4 reaches 10M tokens, compared to the 3.5M token baseline:

  • With hybrid KV cache (10M context length):
python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --context-length 10000000 --attention-backend fa3 --disable-radix-cache --hybrid-kvcache-ratio 1 --cuda-graph-max-bs 32 --max-running-requests 32
  • Baseline (3.5M context length):
python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --context-length 3500000 --attention-backend fa3 --cuda-graph-max-bs 32 --max-running-requests 32

TODO

  1. Enable when page_size > 1
  2. Apply evict rule when radix tree is enable
  3. ...

Checklist

hybrid cache

hybrid cache

hybrid cache end

with evict rules and reformat

1

2
@tarinkk tarinkk changed the title [WIP]hybrid kv cache for LlaMa4 [WIP]hybrid kv cache for LlaMA4 May 24, 2025
@tarinkk tarinkk changed the title [WIP]hybrid kv cache for LlaMA4 [WIP]hybrid kv cache for LLaMA4 May 24, 2025
@tarinkk tarinkk force-pushed the llama4hybridCache branch from 5e66e89 to d053027 Compare May 25, 2025 00:29
@tarinkk tarinkk marked this pull request as ready for review May 25, 2025 00:29
@tarinkk tarinkk force-pushed the llama4hybridCache branch from c9fa7ea to d1203cb Compare May 25, 2025 01:50
@@ -624,6 +626,9 @@ def forward_extend(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
use_hybrid_loc = self.is_hybrid is not None and (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't find the usage of use_hybrid_loc

@@ -887,6 +892,9 @@ def forward_decode(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_hybrid_loc = self.is_hybrid is not None and (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar here

@@ -523,6 +526,8 @@ def __init__(
# Prefix info
# The indices to kv cache for the shared prefix.
self.prefix_indices: torch.Tensor = []
# The indices to local kv cache for the shared prefix.
self.prefix_indices_local: torch.Tensor = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't find usage of prefix_indices_local

@@ -55,6 +57,11 @@ def __init__(
def debug_print(self) -> str:
return ""

def log_usage(self, evictable_size: int = 0):
num_used = self.size - (self.available_size() + evictable_size)
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
Copy link
Collaborator

@hanming-lu hanming-lu Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we show both swa and full token usage?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I define log_usage for SWA case around line 216 in allocator.py.

Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great! Left some comments, all small changes, thanks!

available_token_size = self.token_to_kv_pool_allocator.full_available_size()
else:
available_token_size = self.token_to_kv_pool_allocator.available_size()
available_size = available_token_size + self.tree_cache.evictable_size()
Copy link
Collaborator

@hanming-lu hanming-lu Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use self.full_max_total_num_tokens and self.swa_max_total_num_tokens here, I think you already have it, each determines the max total per full attn / swa layer, resp. And compare full_available_size + 0 == max_total_full_num_tokens and swa_available_size + 0 = max_total_swa_num_tokens

@@ -113,7 +120,7 @@ def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache)
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
1, self.size + 1, dtype=torch.int32, device=self.device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason behind this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure this part. I will make it back to int64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made all kv indices to be torch.int64. And later those indices will convert to torch.int32 when building page_table in order to support flash_attn_with_kvcache

device=device,
)
self.clear()
self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping))
Copy link
Collaborator

@hanming-lu hanming-lu Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as gemini, better to explain the reason for weakref

Comment on lines 220 to 221
f"#token: global={used_full}, swa={used_swa}, "
f"token usage: global={used_full / self.size_full:.2f}, "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's be consistent with naming, either full or global

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will double check the namings, Thank you so much for pointing out this problem.


def log_usage(self, evictable_size: int = 0):
used_full = self.size_full - (self.full_available_size() + evictable_size)
used_swa = self.size_swa - self.swa_available_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should pass in both swa_evictable_size and full_evictable_size. For SWAChunkCache, this value is always 0, but the logic here is cleaner.

* self.attention_chunk_size
/ self.model_config.context_len
)
self.local_max_total_num_tokens = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consistent naming please, either local or swa

self.local_max_total_num_tokens = (
4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
)
self.max_total_num_tokens = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add full or global for consistent naming

@@ -852,6 +859,39 @@ def profile_max_num_token(self, total_gpu_memory: int):
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token

def get_num_token_hybrid(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_num_token_hybrid(self):
def set_num_token_hybrid(self):


if self.token_to_kv_pool_allocator is None:
if self.page_size == 1:
if self.is_hybrid is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to do if self.is_hybrid:, easier for future additions

@@ -61,6 +61,7 @@ class ServerArgs:
is_embedding: bool = False
enable_multimodal: Optional[bool] = None
revision: Optional[str] = None
hybrid_kvcache_ratio: Optional[float] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't find usage of it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for getting a mix ratio from server parser.
hybrid_kvcache_ratio == 0: pure uniform: swa_size / full_size = 1.
hybrid_kvcache_ratio ==1.0: pure hybrid: swa_size / full_size = local_attention_size / context_length
It is called it in model_config.py around 280

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified my code according to your comments. Thank you very much for your helpful suggestions.

@@ -63,3 +66,32 @@ def dec_lock_ref(self, node: Any):

def pretty_print(self):
return ""


class SWAChunkCache(ChunkCache):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does SWA mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to make my namings in consistent with the ones in #7367.
SWA means sliding window attention (I guess...).

@@ -431,6 +436,136 @@ def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
)


class SWAKVPool(KVCache):
Copy link
Collaborator

@CatherineSue CatherineSue Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a docstring for this to indicate its usage and meaning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for addressing all the comments.

@@ -29,6 +29,7 @@
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't import this on non-nv devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote this part.

@zhyncs zhyncs merged commit eb6c2c1 into sgl-project:main Jun 28, 2025
102 of 132 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants