Skip to content

[WIP]hybrid kv cache for LLaMA4 #6563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
) -> None:

self.model_path = model_path
Expand All @@ -78,6 +79,12 @@ def __init__(
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
self.is_hybrid = is_hybrid_model(
self.hf_config.architectures,
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
context_length=context_length,
attention_chunk_size=self.attention_chunk_size,
)

if enable_multimodal is None:
mm_disabled_models = [
Expand Down Expand Up @@ -256,6 +263,7 @@ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
**kwargs,
)

Expand Down Expand Up @@ -617,3 +625,21 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0


def is_hybrid_model(
model_architectures: List[str],
hybrid_kvcache_ratio: Optional[float],
context_length: Optional[int],
attention_chunk_size: Optional[int],
):
if hybrid_kvcache_ratio is None:
return None
elif (
hybrid_kvcache_ratio > 0
and model_architectures[0] == "Llama4ForConditionalGeneration"
and context_length > attention_chunk_size
):
return hybrid_kvcache_ratio
else:
return None
89 changes: 76 additions & 13 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class FlashAttentionMetadata:
window_size: tuple = (-1, -1)
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None
# Page table for local attention
page_table_local: torch.Tensor = None

# Encoder metadata
# Cumulative sequence lengths for encoder key
Expand Down Expand Up @@ -320,6 +322,7 @@ def __init__(
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
Expand All @@ -333,6 +336,11 @@ def __init__(
if hasattr(model_runner, "attention_chunk_size")
else None
)
self.req_to_token_local = (
model_runner.req_to_token_pool.req_to_token_local
if self.is_hybrid is not None
else None
)

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
Expand Down Expand Up @@ -427,6 +435,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if self.is_hybrid is not None:
metadata.page_table_local = (
forward_batch.req_to_token_pool.req_to_token_local[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
)
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device)
elif forward_batch.forward_mode.is_target_verify():
Expand Down Expand Up @@ -562,6 +576,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if self.is_hybrid is not None:
metadata.page_table_local = (
forward_batch.req_to_token_pool.req_to_token_local[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
)

if (
any(forward_batch.extend_prefix_lens_cpu)
Expand Down Expand Up @@ -627,14 +647,20 @@ def forward_extend(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
use_hybrid_loc = self.is_hybrid is not None and (
hasattr(layer, "use_irope") and layer.use_irope
)
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_hybrid_loc:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
else:
cache_loc = forward_batch.out_cache_loc_local
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
Expand Down Expand Up @@ -890,14 +916,20 @@ def forward_decode(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_hybrid_loc = self.is_hybrid is not None and (
hasattr(layer, "use_irope") and layer.use_irope
)
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_hybrid_loc:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
else:
cache_loc = forward_batch.out_cache_loc_local
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
Expand Down Expand Up @@ -1147,6 +1179,12 @@ def init_cuda_graph_state(self, max_bs: int):
dtype=torch.int32,
device=self.device,
),
"page_table_local": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"page_table_draft_decode": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
Expand Down Expand Up @@ -1453,6 +1491,10 @@ def init_forward_metadata_capture_cuda_graph(
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, :
]
if self.is_hybrid is not None:
metadata.page_table_local = self.decode_cuda_graph_metadata[
"page_table_local"
][req_pool_indices, :]
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
Expand Down Expand Up @@ -1680,6 +1722,18 @@ def init_forward_metadata_replay_cuda_graph(
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
if self.is_hybrid is not None:
page_indices_local = self.req_to_token_local[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][
:max_seq_pages
][None, :],
]
page_indices_local //= self.page_size
metadata.page_table_local[:, :max_seq_pages].copy_(
page_indices_local
)
metadata.page_table_local[:, max_seq_pages:].fill_(0)

self._update_local_attn_metadata_for_replay(metadata, bs)
elif forward_mode.is_target_verify():
Expand Down Expand Up @@ -1841,7 +1895,10 @@ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):

cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens_int32 = metadata.cache_seqlens_int32
page_table = metadata.page_table
if self.is_hybrid is not None:
page_table = metadata.page_table_local
else:
page_table = metadata.page_table
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
metadata.local_attn_metadata = None
return
Expand Down Expand Up @@ -1881,7 +1938,10 @@ def _update_local_attn_metadata_for_capture(
"""
seq_lens_capture = metadata.cache_seqlens_int32
max_seq_len = int(seq_lens_capture.max().item())
page_table_capture = metadata.page_table
if self.is_hybrid is not None:
page_table_capture = metadata.page_table_local
else:
page_table_capture = metadata.page_table

cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
seqlens_np = seq_lens_capture.cpu().numpy()
Expand Down Expand Up @@ -1958,7 +2018,10 @@ def _update_local_attn_metadata_for_replay(
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
# beyond the actual sequence length, leading to incorrect attention calculations
max_seq_len = int(seqlens.max().item())
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
if self.is_hybrid is not None:
sliced_page_table = metadata.page_table_local[:bs, :max_seq_len]
else:
sliced_page_table = metadata.page_table[:bs, :max_seq_len]

cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seqlens_np = seqlens.cpu().numpy()
Expand Down
Loading
Loading