Skip to content

Commit ab065ff

Browse files
authored
[BugFix]Mamba pooling&mtp (#10565)
### What this PR does / why we need it? - Fixed the bug where the KV cache usage statistics displayed a value less than zero. - Fixed the bug where the hybrid attention pointers registered to the backend were not aligned to 2MB. - vLLM version: v0.23.0 - vLLM main: vllm-project/vllm@967c5c3 --------- Signed-off-by: Qingsong Zhang <1640410765@qq.com>
1 parent 709bd7b commit ab065ff

4 files changed

Lines changed: 158 additions & 17 deletions

File tree

tests/ut/distributed/ascend_store/test_config_data.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,74 @@ def test_update_invalid_type(self):
283283
with self.assertRaises(ValueError):
284284
tracker.update("invalid") # type: ignore[arg-type]
285285

286+
def test_update_mamba_with_tuple(self):
287+
tracker = RequestTracker(
288+
req_id="r1", token_len=16, allocated_block_ids_by_group=[[1], [2], [3], [4]], block_sizes=[16] * 4
289+
)
290+
tracker.update(([5, 6], [0, 7], [0, 8], [0, 9]))
291+
self.assertEqual(tracker.allocated_block_ids_by_group[0], [1, 5, 6])
292+
self.assertEqual(tracker.allocated_block_ids_by_group[1], [2, 0, 7])
293+
self.assertEqual(tracker.allocated_block_ids_by_group[2], [3, 0, 8])
294+
self.assertEqual(tracker.allocated_block_ids_by_group[3], [4, 0, 9])
295+
296+
def test_update_mamba_mtp_with_tuple_chunk2(self):
297+
tracker = RequestTracker(
298+
req_id="r1",
299+
token_len=32,
300+
allocated_block_ids_by_group=[
301+
[1, 2],
302+
[0, 3, 4, 5, 6],
303+
[0, 7, 8, 9, 10],
304+
[0, 11, 12, 13, 14],
305+
],
306+
mamba_group_ids=[1, 2, 3],
307+
num_speculative_blocks=3,
308+
block_sizes=[16] * 4,
309+
)
310+
311+
tracker.update(([15, 16], [4, 17], [8, 18], [12, 19]), 32)
312+
self.assertEqual(tracker.allocated_block_ids_by_group[0], [1, 2, 15, 16])
313+
self.assertEqual(tracker.allocated_block_ids_by_group[1], [0, 3, 0, 5, 6, 4, 17])
314+
self.assertEqual(tracker.allocated_block_ids_by_group[2], [0, 7, 0, 9, 10, 8, 18])
315+
self.assertEqual(tracker.allocated_block_ids_by_group[3], [0, 11, 0, 13, 14, 12, 19])
316+
317+
def test_update_mamba_mtp_with_tuple_chunk8(self):
318+
tracker = RequestTracker(
319+
req_id="r1",
320+
token_len=128,
321+
allocated_block_ids_by_group=[
322+
[1, 2, 3, 4, 5, 6, 7, 8],
323+
[0, 0, 0, 0, 0, 0, 0, 9, 10, 11, 12],
324+
[0, 0, 0, 0, 0, 0, 0, 13, 14, 15, 16],
325+
[0, 0, 0, 0, 0, 0, 0, 17, 18, 19, 20],
326+
],
327+
mamba_group_ids=[1, 2, 3],
328+
num_speculative_blocks=3,
329+
block_sizes=[16] * 4,
330+
)
331+
332+
tracker.update(
333+
(
334+
[21, 22, 23, 24, 25, 26, 27, 28],
335+
[0, 0, 0, 0, 10, 11, 12, 29],
336+
[0, 0, 0, 0, 14, 15, 16, 30],
337+
[0, 0, 0, 0, 18, 19, 20, 31],
338+
),
339+
128,
340+
)
341+
self.assertEqual(
342+
tracker.allocated_block_ids_by_group[0], [1, 2, 3, 4, 5, 6, 7, 8, 21, 22, 23, 24, 25, 26, 27, 28]
343+
)
344+
self.assertEqual(
345+
tracker.allocated_block_ids_by_group[1], [0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 11, 12, 29]
346+
)
347+
self.assertEqual(
348+
tracker.allocated_block_ids_by_group[2], [0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 14, 15, 16, 30]
349+
)
350+
self.assertEqual(
351+
tracker.allocated_block_ids_by_group[3], [0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 18, 19, 20, 31]
352+
)
353+
286354

287355
class TestReqMeta(unittest.TestCase):
288356
def test_from_request_tracker_basic_save(self):

vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,13 @@ class RequestTracker:
542542

543543
last_block_key: str | None = None
544544

545+
mamba_group_ids: list[int] | None = None
546+
547+
# spec blocks for mamba cache group
548+
num_speculative_blocks: int = 0
549+
550+
block_sizes: list[int] | None = None
551+
545552
def __init__(
546553
self,
547554
req_id: str,
@@ -558,9 +565,14 @@ def __init__(
558565
ends: list[int] | None = None,
559566
sizes_per_chunk: list[list[int]] | None = None,
560567
last_block_key: str | None = None,
568+
mamba_group_ids: list[int] | None = None,
569+
num_speculative_blocks: int = 0,
570+
block_sizes: list[int] | None = None,
561571
) -> None:
562572
self.req_id = req_id
563573
self.token_len = token_len
574+
self.mamba_group_ids = mamba_group_ids
575+
self.num_speculative_blocks = num_speculative_blocks
564576
block_ids = allocated_block_ids_by_group
565577
if block_ids is None:
566578
block_ids = normalize_block_ids_by_group(allocated_block_ids or [])
@@ -575,6 +587,7 @@ def __init__(
575587
self.ends = ends
576588
self.sizes_per_chunk = sizes_per_chunk
577589
self.last_block_key = last_block_key
590+
self.block_sizes = block_sizes
578591

579592
@property
580593
def allocated_block_ids(self) -> list[int]:
@@ -601,6 +614,7 @@ def from_new_request(
601614
def update(
602615
self,
603616
new_block_ids: tuple[list[int], ...] | list[int],
617+
num_computed_tokens: int = 0,
604618
) -> None:
605619
"""Update the request tracker when a running request is scheduled again."""
606620
normalized = normalize_block_ids_by_group(new_block_ids)
@@ -609,8 +623,37 @@ def update(
609623
[[] for _ in range(len(normalized) - len(self.allocated_block_ids_by_group))]
610624
)
611625
for group_id, ids in enumerate(normalized):
626+
self.update_mamba_spec_blocks(ids, group_id, num_computed_tokens)
612627
self.allocated_block_ids_by_group[group_id].extend(ids)
613628

629+
def update_mamba_spec_blocks(self, block_ids: list[int], kv_cache_group_id: int, num_computed_tokens: int):
630+
"""
631+
for mamba align groups, each step will:
632+
- Firstly, remove some previous blocks and append some necessary null blocks
633+
- Secondly, move the speculative blocks(maybe all or partially) to the last position for reuse
634+
- Finally, allocate a new block
635+
so, if a speculative block is moved to last position and replaced with null block,
636+
we also need to update the previous allocated_block_ids to 0.
637+
"""
638+
if self.mamba_group_ids and kv_cache_group_id in self.mamba_group_ids:
639+
assert self.block_sizes is not None and len(self.block_sizes) > kv_cache_group_id
640+
num_skipped_blocks = (
641+
max(num_computed_tokens - self.num_speculative_blocks - 1, 0) // self.block_sizes[kv_cache_group_id]
642+
)
643+
num_skipped_blocks = min(len(self.allocated_block_ids_by_group[kv_cache_group_id]), num_skipped_blocks)
644+
if num_skipped_blocks > 0:
645+
self.allocated_block_ids_by_group[kv_cache_group_id][:num_skipped_blocks] = [0] * num_skipped_blocks
646+
if not block_ids or self.num_speculative_blocks <= 0:
647+
return
648+
mask_spec_count = min(len(block_ids) - 1, self.num_speculative_blocks)
649+
group_block_ids = self.allocated_block_ids_by_group[kv_cache_group_id]
650+
if mask_spec_count >= self.num_speculative_blocks:
651+
group_block_ids[-self.num_speculative_blocks :] = [0] * self.num_speculative_blocks
652+
else:
653+
group_block_ids[-self.num_speculative_blocks : mask_spec_count - self.num_speculative_blocks] = [
654+
0
655+
] * mask_spec_count
656+
614657

615658
@dataclass(init=False)
616659
class ReqMeta:

vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def __init__(
9797
self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
9898

9999
self.mamba_group_ids = self._infer_mamba_groups()
100+
self.num_speculative_blocks = (
101+
vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0
102+
)
100103
self.original_block_size = self._infer_group_block_sizes(vllm_config, kv_cache_config)
101104
cp_scale = self.pcp_size * self.dcp_size
102105
self.grouped_block_size = [block_size * cp_scale for block_size in self.original_block_size]
@@ -704,6 +707,9 @@ def _process_new_request(
704707
block_keys=(previous_tracker.block_keys.copy() if previous_tracker else []),
705708
block_gvas=(previous_tracker.block_gvas.copy() if previous_tracker else []),
706709
gva_block_offset=(previous_tracker.gva_block_offset if previous_tracker else 0),
710+
mamba_group_ids=self.mamba_group_ids,
711+
num_speculative_blocks=self.num_speculative_blocks,
712+
block_sizes=self.grouped_block_size,
707713
)
708714
self._request_trackers[request.req_id] = request_tracker
709715
num_blocks = num_tokens_to_compute // self._block_size
@@ -751,6 +757,9 @@ def _process_preempted_cached_request(
751757
block_keys=(previous_tracker.block_keys.copy() if previous_tracker else []),
752758
block_gvas=(previous_tracker.block_gvas.copy() if previous_tracker else []),
753759
gva_block_offset=(previous_tracker.gva_block_offset if previous_tracker else 0),
760+
mamba_group_ids=self.mamba_group_ids,
761+
num_speculative_blocks=self.num_speculative_blocks,
762+
block_sizes=self.grouped_block_size,
754763
)
755764
self._request_trackers[req_id] = request_tracker
756765
num_blocks = len(new_block_ids_by_group[0])
@@ -785,15 +794,15 @@ def _process_running_cached_request(
785794
raise ValueError(f"Request {req_id} is not in _request_trackers, but it is scheduled to be cached")
786795
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
787796
req_tuple = self._unfinished_requests.get(req_id)
788-
if req_tuple:
789-
request = req_tuple[0]
790-
num_current_tokens = request_tracker.token_len
791-
new_token_ids = request.all_token_ids[num_current_tokens : num_current_tokens + num_new_tokens]
792-
if request_tracker.token_ids is not None and new_token_ids:
793-
request_tracker.token_ids.extend(new_token_ids)
794-
request_tracker.token_len += num_new_tokens
795-
else:
797+
if not req_tuple:
796798
raise ValueError(f"Request {req_id} is not in _unfinished_requests, but it is scheduled to be cached")
799+
request = req_tuple[0]
800+
num_current_tokens = request_tracker.token_len
801+
new_token_ids = request.all_token_ids[num_current_tokens : num_current_tokens + num_new_tokens]
802+
if request_tracker.token_ids is not None and new_token_ids:
803+
request_tracker.token_ids.extend(new_token_ids)
804+
request_tracker.token_len += num_new_tokens
805+
797806
prev_token_count = request_tracker.token_len - num_new_tokens
798807
prev_hash_count = prev_token_count // self._block_size
799808
current_hash_count = request_tracker.token_len // self._block_size
@@ -813,7 +822,7 @@ def _process_running_cached_request(
813822
has_last_block=True,
814823
)
815824
if new_block_ids is not None:
816-
request_tracker.update(new_block_ids)
825+
request_tracker.update(new_block_ids, request.num_computed_tokens)
817826
load_spec = None
818827
return self._build_req_meta(
819828
request_tracker,
@@ -846,6 +855,9 @@ def _process_async_load_request(
846855
block_keys=(previous_tracker.block_keys.copy() if previous_tracker else []),
847856
block_gvas=(previous_tracker.block_gvas.copy() if previous_tracker else []),
848857
gva_block_offset=(previous_tracker.gva_block_offset if previous_tracker else 0),
858+
mamba_group_ids=self.mamba_group_ids,
859+
num_speculative_blocks=self.num_speculative_blocks,
860+
block_sizes=self.grouped_block_size,
849861
)
850862
self._request_trackers[request_id] = request_tracker
851863
num_blocks = num_tokens_to_compute // self._block_size
@@ -973,9 +985,9 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
973985
hand the connector_output, free non-null mamba blocks and so on.
974986
"""
975987
meta = connector_output.kv_connector_worker_meta
976-
if not isinstance(meta, AscendStoreKVConnectorWorkerMetadata):
988+
if not isinstance(meta, AscendStoreKVConnectorWorkerMetadata) or self._block_pool is None:
977989
return
978-
to_free_block_ids: list[int] = []
990+
979991
for event_id, count in meta.completed_events.items():
980992
logger.debug("event %s update with %s", event_id, count)
981993
total = self.sending_events.get(event_id, -1)
@@ -984,16 +996,14 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
984996
continue
985997
total = total + count
986998
if total >= self._expected_worker_count:
987-
to_free_block_ids.extend(self.sending_blocks.pop(event_id, []))
999+
to_free_block_ids = self.sending_blocks.pop(event_id, [])
9881000
self.sending_events.pop(event_id, None)
1001+
if to_free_block_ids:
1002+
logger.debug("event %s free blocks: %s", event_id, to_free_block_ids)
1003+
self._block_pool.free_blocks([self._block_pool.blocks[block_id] for block_id in to_free_block_ids])
9891004
else:
9901005
self.sending_events[event_id] = total
9911006

992-
if to_free_block_ids:
993-
logger.debug("free blocks: %s", to_free_block_ids)
994-
assert self._block_pool is not None
995-
self._block_pool.free_blocks([self._block_pool.blocks[block_id] for block_id in to_free_block_ids])
996-
9971007
def request_finished(
9981008
self,
9991009
request: "Request",

vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,25 @@ def _infer_cache_group_metadata(self, group_id: int, layer_names: list[str]):
500500
self.group_block_stride[group_id] = group_block_strides
501501
self.group_num_layers[group_id] = len(layer_names)
502502

503+
def _align_kv_ptrs(self, registered_regions: dict[int, tuple[int, int]]):
504+
"""
505+
In hybrid scenario, where a KVCacheTensor is shared by multiple layers,
506+
but sometimes, layers cannot be evenly distributed among multiple groups,
507+
the layers sharing the KVCacheTensor may not completely occupy all the space of the KVCacheTensor.
508+
This results in the calculated start address not being the previously aligned address.
509+
Therefore, we down-align the start address to meet the 2MB alignment requirement.
510+
"""
511+
if not self.use_hybrid:
512+
return
513+
alignment = 2 * 1024 * 1024
514+
for storage_key in registered_regions:
515+
start, end = registered_regions[storage_key]
516+
new_start = start // alignment * alignment
517+
# Because the addresses of raw tensors are aligned to 2MB,
518+
# all shared sub-tensors, when aligned downwards, should theoretically not exceed the address bounds.
519+
assert new_start >= storage_key, "invalid kv cache tensor, raw tensor ptr must be align to 2MB"
520+
registered_regions[storage_key] = (new_start, end)
521+
503522
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
504523
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
505524
first_kv_cache_tuple = self._as_cache_tuple(first_kv_cache_tuple)
@@ -553,6 +572,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
553572
else:
554573
registered_regions[storage_key] = (start, end)
555574

575+
self._align_kv_ptrs(registered_regions)
556576
ptrs = [start for start, _ in registered_regions.values()]
557577
lengths = [end - start for start, end in registered_regions.values()]
558578

0 commit comments

Comments
 (0)