-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[WIP]hybrid kv cache for LLaMA4 #6563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
hybrid cache hybrid cache hybrid cache end with evict rules and reformat 1 2
5e66e89
to
d053027
Compare
c9fa7ea
to
d1203cb
Compare
@@ -692,6 +693,18 @@ def add_cli_args(parser: argparse.ArgumentParser): | |||
default=ServerArgs.page_size, | |||
help="The number of tokens in a page.", | |||
) | |||
parser.add_argument( | |||
"--enable-hybrid-kvcache", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we name this --hybrid-kvcache-ratio
? enable kinda means turning it on. If user sets it to 0, it means disable it right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we name this --hybrid-kvcache-ratio?
Yes, it is much better.
If user sets it to 0, it means disable it right?
Yes.
Motivation
LLaMA 4 uses local attention in 3/4 of its layers. To accommodate this, we divide the KV cache into two parts: a global cache and a local cache. Determining the optimal ratio between their sizes is nontrivial, so we introduce a tunable parameter
p
, where0 ≤ p ≤ 1
.p = 1
, the ratio of global to local cache sizes is equal tocontext_length / attention_chunk_size
(e.g., 8192).p = 0
, the two caches are of equal size.p
varies from 0 to 1. By default, we setp = 0.5
.Currently, we disable the radix tree, so prefix matching is not a concern.
During local attention, certain KV cache entries can be safely removed:
attention_chunk_size * (prelen // attention_chunk_size)
are no longer needed and can be evicted.attention_chunk_size * ((seqlen - 1) // attention_chunk_size)
are similarly unused and can be discardedModifications
Add a server argument:
enable_hybrid_kvcache
with default value0.5
.This turns on the hybrid KV cache mode and controls the global-to-local cache size ratio.In model_config.py: add
is_hybrid_model()
to determine whether the current model configuration satisfies the conditions to enable hybrid KV caching.In model_runner.py:
get_num_token_hybrid()
to get the size of global and local KV cachetoken_to_kv_pool_allocator_local
to allocate local cache indicesIn memory_pool.py:
ReqToTokenPool
, add a new attrreq_to_token_local
to store local indices in KV cache per reqMHATokenToKVPool._create_buffer
to create global and local cache buffers.In schedule_batch.py:
prepare_for_extend()
andprepare_for_decode()
, allocateout_cache_loc_local
for local attention KV indices, and store them intoken_to_token_pool.req_to_token_local
.self.tree_cache.evict_hybrid()
right before allocating new indices.In chunk_cache.py:
evict_hybrid()
is defined to apply the new evict rule in chunked prefill and decoding.cache_finished_req()
to free local indices once the reqs are finishedIn flashattention_backend.py
cache_loc = forward_batch.out_cache_loc_local
in normal both decode and extend forward.page_table
s in metadata are modified correspondinglysome essential modification for memory computations
Every time we meet
token_to_kv_pool_allocator.available_size()
, change it tomin(token_to_kv_pool_allocator.available_size(), token_to_kv_pool_allocator_local.available_size())
Some essential changes to support CUDA graph
Experiments
python3 -m sglang.launch_server --model-path meta-llama/Llama-4-Scout-17B-16E-Instruct --port 30002 --tp 8 --mem-fraction-static 0.8 --context-length 65536 --attention-backend fa3 --enable-hybrid-kvcache 0.5 --disable-radix-cache
lm_eval --model local-chat-completions --model_args model=meta-llama/Llama-4-Scout-17B-16E-Instruct,base_url=http://localhost:30002/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=2048 --tasks mmlu_pro --apply_chat_template --num_fewshot 0
python3 bench_sglang.py --num-questions 1000 --port 30002
Accuracy: 0.922
Invalid: 0.000
Latency: 97.575 s
Output throughput: 1012.703 token/s
TODO
page_size
> 1Checklist