-
Notifications
You must be signed in to change notification settings - Fork 723
Expand file tree
/
Copy pathgpu_model_runner.py
More file actions
2969 lines (2671 loc) · 141 KB
/
gpu_model_runner.py
File metadata and controls
2969 lines (2671 loc) · 141 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import os
import queue
import time
from concurrent.futures import Future
from threading import Thread
from typing import Any, Dict, List, Optional, cast
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.request import ImagePosition, Request, RequestType
from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
sot_warmup_guard,
)
from fastdeploy.model_executor.guided_decoding import (
LogitsProcessorBase,
get_guided_backend,
)
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.append_attn_backend import (
allocate_launch_related_buffer,
)
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
)
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
recover_decode_task,
set_data_ipc,
set_value_by_flags_and_idx,
)
share_external_data = None
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx
recover_decode_task = None
share_external_data = None
else:
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_value_by_flags_and_idx,
share_external_data,
speculate_schedule_cache,
set_data_ipc,
unset_data_ipc,
)
import zmq
from fastdeploy import envs
from fastdeploy.engine.tasks import PoolingTask
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.logger.deterministic_logger import DeterministicLogger
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.model_executor.pre_and_post_process import (
async_set_value,
post_process,
pre_process,
rebuild_padding,
save_output_normal,
)
from fastdeploy.output.pooler import PoolerOutput
from fastdeploy.worker.model_runner_base import (
DistributedOut,
DistributedStatus,
ModelRunnerBase,
)
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput
class GPUModelRunner(ModelRunnerBase):
def __init__(
self,
fd_config: FDConfig,
device: str, # logic device
device_id: int, # physical device id
rank: int,
local_rank: int,
):
super().__init__(fd_config=fd_config, device=device)
self.MAX_INFER_SEED = 9223372036854775806
self.enable_mm = self.model_config.enable_mm
self.rank = rank
self.local_rank = local_rank
self.device_id = device_id
self.spec_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.spec_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size
self.max_logprobs = None
if self.enable_logprob:
self.max_logprobs = (
self.ori_vocab_size
if fd_config.model_config.max_logprobs == -1
else fd_config.model_config.max_logprobs
)
self.temp_scaled_logprobs = True
self.top_p_normalized_logprobs = True
self.prompt_logprobs_reqs: dict[str, Request] = {}
self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {}
self.forward_batch_reqs_list: list[Request] = [None for _ in range(self.scheduler_config.max_num_seqs)]
self.cache_kvs_map: dict = {}
self.exist_prefill_flag = False
if self.speculative_decoding:
self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory()
self.output_token_num_event = paddle.device.cuda.Event()
# VL model config:
if self.enable_mm:
if "ernie" in self.fd_config.model_config.model_type:
self._init_image_preprocess()
self.amp_black = [
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div",
"sin",
"cos",
"sort",
"multinomial",
]
self.amp_white = [
"lookup_table",
"lookup_table_v2",
"flash_attn",
"matmul",
"matmul_v2",
"fused_gemm_epilogue",
]
if self.cache_config.max_encoder_cache > 0:
self.encoder_cache: dict[str, paddle.Tensor] = {}
else:
self.encoder_cache = None
# Sampler
if not self.speculative_decoding:
self.sampler = Sampler(fd_config)
else:
self.sampler = SpeculativeSampler(fd_config)
self.guided_backend = None
if self.fd_config.structured_outputs_config.guided_decoding_backend != "off":
self.guided_backend = get_guided_backend(fd_config=self.fd_config)
self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser())
# Lazy initialize kv cache after model loading
# self.kv_caches: list[paddle.Tensor] = []
# CUDA Graph
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.cudagraph_capture_sizes_prefill = list(reversed(self.graph_opt_config.cudagraph_capture_sizes_prefill))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
# Initialize input batch
self.share_inputs = InputBatch(self.fd_config)
self.share_inputs.init_share_inputs()
self.increment_value = (
4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4
)
self.infer_seed_increment = paddle.full(
shape=[self.scheduler_config.max_num_seqs, 1], fill_value=self.increment_value, dtype="int64", device="cpu"
)
self.restore_chunked_prefill_request = dict()
# Initialize deterministic logger (only when deterministic debugging is enabled)
self.deterministic_logger = (
DeterministicLogger(self.share_inputs)
if envs.FD_DETERMINISTIC_MODE and envs.FD_DETERMINISTIC_LOG_MODE
else None
)
# Initialize attention Backend
# NOTE(gonshaotian): Currently, all attention layers share one attention backend instance.
# In the future, we will expand it as a list.
self.attn_backends: list[AttentionBackend] = []
# self.attn_metadatas: list[AttentionMetadata] = []
self._initialize_attn_backend()
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None
# Postprocess Env params
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.local_engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.local_engine_worker_queue_port)}")
# Rollout routing replay config
self.routing_replay_manager = None
self.zmq_client = None
self.async_output_queue = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
port = self.fd_config.parallel_config.local_engine_worker_queue_port
logger.info(f"zmq client get_save_output_rank{local_rank}_{port}")
self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}_{port}", mode=zmq.PUSH)
self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000
self.async_output_queue: queue.Queue = queue.Queue()
self.async_output_copy_thread = Thread(
target=self._async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy",
)
self.async_output_copy_thread.start()
self.enable_entropy = self.model_config.enable_entropy
# init signal
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
# for overlap
self._cached_model_output_data = None
self._cached_sampler_output = None
self._cached_post_process_event = None
# Cached token count for next batch prediction in overlap scheduling.
# Used to avoid synchronization overhead when preparing inputs for the next batch.
self._cached_launch_token_num = -1
self.enable_overlap_schedule = fd_config.scheduler_config.enable_overlap_schedule and (
not self.speculative_decoding
)
if self.enable_overlap_schedule:
logger.info("Using overlap schedule")
self.current_launch_token_num = 0
def _async_output_busy_loop(self):
"""Entrypoint for the thread which handles outputs asynchronously."""
while True:
try:
output = self.async_output_queue.get()
self.zmq_client.send_pyobj(output)
except Exception as e:
logger.exception("Exception in async output loop: %s", e)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.exist_prefill_flag
def exist_decode(self):
"""
check whether decode stage exist
"""
seq_lens_decoder = self.share_inputs["seq_lens_decoder"]
stop_flags = self.share_inputs["stop_flags"].squeeze(1)
return ((seq_lens_decoder > 0) & ~stop_flags).any().cpu().numpy().item()
def _resolve_current_launch_token_num(
self, cached_token_num: int, token_num_event, is_dummy_or_profile_run: bool
) -> int:
"""
Resolve token count for current batch.
In overlap mode, uses cached value from previous batch prediction to avoid GPU-CPU sync.
Falls back to fresh computation in certain conditions:
- dummy/profile runs need accurate counts
- non-overlap mode doesn't support caching
- prefill stage changes batch composition
- invalid cached value
"""
if (
is_dummy_or_profile_run
or (not self.enable_overlap_schedule)
or self.exist_prefill()
or cached_token_num <= 0
):
token_num_event.synchronize()
return self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item()
return cached_token_num
def _predict_next_launch_token_num(self) -> int:
"""
Predict token count for next batch.
In overlap scheduling, while current batch executes model forward,
the scheduler may have prepared decode requests for next batch.
This prediction allows next batch to skip synchronization.
Returns -1 if prediction is not applicable (non-overlap or prefill exists).
"""
if self.exist_prefill():
return -1
return (
self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item()
+ self.share_inputs["is_block_step_cpu"].numpy().sum().item()
)
def only_prefill(self):
"""
check whether prefill only
"""
if_only_prefill = True
decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_prefill_batch_list = []
decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
if_only_prefill = all(only_prefill_batch_list)
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
return if_only_prefill
def collect_distributed_status(self):
"""
Collect distributed status
"""
dist_status_list = []
dist_status_obj = DistributedStatus()
dist_out = DistributedOut()
prefill_exists = None
if_only_decode = True
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
prefill_exists = self.exist_prefill()
dist_status_obj.only_decode = not prefill_exists
# whether chunked moe
if self.fd_config.parallel_config.enable_chunked_moe:
chunk_size = self.fd_config.parallel_config.chunked_moe_size
token_num = self.share_inputs["ids_remove_padding"].shape[0]
if token_num > chunk_size:
self.forward_meta.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size
else:
self.forward_meta.moe_num_chunk = 1
dist_status_obj.moe_num_chunk = self.forward_meta.moe_num_chunk
# only ep need to collect and sync distributed status
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
# call once to gather all status
paddle.distributed.all_gather_object(dist_status_list, dist_status_obj)
# Update Batch type for cuda graph for if_only_decode
if_only_decode = all(dist_status.only_decode for dist_status in dist_status_list)
if_only_decode = if_only_decode and not (
prefill_exists if prefill_exists is not None else self.exist_prefill()
)
max_moe_num_chunk = None
if self.fd_config.parallel_config.enable_chunked_moe:
max_moe_num_chunk = max(dist_status.moe_num_chunk for dist_status in dist_status_list)
dist_out = DistributedOut(
if_only_decode=if_only_decode,
max_moe_num_chunk=max_moe_num_chunk,
)
return dist_out
def only_decode(self):
"""
check whether decode only
"""
# Update Batch type for cuda graph for if_only_decode
if_only_decode = True
prefill_exists = None
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
if_only_decode = all(only_decode_batch_list)
if_only_decode = if_only_decode and not (
prefill_exists if prefill_exists is not None else self.exist_prefill()
)
return if_only_decode
def _init_speculative_proposer(self):
"""
Init speculative proposer
"""
if self.spec_method is None:
self.proposer = None
return
# MTP-specific: swap seq_lens_this_time to the buffer tensor
if self.spec_method == SpecMethod.MTP:
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
self.proposer = self.spec_method.create_proposer(
self.fd_config,
main_model=self.get_model(),
local_rank=self.local_rank,
device_id=self.device_id,
share_inputs=self.share_inputs,
)
def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]:
"""
init logits processor for guided decoding
"""
assert self.guided_backend is not None, (
"guided_backend is None, use " "--guided-decoding-backend to specify the backend at server startup."
)
if request.guided_json is not None:
schemata_key = ("json", request.guided_json)
elif request.guided_regex is not None:
schemata_key = ("regex", request.guided_regex)
elif request.guided_grammar is not None:
schemata_key = ("grammar", request.guided_grammar)
elif request.structural_tag is not None:
schemata_key = ("structural_tag", request.structural_tag)
return (
self.guided_backend.get_logits_processor(
schemata_key=schemata_key,
enable_thinking=False, # TODO cfg
),
schemata_key,
)
def _process_mm_features(self, request_list: List[Request]):
"""
Process and cache vision features from model
- add image_features, extract and cache vision features from model
- add rope_emb, rotate position embeddings
"""
if not self.enable_mm:
return
self.share_inputs["image_features_list"] = [-1] * self.scheduler_config.max_num_seqs
img_index = 0
req_idx_img_index_map = {}
multi_vision_inputs = {
"images_lst": [],
"grid_thw_lst": [],
"vit_position_ids_lst": [],
"cu_seqlens": [0],
"encoder_cache_info": [],
"feature_position_list": [],
"grid_thw_lst_batches": [],
"feature_position_list_batches": [],
}
rope_3d_position_ids = {
"position_ids_idx": [],
"position_ids_lst": [],
"position_ids_offset": [0],
"max_tokens_lst": [],
}
for request in request_list:
if request.task_type.value != RequestType.PREFILL.value:
continue
if self.encoder_cache is not None:
evict_mm_hashes = request.get("evict_mm_hashes", None)
if evict_mm_hashes:
for mm_hash in evict_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
position_ids = request.multimodal_inputs["position_ids"]
idx = self.share_inputs.get_index_by_batch_id(request.idx)
rope_3d_position_ids["position_ids_idx"].append(idx)
req_idx_img_index_map[idx] = -1
rope_3d_position_ids["position_ids_lst"].append(position_ids)
rope_3d_position_ids["position_ids_offset"].append(
position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1]
)
if self.is_pooling_model:
rope_3d_position_ids["max_tokens_lst"].append(0)
else:
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
if request.with_image:
req_idx_img_index_map[idx] = img_index
img_index = img_index + 1
inputs = request.multimodal_inputs
if self.encoder_cache is not None:
if envs.FD_ENABLE_MAX_PREFILL:
if "vit_seqlen" in inputs:
vit_seqlen_list = inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
if "vit_position_ids" in inputs:
vit_position_ids_list = inputs["vit_position_ids"][
request.num_image_start : request.num_image_end
]
grid_thw_list = inputs["grid_thw"][request.num_image_start : request.num_image_end]
mm_hashes_list = inputs["mm_hashes"][request.num_image_start : request.num_image_end]
feature_positions = self._get_feature_positions(
mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end],
prefill_start_index=request.prefill_start_index,
prefill_end_index=request.prefill_end_index,
)
image_start_idx = request.num_image_start
logger.debug(
f"request {request.request_id} start process encoder info, image_start_idx: {image_start_idx} "
f"grid_thw_list: {grid_thw_list}, feature_positions: {feature_positions}, mm_hashes_list: {mm_hashes_list}"
)
encoder_cache_info_per_req = []
grid_thw_lst_per_req = []
for i, mm_hash in enumerate(mm_hashes_list):
image_offset = np.prod(grid_thw_list[i])
logger.debug(
f"run idx {i} with mm_hash {mm_hash} image_offset: {image_offset} grid_thw: {grid_thw_list[i]}"
)
if mm_hash in self.encoder_cache:
encoder_cache_info_per_req.append((mm_hash, feature_positions[i], True))
continue
encoder_cache_info_per_req.append((mm_hash, feature_positions[i], False))
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][image_start_idx : image_start_idx + image_offset].to(self.device)
)
multi_vision_inputs["grid_thw_lst"].append(paddle.to_tensor(grid_thw_list[i]))
grid_thw_lst_per_req.append(paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64))
multi_vision_inputs["cu_seqlens"].append(vit_seqlen_list[i])
multi_vision_inputs["vit_position_ids_lst"].append(vit_position_ids_list[i])
else:
multi_vision_inputs["images_lst"].append(
paddle.to_tensor(
inputs["images"][image_start_idx : image_start_idx + image_offset],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
)
multi_vision_inputs["grid_thw_lst"].append(
paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64)
)
grid_thw_lst_per_req.append(paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64))
image_start_idx += image_offset
multi_vision_inputs["grid_thw_lst_batches"].append(grid_thw_lst_per_req)
multi_vision_inputs["encoder_cache_info"].append(encoder_cache_info_per_req)
else:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].to(self.device)
)
multi_vision_inputs["grid_thw_lst"].extend(
paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end])
)
multi_vision_inputs["grid_thw_lst_batches"].append(
paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end])
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
multi_vision_inputs["images_lst"].append(
paddle.to_tensor(
inputs["images"][request.image_start : request.image_end],
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
)
)
multi_vision_inputs["grid_thw_lst"].extend(
paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end],
dtype=paddle.int64,
)
)
multi_vision_inputs["grid_thw_lst_batches"].append(
paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end],
dtype=paddle.int64,
)
)
multi_vision_inputs["feature_position_list"].extend(
self._get_feature_positions(
mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end],
prefill_start_index=request.prefill_start_index,
prefill_end_index=request.prefill_end_index,
)
)
multi_vision_inputs["feature_position_list_batches"].append(
self._get_feature_positions(
mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end],
prefill_start_index=request.prefill_start_index,
prefill_end_index=request.prefill_end_index,
)
)
if self.encoder_cache is not None:
if len(multi_vision_inputs["images_lst"]) > 0 or len(multi_vision_inputs["encoder_cache_info"]) > 0:
image_features_output = None
if len(multi_vision_inputs["images_lst"]) > 0:
image_features_output = self.extract_vision_features(multi_vision_inputs)
logger.debug(f"encoder_cache_info: {multi_vision_inputs['encoder_cache_info']}")
feature_idx = 0
image_features_list = []
for index, encoder_cache_info in enumerate(multi_vision_inputs["encoder_cache_info"]):
merge_image_features, thw_idx = [], 0
for mm_hash, feature_position, use_cache in encoder_cache_info:
if use_cache:
assert mm_hash in self.encoder_cache, f"{mm_hash} not in encoder cache"
mm_feature = self.encoder_cache[mm_hash].cuda()
else:
assert (
image_features_output is not None
), f"image_features_output is None, images_lst length: {len(multi_vision_inputs['images_lst'])}"
grid_thw = multi_vision_inputs["grid_thw_lst_batches"][index][thw_idx]
mm_token_lenght = inputs["mm_num_token_func"](grid_thw=grid_thw)
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
# add feature to encoder cache
self.encoder_cache[mm_hash] = mm_feature.detach().cpu()
feature_idx += mm_token_lenght
thw_idx += 1
feature_start = feature_position.offset
feature_end = feature_position.offset + feature_position.length
merge_image_features.append(mm_feature[feature_start:feature_end])
image_features_list.append(paddle.concat(merge_image_features, axis=0))
for idx, index in req_idx_img_index_map.items():
if index != -1:
self.share_inputs["image_features_list"][idx] = image_features_list[index]
elif len(multi_vision_inputs["images_lst"]) > 0:
image_features_output = self.extract_vision_features(multi_vision_inputs)
image_features_list = []
feature_idx = 0
for index, feature_position_item in enumerate(multi_vision_inputs["feature_position_list_batches"]):
grid_thw_lst = multi_vision_inputs["grid_thw_lst_batches"][index]
assert len(feature_position_item) == len(grid_thw_lst), f"{feature_position_item} != {grid_thw_lst}"
merge_image_features, thw_idx = [], 0
for feature_position in feature_position_item:
grid_thw = grid_thw_lst[thw_idx]
mm_token_lenght = inputs["mm_num_token_func"](grid_thw=grid_thw)
mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght]
feature_start = feature_position.offset
feature_end = feature_position.offset + feature_position.length
merge_image_features.append(mm_feature[feature_start:feature_end])
feature_idx += mm_token_lenght
thw_idx += 1
image_features_list.append(paddle.concat(merge_image_features, axis=0))
for idx, index in req_idx_img_index_map.items():
if index != -1:
self.share_inputs["image_features_list"][idx] = image_features_list[index]
if len(rope_3d_position_ids["position_ids_idx"]) > 0:
packed_position_ids = paddle.to_tensor(
np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64"
)
rope_3d_lst = self.prepare_rope3d(
packed_position_ids,
rope_3d_position_ids["max_tokens_lst"],
rope_3d_position_ids["position_ids_offset"],
)
for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]):
self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i]
def _get_feature_positions(
self, mm_positions: List[ImagePosition], prefill_start_index: int, prefill_end_index: int
):
"""
Filter and adjust ImagePosition objects that fall within the specified prefill range.
Args:
mm_positions: List of ImagePosition objects to filter
prefill_start_index: Start index of the prefill range
prefill_end_index: End index of the prefill range
Returns:
List of ImagePosition objects that are within or intersect with the prefill range
"""
feature_positions = []
for position in mm_positions:
position_start = position.offset
position_end = position.offset + position.length
if position_end <= prefill_start_index or position_start >= prefill_end_index:
continue
elif position_start >= prefill_start_index and position_end <= prefill_end_index:
new_position = copy.deepcopy(position)
new_position.offset = 0
feature_positions.append(new_position)
else:
new_position = copy.deepcopy(position)
# Adjust offset if it starts before prefill_start_index
if position_start < prefill_start_index:
new_position.offset = prefill_start_index - position_start
new_position.length = min(position_end, prefill_end_index) - prefill_start_index
# Adjust length if it extends beyond prefill_end_index
elif position_end > prefill_end_index:
new_position.offset = 0
new_position.length = prefill_end_index - position_start
feature_positions.append(new_position)
logger.debug(
f"get feature_positions, original positions: {mm_positions}, filtered positions: {feature_positions}"
)
return feature_positions
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
num_running_requests: batch_size
"""
req_len = len(req_dicts)
batch_pooling_params = []
self.share_inputs["num_running_requests"] = num_running_requests
self.share_inputs["running_requests_ids"] = range(num_running_requests)
for i in range(req_len):
request = req_dicts[i]
idx = self.share_inputs.get_index_by_batch_id(request.idx)
self.share_inputs["req_ids"][idx] = str(request.request_id)
if hasattr(request, "pooling_params") and request.pooling_params is not None:
batch_pooling_params.append(request.pooling_params)
logits_info = None
prefill_tokens = []
if request.task_type.value == RequestType.PREFILL.value: # prefill task
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0
self.share_inputs["req_ids"][idx] = str(request.request_id)
# guided decoding
if (
request.guided_json is not None
or request.guided_regex is not None
or request.structural_tag is not None
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.schemata_key = schemata_key
if (
self.scheduler_config.splitwise_role == "decode"
and hasattr(request, "prefill_end_index")
and hasattr(request, "prompt_token_ids")
and request.prefill_end_index > len(request.prompt_token_ids)
and hasattr(request, "output_token_ids")
):
prefill_tokens.extend(request.output_token_ids)
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
if not self.is_pooling_model:
if request.get("enable_thinking") is not None:
enable_thinking = bool(request.get("enable_thinking"))
logger.debug(f"request {request.request_id} with {enable_thinking=} at idx {idx}")
self.share_inputs["enable_thinking"][idx : idx + 1, :] = enable_thinking
if enable_thinking:
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
if request.get("reasoning_max_tokens") is not None:
# Enable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get(
"reasoning_max_tokens"
)
else:
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
if request.get("response_max_tokens") is not None:
# Enable thinking
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = request.get(
"response_max_tokens"
)
else:
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
else:
# Disable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
input_ids = prompt_token_ids + request.output_token_ids
prompt_len = len(prompt_token_ids)
self.share_inputs["token_ids_all"][idx : idx + 1, :prompt_len] = np.array(
prompt_token_ids, dtype="int64"
)
self.share_inputs["token_ids_all"][idx : idx + 1, prompt_len:] = -1
# Log complete input_ids for input determinism verification
# Note: Only current request info is logged here; batch info is logged during forward
if self.deterministic_logger is not None:
self.deterministic_logger.log_prefill_input(
request.request_id, idx, prefill_start_index, prefill_end_index, input_ids
)
logger.debug(
f"Handle prefill request {request} at idx {idx}, "
f"{prefill_start_index=}, {prefill_end_index=}, "
f"need_prefilled_token_num={len(input_ids)}"
f"prompt_len={prompt_len}"
)
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index]
)
encoder_block_num = len(request.block_tables)
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.exist_prefill_flag = True
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
self.share_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
# pooling model request.sampling_params is None
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
self.prompt_logprobs_reqs[request.request_id] = request
self.forward_batch_reqs_list[idx] = request
if self.speculative_decoding and self.spec_method == SpecMethod.SUFFIX and self.proposer is not None:
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
self.proposer.start_request(idx, request.request_id, prompt_token_ids)
# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
# 1.prefix task(need regist) 2. chunkend task(not need regist)
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)
if (
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generate first token
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.exist_prefill_flag = False
self._cached_launch_token_num = -1
if self.speculative_decoding:
# D speculate decode, seq_lens_this_time = length + 1
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + 1
self.share_inputs["draft_tokens"][idx : idx + 1, 0 : length + 1] = paddle.to_tensor(
request.draft_token_ids[0 : length + 1],
dtype="int64",
)
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
encoder_block_num = len(request.block_tables)
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
if current_platform.is_cuda():
async_set_value(
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
)
else:
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0
continue
else: # preempted task
logger.info(f"Handle preempted request {request} at idx {idx}")
self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["stop_flags"][idx : idx + 1] = True
self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.prompt_logprobs_reqs.pop(request.request_id, None)
self.in_progress_prompt_logprobs.pop(request.request_id, None)
self.forward_batch_reqs_list[idx] = None
# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager.clear_request(batch_id=idx)
continue
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False)
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get(
"top_p_normalized_logprobs", False
)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
"max_tokens", self.model_config.max_model_len
)
if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
self.share_inputs["bad_tokens_len"][idx] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.sampling_params.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
request.sampling_params.stop_seqs_len, dtype="int32"
)
self.share_inputs["stop_seqs"][
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
] = np.array(request.get("stop_token_ids"), dtype="int64")
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
self.pooling_params = batch_pooling_params
# For logits processors
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
self._process_mm_features(req_dicts)
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
if self.spec_method == SpecMethod.MTP:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests, self.share_inputs.index_to_batch_id)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
raise NotImplementedError("GPUs only support KVCACHE SCHEDULER V1 in versions 2.6 and above.")
def get_input_length_list(
self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False
):
"""
Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp,
the list should be carefully constructed.
This function addresses a specific problem: in the pure prefill stage, variable
input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different
CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph
reuse.
The `split_q_block` kernel calculates the total number of blocks, which directly
determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`.
The blocks for a single sequence are determined by the formula:
`num_blocks = ceil((sequence_length * group_size) / block_shape_q)`