Skip to content

[Executorch][llm] Enable leveraging ring kv cache via module swap #10611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/kimishpatel/188/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ def forward(
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


def _create_causal_mask_for_ring_buffer(
cache_positions, window_size, start_pos, seq_len
):
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
delta = pos_q - cache_positions
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < window_size)
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
return attn_mask


class CacheUpdateStrategy(Enum):
RING_BUFFER = "RingBuffer"
INVALID = "Invalid"
Expand Down Expand Up @@ -283,12 +293,10 @@ def __init__(
self.is_ring_buffer = True

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
cache_positions = self.cache_positions_manager.cache_positions
delta = pos_q - cache_positions
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
return attn_mask
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
Expand Down
191 changes: 190 additions & 1 deletion examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

import torch
import torch.nn as nn
from executorch.examples.models.llama.attention import KVCache
from executorch.examples.models.llama.attention import (
_create_causal_mask_for_ring_buffer,
CachePositionsManager,
KVCache,
RingKVCache,
)

from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401

Expand Down Expand Up @@ -75,6 +80,7 @@ def __init__(
self.register_buffer(
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
)
self.cache_type = cache_type

def _quantize(self, value):
(
Expand Down Expand Up @@ -181,6 +187,7 @@ def update(self, input_pos, k_val, v_val, indices=None):
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""

k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)

Expand Down Expand Up @@ -346,3 +353,185 @@ def _replace_kv_cache_with_custom_kv_cache(module):
else:
_replace_kv_cache_with_custom_kv_cache(child)
return module


class QuantizedRingKVCache(QuantizedKVCache):
def __init__(
self,
max_batch_size,
max_context_length,
n_heads,
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
max_batch_size,
max_context_length * 2,
n_heads,
head_dim,
cache_type,
use_custom_update_cache_op,
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
self.window_size = max_context_length

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
cache_positions = self.cache_positions_manager.cache_positions
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
# Need to transpose for two reasons
# 1. kv cache is stored as [B, S, H, D]
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
# away transpose at the output of k, v projection
seq_len = k_val.transpose(1, 2).size(1)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
indices = indices.unsqueeze(0)

return super().update(input_pos, k_val, v_val, indices)

@classmethod
def from_quantized_kv_cache(
cls,
kv_cache,
sliding_window_size,
):
assert isinstance(
kv_cache, QuantizedKVCache
), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
sliding_window_size,
n_heads,
head_dim,
kv_cache.cache_type,
kv_cache.use_custom_update_cache_op,
)


class CustomRingKVCache(CustomKVCache):
def __init__(
self,
max_batch_size,
max_context_length,
n_heads,
head_dim,
dtype=torch.float32,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
self.window_size = max_context_length

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
cache_positions = self.cache_positions_manager.cache_positions
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
# Need to transpose for two reasons
# 1. kv cache is stored as [B, S, H, D]
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
# away transpose at the output of k, v projection
seq_len = k_val.transpose(1, 2).size(1)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
indices = indices.unsqueeze(0)

return super().update(input_pos, k_val, v_val, indices)

@classmethod
def from_custom_kv_cache(
cls,
kv_cache,
sliding_window_size,
):
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
sliding_window_size,
n_heads,
head_dim,
dtype=kv_cache.k_cache.dtype,
)


def _replace_kv_cache_with_ring_kv_cache(attention, layer_size):
sliding_window_size = layer_size
assert (
getattr(attention, "kv_cache", None) is not None
), "Attention module must have kv_cache module"
kv_cache = attention.kv_cache
if isinstance(kv_cache, KVCache):
attention.kv_cache = RingKVCache(
kv_cache.max_batch_size,
sliding_window_size,
kv_cache.n_heads,
kv_cache.head_dim,
kv_cache.enable_dynamic_shape,
kv_cache.k_cache.dtype,
)
elif isinstance(kv_cache, CustomKVCache):
attention.kv_cache = CustomRingKVCache.from_custom_kv_cache(
kv_cache, layer_size
)
elif isinstance(kv_cache, QuantizedKVCache):
attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache(
kv_cache, layer_size
)


def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
# This is needed to ensure that custom ops are registered
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401

logging.info(
"Replacing kv cache with ring kv cache. This modifies the model in place."
)
assert len(layer_sizes) == len(
module.layers
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
for i, transformer_block in enumerate(module.layers):
sliding_window_size = layer_sizes[i]
if sliding_window_size == 0:
continue
assert (
getattr(transformer_block, "attention", None) is not None
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
attention = transformer_block.attention
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
return module
25 changes: 25 additions & 0 deletions examples/models/llama/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,33 @@ python_unittest(
srcs = [
"test_ring_attention.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:export_library",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama:custom_kv_cache",
"//executorch/examples/models/llama:sdpa",
],
)

python_unittest(
name = "test_replace_kv_cache",
srcs = [
"test_replace_kv_cache.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:export_library",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama:custom_kv_cache",
"//executorch/examples/models/llama:sdpa",
],
)
Loading
Loading