Skip to content

[WIP]hybrid kv cache for LLaMA4 #6563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

tarinkk
Copy link
Contributor

@tarinkk tarinkk commented May 24, 2025

Motivation

LLaMA 4 uses local attention in 3/4 of its layers. To accommodate this, we divide the KV cache into two parts: a global cache and a local cache. Determining the optimal ratio between their sizes is nontrivial, so we introduce a tunable parameter p, where 0 ≤ p ≤ 1.

  • When p = 1, the ratio of global to local cache sizes is equal to context_length / attention_chunk_size (e.g., 8192).
  • When p = 0, the two caches are of equal size.
  • The ratio transitions linearly as p varies from 0 to 1. By default, we set p = 0.5.

Currently, we disable the radix tree, so prefix matching is not a concern.

During local attention, certain KV cache entries can be safely removed:

  • In chunked prefill: entries in the range attention_chunk_size * (prelen // attention_chunk_size) are no longer needed and can be evicted.
  • In decoding: entries in the range attention_chunk_size * ((seqlen - 1) // attention_chunk_size) are similarly unused and can be discarded

Modifications

  1. Add a server argument: enable_hybrid_kvcache with default value 0.5 .This turns on the hybrid KV cache mode and controls the global-to-local cache size ratio.

  2. In model_config.py: add is_hybrid_model() to determine whether the current model configuration satisfies the conditions to enable hybrid KV caching.

  3. In model_runner.py:

    • Implement get_num_token_hybrid() to get the size of global and local KV cache
    • Initialize token_to_kv_pool_allocator_local to allocate local cache indices
  4. In memory_pool.py:

    • In ReqToTokenPool, add a new attr req_to_token_local to store local indices in KV cache per req
    • modify MHATokenToKVPool._create_buffer to create global and local cache buffers.
  5. In schedule_batch.py:

    • In prepare_for_extend() and prepare_for_decode(), allocate out_cache_loc_local for local attention KV indices, and store them in token_to_token_pool.req_to_token_local.
    • Apply the new eviction rule via self.tree_cache.evict_hybrid() right before allocating new indices.
  6. In chunk_cache.py:

    • evict_hybrid() is defined to apply the new evict rule in chunked prefill and decoding.
    • Modify cache_finished_req() to free local indices once the reqs are finished
  7. In flashattention_backend.py

    • When hybrid cache is enabled, set cache_loc = forward_batch.out_cache_loc_local in normal both decode and extend forward.
    • The page_tables in metadata are modified correspondingly
  8. some essential modification for memory computations

    Every time we meet token_to_kv_pool_allocator.available_size(), change it to

    min(token_to_kv_pool_allocator.available_size(), token_to_kv_pool_allocator_local.available_size())

  9. Some essential changes to support CUDA graph

Experiments

python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --mem-fraction-static 0.8 --context-length 65536 --attention-backend fa3 --enable-hybrid-kvcache 0.5 --disable-radix-cache

lm_eval --model local-chat-completions --model_args model=meta-llama/Llama-4-Scout-17B-16E-Instruct,base_url=http://localhost:30002/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=2048 --tasks mmlu_pro --apply_chat_template --num_fewshot 0

Tasks Version Filter n-shot Metric   Value   Stderr
mmlu_pro 2 custom-extract   exact_match 0.7470 ± 0.0039
biology 1 custom-extract 0 exact_match 0.8745 ± 0.0124
business 1 custom-extract 0 exact_match 0.7719 ± 0.0149
chemistry 1 custom-extract 0 exact_match 0.8021 ± 0.0118
computer_science 1 custom-extract 0 exact_match 0.7683 ± 0.0209
economics 1 custom-extract 0 exact_match 0.8318 ± 0.0129
engineering 1 custom-extract 0 exact_match 0.6533 ± 0.0153
health 1 custom-extract 0 exact_match 0.7372 ± 0.0154
history 1 custom-extract 0 exact_match 0.6273 ± 0.0248
law 1 custom-extract 0 exact_match 0.5204 ± 0.0151
math 1 custom-extract 0 exact_match 0.8379 ± 0.0100
other 1 custom-extract 0 exact_match 0.7186 ± 0.0148
philosophy 1 custom-extract 0 exact_match 0.6493 ± 0.0214
physics 1 custom-extract 0 exact_match 0.7983 ± 0.0111
psychology 1 custom-extract 0 exact_match 0.7794 ± 0.0147
Groups Version Filter n-shot Metric   Value   Stderr
mmlu_pro 2 custom-extract   exact_match 0.747 ± 0.0039

python3 bench_sglang.py --num-questions 1000 --port 30002

Accuracy: 0.922
Invalid: 0.000
Latency: 97.575 s
Output throughput: 1012.703 token/s

TODO

  1. Enable when page_size > 1
  2. Apply evict rule when radix tree is enable
  3. run long context length tests
  4. ...

Checklist

hybrid cache

hybrid cache

hybrid cache end

with evict rules and reformat

1

2
@tarinkk tarinkk changed the title [WIP]hybrid kv cache for LlaMa4 [WIP]hybrid kv cache for LlaMA4 May 24, 2025
@tarinkk tarinkk changed the title [WIP]hybrid kv cache for LlaMA4 [WIP]hybrid kv cache for LLaMA4 May 24, 2025
@tarinkk tarinkk force-pushed the llama4hybridCache branch from 5e66e89 to d053027 Compare May 25, 2025 00:29
@tarinkk tarinkk marked this pull request as ready for review May 25, 2025 00:29
@tarinkk tarinkk force-pushed the llama4hybridCache branch from c9fa7ea to d1203cb Compare May 25, 2025 01:50
@@ -692,6 +693,18 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.page_size,
help="The number of tokens in a page.",
)
parser.add_argument(
"--enable-hybrid-kvcache",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we name this --hybrid-kvcache-ratio? enable kinda means turning it on. If user sets it to 0, it means disable it right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we name this --hybrid-kvcache-ratio?
Yes, it is much better.
If user sets it to 0, it means disable it right?
Yes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants