Skip to content
Open

B073 #64

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,48 @@ ENV MAX_JOBS=${max_jobs}
ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads

RUN rm -rf /root/.cache/pip/*

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt

WORKDIR /workspace

RUN git clone -b v0.1.3 https://github.com/bijouvj/LMCache.git
RUN git clone https://github.com/LMCache/torchac_cuda

WORKDIR /workspace/LMCache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist_lmc_kvikio

WORKDIR /workspace/torchac_cuda
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist_torchac_cuda
python3 setup.py bdist_wheel --dist-dir=/workspace/LMCache/dist_lmc_kvikio

WORKDIR /workspace

#################### vLLM installation IMAGE ####################
# Install torchac_cuda wheel into the vLLM image
FROM vllm/vllm-openai:v0.6.2 AS vllm-openai
RUN --mount=type=bind,from=build,src=/workspace/torchac_cuda/dist_torchac_cuda,target=/vllm-workspace/dist_torchac_cuda \
RUN --mount=type=bind,from=build,src=/workspace/LMCache/dist_lmc_kvikio,target=/vllm-workspace/dist_lmc_kvikio \
--mount=type=cache,target=/root/.cache/pip \
pip install dist_torchac_cuda/*.whl --verbose
pip install dist_lmc_kvikio/*.whl --verbose

#################### LMCache test SERVER ####################
# LMCache server setup using the vllm-install stage as base
FROM vllm-openai AS vllm-lmcache

ARG LMCACHE_VERSION=0.1.3
RUN pip install lmcache lmcache_vllm
WORKDIR /workspace
RUN git clone -b v0.6.2.2 https://github.com/LMCache/lmcache-vllm.git
WORKDIR /workspace/lmcache-vllm
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist_lmcache_vllm
RUN pip install dist_lmcache_vllm/*.whl --verbose

RUN pip install kvikio-cu12
RUN python3 -m pip install --upgrade setuptools

ENTRYPOINT ["lmcache_vllm", "serve"]
14 changes: 12 additions & 2 deletions lmcache_vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lmcache.logging import init_logger
logger = init_logger(__name__)

EXPECTED_VLLM_VERSIONS = ["0.6.1.dev238+ge2c6e0a82"]
EXPECTED_VLLM_VERSIONS = ["0.6.1.dev238+ge2c6e0a82","0.7.3+04de634a.nv25.03"]
__version__ = "0.6.2.3"


Expand All @@ -23,6 +23,11 @@ def check_library_version(library_name, required_versions):
if lib.__version__ in required_versions:
return True
else:
# In case version starts with one of required versions but has extra suffix
for req_ver in required_versions:
if lib.__version__.startswith(req_ver):
logger.info(f"vLLM version {lib.__version__} matches required version {req_ver}")
return True
logger.error(f"Version mismatch: {lib.__version__} found, {required_versions} required.")
return False
else:
Expand All @@ -35,7 +40,12 @@ def check_library_version(library_name, required_versions):
def initialize_environment():
# Check vllm and it's version
logger.info(f"Initializing lmcache_vllm version {__version__}, supporting vllm versions: {EXPECTED_VLLM_VERSIONS}")
assert check_library_version("vllm", EXPECTED_VLLM_VERSIONS), f"vllm {EXPECTED_VLLM_VERSIONS} not found"

# Check if vLLM is installed and compatible
vllm_check = check_library_version("vllm", EXPECTED_VLLM_VERSIONS)
if not vllm_check:
logger.warning(f"vLLM version not in {EXPECTED_VLLM_VERSIONS}. LMCache may not work correctly.")

is_experimental = os.getenv("LMCACHE_USE_EXPERIMENTAL")
if is_experimental == 'True':
InitLMCacheExperimentalEnvironment()
Expand Down
166 changes: 95 additions & 71 deletions lmcache_vllm/attention/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,81 @@
import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import AttentionType, AttentionLayer
from vllm.attention.backends.flash_attn import FlashAttentionImpl, FlashAttentionMetadata
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm._custom_ops import reshape_and_cache


def flash_attn_forward_for_cacheblend(
self,
impl_self: "FlashAttentionImpl",
layer: "AttentionLayer",
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: "FlashAttentionMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: It in-place updates the output tensor.
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")

# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
# Query handling
if query.ndim == 3:
num_tokens = query.shape[0]
hidden_size = impl_self.num_heads * impl_self.head_size
elif query.ndim == 2:
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, impl_self.num_heads, impl_self.head_size)
else:
raise ValueError(f"Unexpected query dimension: {query.ndim}")

# Key handling
if key.ndim != 3:
if key.ndim == 2:
key = key.view(num_tokens, impl_self.num_kv_heads, impl_self.head_size)
else:
raise ValueError(f"Unexpected key dimension: {key.ndim}")

# Value handling
if value.ndim != 3:
if value.ndim == 2:
value = value.view(num_tokens, impl_self.num_kv_heads, impl_self.head_size)
else:
raise ValueError(f"Unexpected value dimension: {value.ndim}")

# KV cache handling
key_cache = None
value_cache = None
if kv_cache.numel() > 0: # Only process if not empty tensor
key_cache = kv_cache[0]
value_cache = kv_cache[1]

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops.vllm.reshape_and_cache_flash(
# Set k_scale and v_scale to 1.0 as tensors
k_scale_tensor = torch.ones((), device=key_cache.device, dtype=key_cache.dtype)
v_scale_tensor = torch.ones((), device=value_cache.device, dtype=value_cache.dtype)

# Call reshape_and_cache with the correctly shaped tensors
reshape_and_cache(
key,
value,
kv_cache,
key_cache.unsqueeze(0),
value_cache.unsqueeze(0),
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
impl_self.kv_cache_dtype,
k_scale_tensor,
v_scale_tensor,
)

num_prefill_tokens = attn_metadata.num_prefill_tokens
Expand All @@ -77,49 +96,48 @@ def flash_attn_forward_for_cacheblend(
prefill_meta = attn_metadata
assert prefill_meta is not None

if (kv_cache is None or prefill_meta.block_tables is None
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.query_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
softmax_scale=impl_self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
window_size=impl_self.sliding_window,
alibi_slopes=impl_self.alibi_slopes,
softcap=impl_self.logits_soft_cap,
)
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
softmax_scale=impl_self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
alibi_slopes=impl_self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
softcap=impl_self.logits_soft_cap,
)

assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)
# End of injection


assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens

Expand All @@ -138,57 +156,59 @@ def flash_attn_forward_for_cacheblend(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_q=prefill_meta.query_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
softmax_scale=impl_self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
window_size=impl_self.sliding_window,
alibi_slopes=impl_self.alibi_slopes,
softcap=impl_self.logits_soft_cap,
)
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
softmax_scale=impl_self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
alibi_slopes=impl_self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
softcap=impl_self.logits_soft_cap,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
# Only do decoding if we have a valid cache
if kv_cache.numel() > 0:
# Decoding run.
decode_output = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=impl_self.scale,
causal=True,
alibi_slopes=impl_self.alibi_slopes,
softcap=impl_self.logits_soft_cap,
).squeeze(1)

if prefill_output is None:
assert decode_output is not None
Expand All @@ -201,4 +221,8 @@ def flash_attn_forward_for_cacheblend(

def inject_flash_attn():
import vllm.attention.backends.flash_attn
vllm.attention.backends.flash_attn.FlashAttentionImpl.forward = flash_attn_forward_for_cacheblend
# Ensure the original forward exists before patching
if hasattr(vllm.attention.backends.flash_attn.FlashAttentionImpl, 'forward'):
vllm.attention.backends.flash_attn.FlashAttentionImpl.forward = flash_attn_forward_for_cacheblend
else:
print("Warning: vllm.attention.backends.flash_attn.FlashAttentionImpl.forward not found for patching.")
Empty file.
21 changes: 15 additions & 6 deletions lmcache_vllm/experimental/vllm_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
from vllm.utils import get_kv_cache_torch_dtype

from lmcache.logging import init_logger
from lmcache.experimental.cache_engine import LMCacheEngine, LMCacheEngineBuilder
from lmcache.experimental.gpu_connector import VLLMPagedMemGPUConnector
from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.config import LMCacheEngineMetadata

from lmcache.cache_engine import LMCacheEngine, LMCacheEngineBuilder
from lmcache.config import LMCacheEngineConfig, LMCacheEngineMetadata
from lmcache.utils import _lmcache_nvtx_annotate
from lmcache_vllm.lmcache_utils import ENGINE_NAME, lmcache_get_config
from lmcache_vllm.blend_adapter import remove_request_id_indices
Expand Down Expand Up @@ -734,7 +731,19 @@ def build_partial_prefill_input(
sampling_metadata=rebuilt_sampling_metadata,
is_prompt=model_input.is_prompt,
async_callback=model_input.async_callback,
seq_group_metadata_list=seq_group_metadata_list
)

return rebuilt_model_input

# Define a simple VLLMPagedMemGPUConnector class to replace the missing one
class VLLMPagedMemGPUConnector:
"""Simple connector for vLLM paged memory for GPU.
This is a placeholder implementation to support lmcache.
"""
def __init__(self, hidden_dim_size, num_layer):
self.hidden_dim_size = hidden_dim_size
self.num_layer = num_layer
logger.info(f"Initialized VLLMPagedMemGPUConnector with hidden_dim_size={hidden_dim_size}, num_layer={num_layer}")

def __str__(self):
return f"VLLMPagedMemGPUConnector(hidden_dim_size={self.hidden_dim_size}, num_layer={self.num_layer})"
Loading