forked from LMCache/LMCache-Ascend
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnpu_connector.py
More file actions
executable file
·1424 lines (1171 loc) · 56.5 KB
/
npu_connector.py
File metadata and controls
executable file
·1424 lines (1171 loc) · 56.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-License-Identifier: Apache-2.0
# Standard
from enum import Enum, auto
from typing import Any, List, Optional, Set, Tuple, Union
# Third Party
from lmcache.config import LMCacheEngineMetadata
from lmcache.integration.vllm.utils import ENGINE_NAME
from lmcache.logging import init_logger
from lmcache.utils import _lmcache_nvtx_annotate
from lmcache.v1.compute.blend.utils import LMCBlenderBuilder
from lmcache.v1.gpu_connector import (
VLLMBufferLayerwiseGPUConnector,
VLLMPagedMemGPUConnectorV2,
VLLMPagedMemLayerwiseGPUConnector,
)
from lmcache.v1.memory_management import GPUMemoryAllocator, MemoryFormat, MemoryObj
import torch
# First Party
from lmcache_ascend.v1.proxy_memory_obj import ProxyMemoryObj
from lmcache_ascend.v1.transfer_context import AscendBaseTransferContext
import lmcache_ascend.c_ops as lmc_ops
logger = init_logger(__name__)
_IS_310P = None
def is_310p():
global _IS_310P
if _IS_310P is None:
# First Party
from lmcache_ascend import _build_info
_IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
return _IS_310P
class KVCacheFormat(Enum):
"""
The storage format enumeration of KV cache is used to distinguish
the KV cache data structures of different versions of vLLM.
The order of enum values MUST match the KVCacheFormat
definition in kernels/types.h to ensure correct interoperability
between Python and C++ code.
"""
UNDEFINED = 0
MERGED_KV = auto()
"""merge format (eg: vLLM 0.9.2 ...)
layer: [num_kv, num_blocks, block_size, num_heads, head_dim]
"""
SEPARATE_KV = auto()
"""Separation format (eg: vLLM 0.11.0+ ...)
layer: tuple: (K_tensor, V_tensor)
- K_tensor.shape = [num_blocks, block_size, num_heads, head_dim]
- V_tensor.shape = [num_blocks, block_size, num_heads, head_dim]
eg: kvcaches[0] = (K, V)
"""
def is_separate_format(self) -> bool:
return self == KVCacheFormat.SEPARATE_KV
def is_merged_format(self) -> bool:
return self == KVCacheFormat.MERGED_KV
@staticmethod
def detect(
kvcaches: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
use_mla: bool = False,
) -> "KVCacheFormat":
if not kvcaches:
return KVCacheFormat.UNDEFINED
first_cache = kvcaches[0]
if isinstance(first_cache, tuple):
return KVCacheFormat.SEPARATE_KV
elif isinstance(first_cache, torch.Tensor):
ndim = first_cache.ndim
shape = first_cache.shape
# MLA detect
# MLA Shape: [num_blocks, block_size, head_size] (3D)
# or: [1, num_blocks, block_size, head_size] (4D with first dim = 1)
is_mla_shape = (ndim == 3) or (ndim == 4 and shape[0] == 1)
if use_mla or is_mla_shape:
return KVCacheFormat.MERGED_KV
# Flash Attention:[2, num_blocks, block_size, num_heads, head_size]
if ndim == 5 and shape[0] == 2:
return KVCacheFormat.MERGED_KV
# Flash Infer:[num_blocks, 2, block_size, num_heads, head_size]
if ndim == 5 and shape[1] == 2:
return KVCacheFormat.MERGED_KV
return KVCacheFormat.UNDEFINED
class VLLMBufferLayerwiseNPUConnector(VLLMBufferLayerwiseGPUConnector):
def __init__(
self,
hidden_dim_size: int,
num_layers: int,
use_gpu: bool = False,
use_double_buffer: bool = True,
**kwargs,
):
super().__init__(
hidden_dim_size, num_layers, use_gpu, use_double_buffer, **kwargs
)
self.kv_format: KVCacheFormat = KVCacheFormat.UNDEFINED
self.use_mla = bool(kwargs.get("use_mla", False))
self.fused_rotary_emb: Any = None
def _lazy_initialize_buffer(self, kv_caches):
"""
Lazily initialize the GPU buffer allocator if it is not initialized yet.
Currently, we use the `kv_caches` (kv cache pointer) to determine
the gpu buffer size in gpu connector.
Also, the first request might be a bit slower due to buffer creation.
"""
if self.use_gpu and self.gpu_buffer_allocator is None:
logger.info("Lazily initializing GPU buffer.")
# NOTE (Jiayi): We use the first layer to determine the gpu buffer size.
# NOTE (Jiayi): Using the exact number of tokens in the first layer
# is okay since fragmentation shouldn't exist in the `gpu_buffer_allocator`
# in layerwise mode.
self.kv_format = KVCacheFormat.detect(kv_caches)
if self.kv_format == KVCacheFormat.UNDEFINED:
raise ValueError("Could not detect KV cache format.")
ref_tensor = (
kv_caches[0][0] if self.kv_format.is_separate_format() else kv_caches[0]
)
self.kv_device = ref_tensor.device
first_layer_cache = kv_caches[0]
# flash attention: [num_layers, 2, num_blocks,
# block_size, num_heads, head_size]
if self.kv_format == KVCacheFormat.SEPARATE_KV:
key_tensor = first_layer_cache[0]
value_tensor = first_layer_cache[1]
assert key_tensor.shape == value_tensor.shape, (
f"Key and Value tensors must have identical shapes, "
f"got key={key_tensor.shape}, value={value_tensor.shape}"
)
k_cache_shape_per_layer = key_tensor.shape
elif self.kv_format == KVCacheFormat.MERGED_KV:
assert (
first_layer_cache.shape[0] == 2 or first_layer_cache.shape[1] == 2
), (
"MERGED_KV format should have shape [num_layers, 2, num_blocks, "
"block_size, num_heads, head_size] or "
"[num_layers, num_blocks, 2, block_size, num_heads, head_size]"
f"Got shape: {first_layer_cache.shape}"
)
# Flash Attention: [2, num_blocks, block_size, num_heads, head_size]
k_cache_shape_per_layer = first_layer_cache[0].shape
else:
raise ValueError(f"Unsupported KV cache format: {self.kv_format}")
self.vllm_two_major = True
max_tokens = k_cache_shape_per_layer[0] * k_cache_shape_per_layer[1]
num_elements = k_cache_shape_per_layer.numel() * 2
gpu_buffer_size = num_elements * self.element_size
logger.info(
f"Lazily initializing GPU buffer:\n"
f" - Format: {self.kv_format.name}\n"
f" - Key cache shape per layer: {k_cache_shape_per_layer}\n"
f" - Max tokens: {max_tokens}\n"
f" - gpu_buffer_size: {gpu_buffer_size / (1024 * 1024)} MB"
)
self.gpu_buffer_allocator = GPUMemoryAllocator(
gpu_buffer_size, device=self.device
)
def _prepare_transfer_context(self, kwargs) -> torch.Tensor:
"""
Initialize context for KV cache transfer, validate required
parameters and lazy init buffer.
"""
self.initialize_kvcaches_ptr(**kwargs)
if self.kvcaches is None:
raise ValueError("kvcaches should be provided in kwargs or initialized.")
if "slot_mapping" not in kwargs:
raise ValueError("'slot_mapping' should be provided in kwargs.")
self._lazy_initialize_buffer(self.kvcaches)
return kwargs["slot_mapping"]
def _get_full_slot_mapping(
self,
slot_mapping: torch.Tensor,
starts: List[int],
ends: List[int],
mode: str = "slice",
) -> tuple[torch.Tensor, int]:
"""
Generate full continuous slot mapping tensor and calculate total token count.
Supports two modes for different transfer directions (to/from GPU).
"""
if mode == "slice":
slot_mapping_full = slot_mapping[starts[0] : ends[-1]]
elif mode == "concat":
slot_mapping_chunks = [
slot_mapping[s:e] for s, e in zip(starts, ends, strict=False)
]
slot_mapping_full = torch.cat(slot_mapping_chunks, dim=0)
else:
raise ValueError(
f"Unsupported slot mapping mode: {mode}, only 'slice'/'concat' allowed"
)
num_tokens = len(slot_mapping_full)
return slot_mapping_full, num_tokens
def _allocate_gpu_buffers(
self, num_tokens: int, count: int = 1
) -> Union[object, list[object]]:
"""
Allocate specified number of GPU buffers for KV cache with shape
calculated by token count. Performs strict assertion checks for
valid buffer allocation.
"""
buffer_shape = self.get_shape(num_tokens)
assert self.gpu_buffer_allocator is not None, (
"GPU buffer allocator not initialized"
)
buffers = []
for _ in range(count):
buf_obj = self.gpu_buffer_allocator.allocate(
buffer_shape, self.dtype, MemoryFormat.KV_2TD
)
assert buf_obj is not None, "Failed to allocate GPU buffer in GPUConnector"
assert buf_obj.tensor is not None, "GPU buffer object has no valid tensor"
buffers.append(buf_obj)
return buffers[0] if count == 1 else buffers
@_lmcache_nvtx_annotate
def batched_to_gpu(self, starts: List[int], ends: List[int], **kwargs):
"""
This function is a generator that moves the KV cache from the memory
objects to buffer GPU memory. In each iteration i, it (1) loads the KV
cache of layer i from CPU -> GPU buffer, (2) recovers the positional
encoding of the layer i-1's KV cache in the GPU buffer, and (3)
moves the KV cache of layer i-2 from GPU buffer to paged GPU memory.
In total, this the generator will yield num_layers + 2 times.
:param starts: The starting indices of the KV cache in the corresponding
token sequence.
:param ends: The ending indices of the KV cache in the corresponding
token sequence.
"""
slot_mapping = self._prepare_transfer_context(kwargs)
if self.fused_rotary_emb is None and self.cache_positions:
# TODO(Jiayi): Make this more elegant
self.lmc_model = LMCBlenderBuilder.get(ENGINE_NAME).layerwise_model
self.fused_rotary_emb = self.lmc_model.fused_rotary_emb
slot_mapping_full, num_all_tokens = self._get_full_slot_mapping(
slot_mapping, starts, ends, mode="slice"
)
# compute gap positions
gap_mask = torch.ones(
num_all_tokens, dtype=torch.bool, device=slot_mapping_full.device
)
buf_offset = starts[0]
for start, end in zip(starts, ends, strict=False):
gap_mask[start - buf_offset : end - buf_offset] = False
self.current_gap_positions = torch.where(gap_mask)[0]
load_gpu_buffer_obj: Any = None
compute_gpu_buffer_obj: Any = None
compute_gpu_buffer_obj, load_gpu_buffer_obj = self._allocate_gpu_buffers(
num_all_tokens, count=2
)
if self.cache_positions:
new_positions_full = torch.arange(
starts[0], ends[-1], dtype=torch.int64, device=self.kv_device
)
old_positions_full = torch.zeros(
(num_all_tokens,), dtype=torch.int64, device=self.kv_device
)
for layer_id in range(self.num_layers + 2):
if layer_id > 1:
lmc_ops.single_layer_kv_transfer(
self.buffer_mapping[layer_id - 2].tensor,
self.kvcaches[layer_id - 2],
slot_mapping_full,
False,
self.kv_format.value,
False, # shape is [2, num_tokens, hidden_dim]
self.vllm_two_major,
)
del self.buffer_mapping[layer_id - 2]
logger.debug(f"Finished loading layer {layer_id - 2} into paged memory")
if layer_id > 0 and layer_id <= self.num_layers:
# NOTE: wait until both compute and load streams are done
torch.cuda.synchronize()
# ping-pong the buffers
compute_gpu_buffer_obj, load_gpu_buffer_obj = (
load_gpu_buffer_obj,
compute_gpu_buffer_obj,
)
if self.cache_positions:
assert compute_gpu_buffer_obj.tensor is not None
compute_gpu_buffer_obj.tensor[0] = self.fused_rotary_emb(
old_positions_full,
new_positions_full,
compute_gpu_buffer_obj.tensor[0],
)
# gap zeroing after RoPE
if self.current_gap_positions.numel():
compute_gpu_buffer_obj.tensor[:, self.current_gap_positions] = 0.0
self.buffer_mapping[layer_id - 1] = compute_gpu_buffer_obj
logger.debug(f"Finished loading layer {layer_id - 1} into buffer")
if layer_id < self.num_layers:
memory_objs_layer = yield
# memobj -> gpu_buffer
with torch.cuda.stream(self.load_stream):
for start, end, memory_obj in zip(
starts, ends, memory_objs_layer, strict=False
):
assert memory_obj.metadata.fmt == MemoryFormat.KV_2TD
assert load_gpu_buffer_obj.tensor is not None
load_gpu_buffer_obj.tensor[0][
start - buf_offset : end - buf_offset
].copy_(memory_obj.tensor[0], non_blocking=True)
load_gpu_buffer_obj.tensor[1][
start - buf_offset : end - buf_offset
].copy_(memory_obj.tensor[1], non_blocking=True)
if self.cache_positions and layer_id == 0:
old_positions_full[
start - buf_offset : end - buf_offset
] = memory_obj.metadata.cached_positions
elif layer_id == self.num_layers:
yield
# free the buffer memory
load_gpu_buffer_obj.ref_count_down()
compute_gpu_buffer_obj.ref_count_down()
assert len(self.buffer_mapping) == 0, (
"There are still layers in the buffer mapping after "
"releasing the GPU buffers."
)
yield
# TODO(Jiayi): Reduce repetitive operations in `batched_to_gpu`
# and `batched_from_gpu`.
@_lmcache_nvtx_annotate
def batched_from_gpu(
self,
memory_objs: Union[List[List[MemoryObj]], List[MemoryObj]],
starts: List[int],
ends: List[int],
**kwargs,
):
"""
This function is a generator that moves the KV cache from the paged GPU
memory to the memory objects. The first iteration will prepare some
related metadata and initiate the transfer in the first layer. In each
of the following iterations, it will first wait until the storing of
previous layer finishes, and then initiate string the KV cache of the
current layer one. The storing process of the KV cache is paged GPU
memory -> GPU buffer -> memory objects. The last iteration simply waits
for the last layer to finish.
In total, this the generator will yield num_layers + 1 times.
:param memory_objs: The memory objects to store the KV cache. The first
dimension is the number of layers, and the second dimension is the
number of memory objects (i.e., number of chunks) for each layer.
:param starts: The starting indices of the KV cache in the corresponding
token sequence.
:param ends: The ending indices of the KV cache in the corresponding
token sequence.
:raises ValueError: If 'kvcaches' is not provided in kwargs.
:raises ValueError: If 'slot_mapping' is not provided in kwargs.
"""
slot_mapping = self._prepare_transfer_context(kwargs)
buf_start = 0
buf_starts_ends = []
old_positions_chunks = []
for start, end in zip(starts, ends, strict=False):
buf_end = buf_start + end - start
buf_starts_ends.append((buf_start, buf_end))
buf_start = buf_end
if self.cache_positions:
old_positions_chunks.append(
torch.arange(start, end, device=self.kv_device, dtype=torch.int64)
)
slot_mapping_full, num_tokens = self._get_full_slot_mapping(
slot_mapping, starts, ends, mode="concat"
)
tmp_gpu_buffer_obj = self._allocate_gpu_buffers(num_tokens, count=1)
current_stream = torch.cuda.current_stream()
for layer_id in range(self.num_layers):
memory_objs_layer = memory_objs[layer_id]
# kvcaches -> gpu_buffer -> memobj
with torch.cuda.stream(self.store_stream):
self.store_stream.wait_stream(current_stream)
lmc_ops.single_layer_kv_transfer(
tmp_gpu_buffer_obj.tensor,
self.kvcaches[layer_id],
slot_mapping_full,
True,
self.kv_format.value,
False, # shape is [2, num_tokens, hidden_dim]
self.vllm_two_major,
)
for (buf_start, buf_end), memory_obj, old_positions in zip(
buf_starts_ends,
memory_objs_layer,
old_positions_chunks,
strict=False,
):
assert memory_obj.tensor is not None
memory_obj.tensor[0].copy_(
tmp_gpu_buffer_obj.tensor[0][buf_start:buf_end],
non_blocking=True,
)
memory_obj.tensor[1].copy_(
tmp_gpu_buffer_obj.tensor[1][buf_start:buf_end],
non_blocking=True,
)
if self.cache_positions:
memory_obj.metadata.cached_positions = old_positions
yield
self.store_stream.synchronize()
logger.debug(f"Finished offloading layer {layer_id}")
# free the buffer memory
tmp_gpu_buffer_obj.ref_count_down()
yield
class VLLMPagedMemNPUConnectorV2(VLLMPagedMemGPUConnectorV2):
def __init__(
self,
hidden_dim_size: int,
num_layers: int,
use_gpu: bool = False,
**kwargs,
):
"""
If use_gpu is true, it will create a gpu intermediate buffer. In this
case, it requires the following kwargs:
- chunk_size: The MAX size of the chunk to be copied to GPU.
- dtype: The data type of the intermediate buffer.
"""
super().__init__(hidden_dim_size, num_layers, use_gpu, **kwargs)
self.kv_format: KVCacheFormat = KVCacheFormat.UNDEFINED
if is_310p():
assert "num_kv_head" in kwargs, ("num_kv_head should be provided in 310p",)
assert "head_size" in kwargs, ("head_size should be provided in 310p",)
self.num_kv_head = kwargs["num_kv_head"]
self.head_size = kwargs["head_size"]
self.dtype = kwargs["dtype"]
self.device = kwargs["device"]
@classmethod
def from_metadata(
cls,
metadata: LMCacheEngineMetadata,
use_gpu: bool = False,
device: Optional[torch.device] = None,
) -> "VLLMPagedMemGPUConnectorV2":
"""Create a connector from LMCacheEngineMetadata.
Args:
metadata: The LMCache engine metadata containing model configuration.
use_gpu: Whether to use GPU intermediate buffer.
device: The device to use for the connector.
Returns:
A new instance of VLLMPagedMemGPUConnectorV2.
"""
# Extract parameters from metadata
# kv_shape: (num_layer, 2 or 1, chunk_size, num_kv_head, head_size)
num_layers = metadata.kv_shape[0]
chunk_size = metadata.kv_shape[2]
num_kv_head = metadata.kv_shape[3]
head_size = metadata.kv_shape[4]
hidden_dim_size = num_kv_head * head_size
return cls(
hidden_dim_size=hidden_dim_size,
num_layers=num_layers,
use_gpu=use_gpu,
chunk_size=chunk_size,
dtype=metadata.kv_dtype,
device=device,
use_mla=metadata.use_mla,
num_kv_head=num_kv_head,
head_size=head_size,
)
def _initialize_pointers(self, kv_caches: List[torch.Tensor]) -> torch.Tensor:
self.kv_format = KVCacheFormat.detect(kv_caches, use_mla=self.use_mla)
if self.kv_format == KVCacheFormat.UNDEFINED:
raise ValueError(
"Undefined KV cache format detected. "
"Unable to determine the format of input kv_caches."
)
if self.kv_format.is_separate_format():
self.kvcaches_device = kv_caches[0][0].device
else:
self.kvcaches_device = kv_caches[0].device
assert self.kvcaches_device.type == "npu", "The device should be Ascend NPU."
idx = self.kvcaches_device.index
if idx in self.kv_cache_pointers_on_gpu:
return self.kv_cache_pointers_on_gpu[idx]
if self.kv_format == KVCacheFormat.SEPARATE_KV:
self.kv_size = 2
pointers_list = []
for k, v in kv_caches:
pointers_list.append(k.data_ptr())
pointers_list.append(v.data_ptr())
self.kv_cache_pointers = torch.empty(
self.num_layers * self.kv_size, dtype=torch.int64, device="cpu"
)
else:
self.kv_size = 1
pointers_list = [t.data_ptr() for t in kv_caches]
self.kv_cache_pointers = torch.empty(
self.num_layers, dtype=torch.int64, device="cpu"
)
self.kv_cache_pointers.numpy()[:] = pointers_list
self.kv_cache_pointers_on_gpu[idx] = torch.empty(
self.kv_cache_pointers.shape, dtype=torch.int64, device=self.kvcaches_device
)
self.kv_cache_pointers_on_gpu[idx].copy_(self.kv_cache_pointers)
first_tensor = (
kv_caches[0][0] if self.kv_format.is_separate_format() else kv_caches[0]
)
if self.use_mla:
# kv_caches[0].shape: [num_pages, page_size, head_size]
# kv_caches[0].shape: [1, num_pages, page_size, head_size] (vllm-Ascend)
self.page_buffer_size = kv_caches[0].shape[-3] * kv_caches[0].shape[-2]
else:
if self.kv_format == KVCacheFormat.SEPARATE_KV:
# kv_caches[0]: [tuple(k,v)]
# 310P: [num_blocks, num_kv_heads * head_size // 16, block_size, 16]
# 910B: [num_blocks, block_size, num_kv_heads, head_size]
assert first_tensor.dim() >= 2
if is_310p():
self.block_size = first_tensor.shape[-2]
self.page_buffer_size = first_tensor.shape[0] * self.block_size
else:
self.page_buffer_size = (
first_tensor.shape[0] * first_tensor.shape[1]
)
elif self.kv_format == KVCacheFormat.MERGED_KV:
# kv_caches[0].shape: [2, num_pages, page_size, num_heads, head_size]
# 310P: [2, num_blocks, num_kv_heads * head_size // 16, block_size, 16]
# 910B: [2, num_blocks, block_size, num_kv_heads, head_size]
assert first_tensor.dim() == 5
if is_310p():
self.block_size = first_tensor.shape[-2]
self.page_buffer_size = first_tensor.shape[1] * self.block_size
else:
self.page_buffer_size = (
first_tensor.shape[1] * first_tensor.shape[2]
)
return self.kv_cache_pointers_on_gpu[idx]
def to_gpu_310p(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
"""Expect a kwarg 'kvcaches' which is a nested tuple of K and V tensors.
The kvcaches should correspond to the "WHOLE token sequence".
Note:
1. This function expects the 'slot_mapping' is a "full slot mapping"
where it's length is the same as the whole token sequence.
2. In the case that there is prefix caching, slot_mapping will starts
with -1s until the end of the matched prefix. The start and end
should NEVER overlap with the prefix caching (which means the
underlying CUDA kernel will never see -1 in slot_mapping)
:raises ValueError: If 'kvcaches' is not provided in kwargs.
:raises AssertionError: If the memory object does not have a tensor.
:raises ValueError: If 'slot_mapping' is not provided in kwargs.
"""
assert memory_obj.tensor is not None
self.initialize_kvcaches_ptr(**kwargs)
assert self.kvcaches is not None, (
"kvcaches should be provided in kwargs or initialized beforehand."
)
if self.use_mla:
if memory_obj.metadata.fmt != MemoryFormat.KV_MLA_FMT:
raise ValueError(
"The memory object should be in KV_MLA_FMT format in"
" order to be processed by VLLMPagedMemNPUConnector."
)
else:
if memory_obj.metadata.fmt != MemoryFormat.KV_2LTD:
raise ValueError(
"The memory object should be in KV_2LTD format "
"in order to be processed by VLLMPagedMemNPUConnector."
)
if "slot_mapping" not in kwargs:
raise ValueError("'slot_mapping' should be provided in kwargs.")
slot_mapping: torch.Tensor = kwargs["slot_mapping"]
kv_cache_pointers = self._initialize_pointers(self.kvcaches)
tmp_gpu_buffer = torch.empty(
memory_obj.tensor.size(), dtype=self.dtype, device=self.device
)
tmp_gpu_buffer.copy_(memory_obj.tensor)
lmc_ops.multi_layer_kv_transfer_310p(
tmp_gpu_buffer,
kv_cache_pointers,
slot_mapping[start:end],
self.kvcaches_device,
self.page_buffer_size,
False,
self.use_mla,
self.num_kv_head,
self.head_size,
self.block_size,
self.kv_format.value, # 1:MERGED_KV / 2:SEPARATE_KV
)
def from_gpu_310p(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
"""Expect a kwarg 'kvcaches' which is a nested tuple of K and V tensors.
The kvcaches should correspond to the "WHOLE token sequence".
Will set the memory_obj.metadata.fmt to MemoryFormat.KV_2LTD.
Note:
1. This function expects the 'slot_mapping' is a "full slot mapping"
where it's length is the same as the whole token sequence.
2. In the case that there is prefix caching, slot_mapping will starts
with -1s until the end of the matched prefix. The start and end
should NEVER overlap with the prefix caching (which means the
underlying CUDA kernel will never see -1 in slot_mapping)
:raises ValueError: If 'kvcaches' is not provided in kwargs,
:raises AssertionError: If the memory object does not have a tensor.
:raises ValueError: If 'slot_mapping' is not provided in kwargs.
"""
assert memory_obj.tensor is not None
self.initialize_kvcaches_ptr(**kwargs)
assert self.kvcaches is not None, (
"kvcaches should be provided in kwargs or initialized beforehand."
)
if "slot_mapping" not in kwargs:
raise ValueError("'slot_mapping' should be provided in kwargs.")
slot_mapping: torch.Tensor = kwargs["slot_mapping"]
kv_cache_pointers = self._initialize_pointers(self.kvcaches)
assert self.gpu_buffer.device == self.kvcaches_device
tmp_gpu_buffer = torch.empty(
memory_obj.tensor.size(), dtype=self.dtype, device=self.device
)
lmc_ops.multi_layer_kv_transfer_310p(
tmp_gpu_buffer,
kv_cache_pointers,
slot_mapping[start:end],
self.kvcaches_device,
self.page_buffer_size,
True,
self.use_mla,
self.num_kv_head,
self.head_size,
self.block_size,
self.kv_format.value, # 1:MERGED_KV / 2:SEPARATE_KV
)
memory_obj.tensor.copy_(tmp_gpu_buffer)
if self.use_mla:
memory_obj.metadata.fmt = MemoryFormat.KV_MLA_FMT
def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
"""Expect a kwarg 'kvcaches' which is a nested tuple of K and V tensors.
The kvcaches should correspond to the "WHOLE token sequence".
Note:
1. This function expects the 'slot_mapping' is a "full slot mapping"
where it's length is the same as the whole token sequence.
2. In the case that there is prefix caching, slot_mapping will starts
with -1s until the end of the matched prefix. The start and end
should NEVER overlap with the prefix caching (which means the
underlying CUDA kernel will never see -1 in slot_mapping)
:raises ValueError: If 'kvcaches' is not provided in kwargs.
:raises AssertionError: If the memory object does not have a tensor.
:raises ValueError: If 'slot_mapping' is not provided in kwargs.
"""
assert memory_obj.tensor is not None
self.initialize_kvcaches_ptr(**kwargs)
assert self.kvcaches is not None, (
"kvcaches should be provided in kwargs or initialized beforehand."
)
if self.use_mla:
if memory_obj.metadata.fmt != MemoryFormat.KV_MLA_FMT:
raise ValueError(
"The memory object should be in KV_MLA_FMT format in"
" order to be processed by VLLMPagedMemNPUConnector."
)
else:
if memory_obj.metadata.fmt != MemoryFormat.KV_2LTD:
raise ValueError(
"The memory object should be in KV_2LTD format in "
" order to be processed by VLLMPagedMemNPUConnector."
)
if "slot_mapping" not in kwargs:
raise ValueError("'slot_mapping' should be provided in kwargs.")
slot_mapping: torch.Tensor = kwargs["slot_mapping"]
kv_cache_pointers = self._initialize_pointers(self.kvcaches)
lmc_ops.multi_layer_kv_transfer(
memory_obj.tensor,
kv_cache_pointers,
slot_mapping[start:end],
self.kvcaches_device,
self.page_buffer_size,
False,
self.use_mla,
self.kv_format.value, # 1:MERGED_KV / 2:SEPARATE_KV
)
def from_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
"""Expect a kwarg 'kvcaches' which is a nested tuple of K and V tensors.
The kvcaches should correspond to the "WHOLE token sequence".
Will set the memory_obj.metadata.fmt to MemoryFormat.KV_2LTD.
Note:
1. This function expects the 'slot_mapping' is a "full slot mapping"
where it's length is the same as the whole token sequence.
2. In the case that there is prefix caching, slot_mapping will starts
with -1s until the end of the matched prefix. The start and end
should NEVER overlap with the prefix caching (which means the
underlying CUDA kernel will never see -1 in slot_mapping)
:raises ValueError: If 'kvcaches' is not provided in kwargs,
:raises AssertionError: If the memory object does not have a tensor.
:raises ValueError: If 'slot_mapping' is not provided in kwargs.
"""
assert memory_obj.tensor is not None
self.initialize_kvcaches_ptr(**kwargs)
assert self.kvcaches is not None, (
"kvcaches should be provided in kwargs or initialized beforehand."
)
if "slot_mapping" not in kwargs:
raise ValueError("'slot_mapping' should be provided in kwargs.")
slot_mapping: torch.Tensor = kwargs["slot_mapping"]
kv_cache_pointers = self._initialize_pointers(self.kvcaches)
if self.kv_format == KVCacheFormat.UNDEFINED:
raise ValueError("KV cache format is not initialized!")
with torch.cuda.stream(self.store_stream):
# No staging buffer or token count mismatch
if self.gpu_buffer is None or end - start != self.gpu_buffer.shape[2]:
lmc_ops.multi_layer_kv_transfer(
memory_obj.tensor,
kv_cache_pointers,
slot_mapping[start:end],
self.kvcaches_device,
self.page_buffer_size,
True,
self.use_mla,
self.kv_format.value, # 1:MERGED_KV / 2:SEPARATE_KV
)
else:
assert self.gpu_buffer.device == self.kvcaches_device
tmp_gpu_buffer = self.gpu_buffer[:, :, : end - start, :]
lmc_ops.fused_multi_layer_kv_transfer(
memory_obj.tensor, # dst: CPU buffer
tmp_gpu_buffer, # staging cache
kv_cache_pointers, # src: paged KV cache
slot_mapping[start:end],
self.kvcaches_device,
self.page_buffer_size,
True, # from_gpu
self.use_mla,
self.kv_format.value, # 1:MERGED_KV / 2:SEPARATE_KV
)
if not memory_obj.tensor.is_cuda:
# Force a synchronize if the target buffer is NOT CUDA device
# NOTE: for better performance, we may not want to sync for every
# memory object
self.store_stream.synchronize()
if self.use_mla:
memory_obj.metadata.fmt = MemoryFormat.KV_MLA_FMT
def batched_to_gpu(self, memory_objs, starts, ends, **kwargs):
# Check if any memory objects are ProxyMemoryObjs (deferred P2P fetch)
has_proxy = any(isinstance(m, ProxyMemoryObj) for m in memory_objs)
if has_proxy:
assert not is_310p(), "Batched P2P transfer is not supported on 310P."
self._remote_batched_to_gpu(memory_objs, starts, ends, **kwargs)
else:
with torch.cuda.stream(self.load_stream):
for memory_obj, start, end in zip(
memory_objs, starts, ends, strict=False
):
if is_310p():
self.to_gpu_310p(memory_obj, start, end, **kwargs)
else:
self.to_gpu(memory_obj, start, end, **kwargs)
self.load_stream.synchronize()
def _clear_proxy_batch(self, batch) -> None:
"""Clear the backing objects of the proxy batch."""
for proxy, _, _ in batch:
proxy.clear_backing_obj()
return None
def _scatter_proxy_batch(self, batch, event, **kwargs):
"""Wait for a read event, scatter proxies to KV cache.
Enqueues work on ``load_stream``. The caller is responsible for
recording a scatter-done event afterwards if needed for
cross-stream synchronization.
"""
if event is not None:
self.load_stream.wait_event(event)
with torch.cuda.stream(self.load_stream):
for proxy, start, end in batch:
self.to_gpu(proxy.backing_obj, start, end, **kwargs)
def _remote_batched_to_gpu(self, memory_objs, starts, ends, **kwargs):
"""Handle batched_to_gpu when ProxyMemoryObjs are present.
Uses a ping-pong pipeline with **event-based** cross-stream
synchronization to overlap remote data fetching (on the HCCL
transport stream) with KV cache scatter (on the load stream).
Two pools of PIPELINE_DEPTH buffers are allocated from the
transfer context's registered memory and alternated (ping-pong).
This limits peak memory to 2 x PIPELINE_DEPTH chunks regardless
of the total number of proxy objects.
After all proxy objects are processed, sends the Done signal
to release the remote peer's pinned resources.
"""
transfer_contexts: Set[AscendBaseTransferContext] = set()
# Separate proxy and non-proxy items
proxy_items = []
non_proxy_items = []
for memory_obj, start, end in zip(memory_objs, starts, ends, strict=False):
if isinstance(memory_obj, ProxyMemoryObj):
transfer_contexts.add(memory_obj.transfer_context)
proxy_items.append((memory_obj, start, end))
else:
non_proxy_items.append((memory_obj, start, end))
if proxy_items:
# Get the transfer context for buffer allocation
first_ctx = proxy_items[0][0].transfer_context
# Derive pipeline depth from NPU buffer capacity so that
# two full ping-pong pools fit in registered memory.
pipeline_depth = first_ctx.max_pipeline_depth
logger.debug(
"P2P pipeline depth = %d (proxy_items=%d)",
pipeline_depth,
len(proxy_items),
)
# Allocate ping-pong buffer pools.
# Initialized to None so the finally block can safely skip
# release if allocation itself fails.
pool_size = min(pipeline_depth, len(proxy_items))
pool_a = None
pool_b = None
try:
pool_a = first_ctx.allocate_buffers(pool_size)
pool_b = first_ctx.allocate_buffers(pool_size)
pools = [pool_a, pool_b]
current_pool = 0
# Group proxy items into micro-batches
micro_batches = [
proxy_items[i : i + pipeline_depth]
for i in range(0, len(proxy_items), pipeline_depth)
]
prev_read_event = None
prev_batch = None
# Per-pool scatter-done events: prevent the next RDMA
# write into a pool from racing with a scatter that is
# still reading from the same pool on load_stream.
# Events are pre-allocated and re-recorded each iteration.
channel = proxy_items[0][0]._transfer_channel
transport_stream = getattr(
channel, "transport_stream", None
)
pool_scatter_events = [
torch.npu.Event(),
torch.npu.Event(),
]
pool_scatter_recorded = [False, False]
for batch_idx, batch in enumerate(micro_batches):
pool = pools[current_pool]
# Ensure the previous scatter from this pool has
# finished before RDMA overwrites the pool buffers.