Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_kunlun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_MODULE_MAPPINGS = {
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
"vllm.v1.worker.utils": "vllm_kunlun.v1.worker.utils",
"vllm.v1.worker.mamba_utils": "vllm_kunlun.v1.worker.mamba_utils",
"vllm.model_executor.model_loader.bitsandbytes_loader": "vllm_kunlun.models.model_loader.bitsandbytes_loader",
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler",
Expand Down
245 changes: 245 additions & 0 deletions vllm_kunlun/v1/worker/mamba_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import itertools
from collections.abc import Callable
from typing import Any

import torch
from vllm.config import CacheConfig
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
from xspeedgate_ops import ops as xspeedgate_ops


def batch_memcpy(src_ptrs, dst_ptrs, sizes):
xspeedgate_ops.batch_memcpy(src_ptrs, dst_ptrs, sizes)


def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
mamba_group_ids: list[int] = []
mamba_specs: list[MambaSpec] = []
for i in range(len(kv_cache_config.kv_cache_groups)):
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
if isinstance(kv_cache_spec, MambaSpec):
mamba_group_ids.append(i)
mamba_specs.append(kv_cache_spec)
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
assert all(mamba_specs[0] == spec for spec in mamba_specs)
return mamba_group_ids, mamba_specs[0]


@dataclasses.dataclass
class MambaCopyBuffers:
src_ptrs: CpuGpuBuffer
dst_ptrs: CpuGpuBuffer
sizes: CpuGpuBuffer
mamba_group_ids: list[int]
mamba_spec: MambaSpec
offset: int = 0

@classmethod
def create(
cls,
max_num_reqs: int,
kv_cache_config: KVCacheConfig,
copy_funcs: tuple[MambaStateCopyFunc, ...],
make_buffer: Callable[..., CpuGpuBuffer],
) -> "MambaCopyBuffers":
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
entries_per_req = sum(
len(kv_cache_config.kv_cache_groups[gid].layer_names)
for gid in mamba_group_ids
) * len(copy_funcs)
n = max_num_reqs * entries_per_req
return cls(
src_ptrs=make_buffer(n, dtype=torch.int64),
dst_ptrs=make_buffer(n, dtype=torch.int64),
sizes=make_buffer(n, dtype=torch.int32),
mamba_group_ids=mamba_group_ids,
mamba_spec=mamba_spec,
)


def collect_mamba_copy_meta(
copy_bufs: MambaCopyBuffers,
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
src_block_idx: int,
dest_block_idx: int,
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
) -> None:
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return

src_ptrs_np = copy_bufs.src_ptrs.np
dst_ptrs_np = copy_bufs.dst_ptrs.np
sizes_np = copy_bufs.sizes.np
offset = copy_bufs.offset

for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
for layer_name in layer_names:
attention = forward_context[layer_name]
kv_caches: list[torch.Tensor] = attention.kv_cache
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
copy_spec = state_copy_func(
state, block_ids, src_block_idx, accept_token_bias + 1
)

src_ptrs_np[offset] = copy_spec.start_addr
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
sizes_np[offset] = copy_spec.num_elements * state.element_size()
offset += 1

copy_bufs.offset = offset


def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
n = copy_bufs.offset
if n == 0:
return
batch_memcpy(
copy_bufs.src_ptrs.copy_to_gpu(n),
copy_bufs.dst_ptrs.copy_to_gpu(n),
copy_bufs.sizes.copy_to_gpu(n),
)


def preprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
cache_config: CacheConfig,
mamba_state_idx: dict[str, int],
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
Copy the mamba state of previous step to the last
(1 + num_speculative_blocks) block.
"""
mamba_group_ids = copy_bufs.mamba_group_ids
mamba_spec = copy_bufs.mamba_spec
num_speculative_blocks = mamba_spec.num_speculative_blocks
# TODO(Chen): we need to optimize this function a lot
assert cache_config.enable_prefix_caching
block_size = mamba_spec.block_size
finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids or set()
# We need to clear mamba_state_idx for resumed requests. When requests are
# force-preempted (e.g., during reset_prefix_cache / KV cache flush),
# they appear in resumed_req_ids without a corresponding entry in
# preempted_req_ids, leaving stale mamba_state_idx entries that can
# point to block indices beyond the new (smaller) block allocation.
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
mamba_state_idx.pop(req_id, None)

copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
if prev_state_idx is None:
# new / resumed request, no previous state
# if num_computed_tokens is 0, prev_state_idx will be -1
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size

num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_blocks: int = (
cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
+ num_speculative_blocks
)

# We always save the current running state at the last
# (1 + num_speculative_blocks) block.
# A corner case worth mention here: assume we have block_size = 4 and
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
# tokens [draft 1, draft 2]. Then we will have:
# Block 0: [A, B, C, draft 1]
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
# Block 2: speculative block
# Block 3: speculative block
# And use block 1 to save the running state.
curr_state_idx = num_blocks - 1 - num_speculative_blocks
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
prev_state_idx,
curr_state_idx,
input_batch.num_accepted_tokens_cpu[i] - 1,
req_state,
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(copy_bufs)


def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
If a blocks is converted from partial block to full block in this step, copy the
state from the block for running state to the new full block.
"""
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
mamba_group_ids = copy_bufs.mamba_group_ids
mamba_spec = copy_bufs.mamba_spec
copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
num_accepted_tokens = num_accepted_tokens_cpu[i]
num_tokens_running_state = (
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
)
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
aligned_new_computed_tokens = (
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
)
# TODO: how to ensure all blocks that cache_blocks called are cached here?
if aligned_new_computed_tokens >= num_tokens_running_state:
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
src_block_idx,
dest_block_idx,
accept_token_bias,
req_state,
forward_context,
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(copy_bufs)
Loading