-
Notifications
You must be signed in to change notification settings - Fork 722
Expand file tree
/
Copy pathcommon_engine.py
More file actions
2405 lines (2151 loc) · 115 KB
/
common_engine.py
File metadata and controls
2405 lines (2151 loc) · 115 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import asyncio
import copy
import json
import multiprocessing
import os
import re
import signal
import subprocess
import sys
import threading
import time
import traceback
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
import requests
import zmq
from tqdm import tqdm
import fastdeploy.metrics.trace as tracing
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import (
ControlRequest,
ControlResponse,
Request,
RequestOutput,
RequestStatus,
RequestType,
)
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.engine.sched.scheduler_metrics_logger import SchedulerMetricsLogger
from fastdeploy.eplb.utils import init_eplb_signals
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
IPCLock,
IPCSignal,
ZmqIpcServer,
ZmqTcpServer,
)
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.router.utils import check_service_health
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
try:
TokenProcessor = load_token_processor_plugins()
llm_logger.info(f"TokenProcessor plugin {TokenProcessor} loaded")
except:
from fastdeploy.output.token_processor import TokenProcessor
class EngineService:
"""
Base class containing common engine functionality
"""
def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False):
"""
Initializes the LLMEngine with the provided configuration.
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
self.use_async_llm = use_async_llm
if self.cfg.parallel_config.data_parallel_size > 1:
self.llm_logger = get_logger(
"fastdeploy", f"fastdeploy_dprank{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.llm_logger = llm_logger
self.is_paused = False # pause request generation
self._pause_cond = threading.Condition()
self._ctrl_output_queues = {}
tp_size = cfg.parallel_config.tensor_parallel_size
dp_index = cfg.parallel_config.local_data_parallel_id
for tp_rank in range(tp_size):
# create worker control response queue
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
# create cache control response queue
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}"
self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.llm_logger.info("Use V1 KVCache Scheduler")
self.resource_manager = ResourceManagerV1(
cfg.scheduler_config.max_num_seqs,
cfg,
cfg.parallel_config.tensor_parallel_size,
cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id,
)
else:
self.llm_logger.info("Use V0 KVCache Scheduler")
self.resource_manager = ResourceManager(
cfg.scheduler_config.max_num_seqs,
cfg,
cfg.parallel_config.tensor_parallel_size,
cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id,
)
self.start_worker_queue_service(start_queue)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.parallel_config.local_engine_worker_queue_port)
self.llm_logger.info(f"INFERENCE_MSG_QUEUE_ID: {str(self.cfg.parallel_config.local_engine_worker_queue_port)}")
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.scheduler_metrics_logger = SchedulerMetricsLogger(
enabled=True,
dp_rank=self.cfg.parallel_config.local_data_parallel_id,
)
self.resource_manager.scheduler_metrics_logger = self.scheduler_metrics_logger
self.token_processor.set_scheduler_metrics_logger(self.scheduler_metrics_logger)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (
(self.cfg.scheduler_config.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self.bos_client = None
self.mm_max_tokens_per_item = None
self.guided_decoding_checker = None
if self.cfg.structured_outputs_config.guided_decoding_backend != "off":
self.guided_decoding_checker = schema_checker(
self.cfg.structured_outputs_config.guided_decoding_backend,
disable_any_whitespace=self.cfg.structured_outputs_config.disable_any_whitespace,
)
self._init_worker_monitor_signals()
# Pass the GPU KV cache lock to cache_manager for mutual exclusion
# between the CPU transfer process and the worker process.
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock
if self.cfg.eplb_config.enable_eplb:
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
init_eplb_signals(cfg, current_suffix)
if self.use_async_llm:
# Add worker management attributes
self.worker_proc = None
self.do_profile = 1 if self.cfg.cache_config.num_gpu_blocks_override is None else 0
self.ipc_signal_suffix = None
self.cache_manager_processes = None
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self, async_llm_pid=None):
self.running = True
console_logger.debug("Start engineService...")
if self.use_async_llm:
self.start_worker_service(async_llm_pid)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(
target=self._schedule_request_to_worker_v1, daemon=True
)
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role == "decode":
self._decode_process_splitwise_requests()
self._register_to_router()
def start_worker_service(self, async_llm_pid=None):
# Initialize IPC signals for worker management
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
self._init_worker_signals()
# Create data processor if not exists
if not hasattr(self, "data_processor"):
self.create_data_processor()
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
# If block number is specified and model is deployed in splitwise mode, start cache manager first
if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
# Start worker processes
self.worker_proc = self._start_worker_service()
time.sleep(5)
self.worker_init_status = dict()
result_container = {}
def check_worker_initialize_status_func(res: dict):
res["worker_is_alive"] = True
if not self.check_worker_initialize_status():
self.llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
res["worker_is_alive"] = False
self.check_worker_initialize_status_func_thread = threading.Thread(
target=check_worker_initialize_status_func, args=(result_container,), daemon=True
)
self.check_worker_initialize_status_func_thread.start()
# Wait model loading
while self.loaded_model_signal.value[0] == 0:
# Make sure worker process is alive
if not self.check_worker_initialize_status_func_thread.is_alive():
return False
time.sleep(1)
# If block number is not specified, let workers do profiling to determine the block number,
# and then start the cache manager
if self.do_profile:
self._stop_profile()
elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching:
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
# Worker launched
self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]:
self.llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
return False
# Start ZMQ service for communication with AsyncLLM
if async_llm_pid:
self.start_zmq_service(async_llm_pid)
def create_data_processor(self):
self.input_processor = InputPreprocessor(
self.cfg.model_config,
self.cfg.structured_outputs_config.reasoning_parser,
self.cfg.limit_mm_per_prompt,
self.cfg.mm_processor_kwargs,
self.cfg.tool_parser,
)
self.data_processor = self.input_processor.create_processor()
self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item(
self.cfg.model_config.max_model_len
)
if self.mm_max_tokens_per_item is not None:
max_chunk_tokens = self.cfg.get_max_chunk_tokens(self.mm_max_tokens_per_item)
self.cfg.cache_config.postprocess(max_chunk_tokens, self.cfg.scheduler_config.max_num_seqs)
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
self.llm_logger.info(f"current_suffix: {current_suffix}")
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=exist_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=exist_swapped_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_prefill_task_signal = IPCSignal(
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=worker_healthy_live_recorded_time_array,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
cache_transfer_inited_signal_data = np.zeros(
shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32
)
self.cache_transfer_inited_signal = IPCSignal(
name="cache_transfer_inited_signal",
array=cache_transfer_inited_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
name="model_weights_status",
array=model_weights_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# gpu_cache_lock: file-based lock for mutual exclusion between worker
# and CPU transfer when accessing GPU KV cache.
self.gpu_cache_lock = IPCLock(
name="gpu_cache_lock",
suffix=current_suffix,
create=True,
)
def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
"""
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock"
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
if start_queue:
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
# Dynamically updates the port value if an anonymous port is used
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
self.cfg.parallel_config.local_engine_worker_queue_port = (
self.engine_worker_queue_server.get_server_port()
)
address = (
self.cfg.master_ip,
self.cfg.parallel_config.local_engine_worker_queue_port,
)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.llm_logger.info(
f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}"
)
self.cache_task_queue = EngineCacheQueue(
address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port),
authkey=b"cache_queue_service",
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port()
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks: List[Request], current_id=-1):
"""
Allocate resource and insert tasks to engine.
Used in v0_kvcache_scheduler.
"""
if not isinstance(tasks, list):
tasks = [tasks]
self.resource_manager.check_and_free_block_tables()
need_delete_tasks = []
for task in tasks:
rid = task.request_id.split("_")[0]
trace_carrier = task.trace_carrier
if trace_carrier:
tracing.trace_set_proc_propagate_context(rid, trace_carrier)
task.trace_carrier = tracing.trace_get_proc_propagate_context(rid)
if self.cfg.scheduler_config.splitwise_role == "prefill":
status, msg = self.split_connector.check_decode_allocated(task)
if status:
task.metrics.ask_decode_resource_finish_time = time.time()
else:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
for item in tasks:
trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, item.request_id, getattr(item, "user", ""))
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
self.token_processor.number_of_tasks += len(tasks)
is_decode = False
is_prefill = False
for i in range(len(tasks)):
if tasks[i].disaggregate_info is not None:
if self.cfg.scheduler_config.splitwise_role == "decode":
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.split_connector.send_cache_info_to_messager(tasks, current_id)
elif self.cfg.scheduler_config.splitwise_role == "decode":
self.split_connector.send_cache_info_to_prefill(tasks)
if not is_decode:
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
if not getattr(task, "has_been_preempted_before", False):
task.metrics.inference_start_time = time.time()
tracing.trace_report_span(
tracing.TraceSpanName.SCHEDULE,
task.request_id.split("_")[0],
int(task.metrics.scheduler_recv_req_time * 1e9),
int(task.metrics.inference_start_time * 1e9),
thread_finish_flag=True,
)
trace_print(LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
else:
trace_print(
LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "")
)
if not is_prefill:
if not self.cfg.model_config.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
return True
def _insert_prefilled_requests(self, request_outputs: List[RequestOutput]):
"""
Decode insert prefilled requests into engine worker queue.
Used in v0_kvcache_scheduler.
Args:
request_outputs: a list of RequestOutput sent by prefill instance
"""
to_infer_reqs = []
for req_out in request_outputs:
solt_idx = self.resource_manager.req_dict[req_out.request_id]
del self.resource_manager.req_dict[req_out.request_id]
cur_req = self.resource_manager.tasks_list[solt_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not req_out.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[solt_idx] = True
self.resource_manager.tasks_list[solt_idx] = None
self.resource_manager._recycle_block_tables(cur_req)
if req_out.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[req_out.request_id]
self.llm_logger.warning(f"{req_out.request_id} need not decode after first token")
continue
cur_req.prompt_token_ids[0] = req_out.outputs.token_ids[0]
cur_req.num_cached_tokens = req_out.num_cached_tokens
req_out.metrics.decode_recv_req_time = cur_req.metrics.decode_recv_req_time
req_out.metrics.decode_preallocate_req_time = cur_req.metrics.decode_preallocate_req_time
cur_req.metrics = req_out.metrics
cur_req.metrics.decode_inference_start_time = time.time()
if (
self.cfg.speculative_config.method == SpecMethod.MTP
and self.cfg.scheduler_config.splitwise_role == "decode"
):
cur_req.draft_token_ids = copy.deepcopy(req_out.outputs.draft_token_ids)
if req_out.error_code != 200:
self.resource_manager.stop_flags[solt_idx] = True
self.resource_manager.tasks_list[solt_idx] = None
self.resource_manager._recycle_block_tables(cur_req)
if req_out.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[req_out.request_id]
self.scheduler.put_results([req_out])
self.llm_logger.warning(
f"{req_out.request_id} prefill failed with msg:{req_out.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[req_out.request_id] = 1
to_infer_reqs.append(cur_req)
if to_infer_reqs:
self.engine_worker_queue.put_tasks((to_infer_reqs, self.resource_manager.real_bsz))
self.llm_logger.debug(f"put requests to engine worker queue, task:{to_infer_reqs}")
return True
def task_is_finished(self, index):
"""
judge if the task is finished
"""
assert index < len(self.resource_manager.stop_flags)
return self.resource_manager.stop_flags[index]
def all_tasks_finished(self):
"""
judge if all tasks are finished
"""
return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
def update_requests_chunk_size(self, requests):
"""
update each request's chunk size info
"""
def update_tokens(idx, chunk_size, update_chunk=False):
nonlocal remain_batched_tokens, chunk_request_num
if update_chunk:
requests_chunk[idx][-1] += chunk_size
else:
requests_chunk[idx].append(chunk_size)
remain_batched_tokens -= chunk_size
current_request_size[idx] -= chunk_size
if current_request_size[idx] <= 0:
chunk_request_num -= 1
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
current_request_size = [request.prompt_token_ids_len for request in requests]
requests_chunk = [[] for _ in range(len(requests))]
chunk_request_num = len(current_request_size)
while chunk_request_num >= 1:
remain_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
for idx in range(len(current_request_size)):
if current_request_size[idx] <= 0:
continue
chunk_size = min(
current_request_size[idx],
self.partial_chunked_tokens[chunk_request_num],
)
update_tokens(idx, chunk_size)
while remain_batched_tokens >= self.cfg.cache_config.block_size:
# 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求
waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0]
if len(waiting_requests) == 0:
break
available_tokens = (
remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
)
append_idx = current_request_size.index(min(waiting_requests))
chunk_size = min(
current_request_size[append_idx],
self.partial_chunked_tokens[chunk_request_num],
available_tokens,
)
update_tokens(append_idx, chunk_size, update_chunk=True)
for idx in range(len(requests)):
requests[idx].set("prefill_chunk_info", requests_chunk[idx])
def update_mm_requests_chunk_size(self, requests):
"""
update each multimodal request's chunk size info
"""
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
for request in requests:
inputs = request.multimodal_inputs
# 兼容没有图片和视频的情况
if inputs["images"] is None:
inputs["image_type_ids"] = np.array([], dtype="int32")
inputs["grid_thw"] = np.array([], dtype="int64")
inputs["images"] = np.array([], dtype="uint8")
input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32")
image_mask = input_ids == self.data_processor.image_patch_id
image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32")
image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32"), dtype="int32")
grid_thw = []
for one in inputs["grid_thw"]:
if one[0] == 1:
grid_thw.append(one)
else:
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse
chunk_image_num, chunk_seq_len = get_mm_split_fuse(
input_ids,
image_type_ids,
image_token_sum,
grid_thw,
self.data_processor.image_patch_id,
len(grid_thw),
0,
len(input_ids),
0,
self.partial_chunked_tokens[1],
2048,
)
grid_thw = grid_thw.numpy().reshape([-1, 3])
num_chunks = len(chunk_image_num)
chunks_info = []
input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0
for idx in range(num_chunks):
chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0])
chunk_image_type_ids = inputs["image_type_ids"][
image_type_ids_st : image_type_ids_st + actual_image_num
]
chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]]
chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1))
chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num]
chunk_position_ids = inputs["position_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
chunks_info.append(
{
"input_ids": chunk_input_ids,
"token_type_ids": chunk_token_type_ids,
"image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None),
"grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None),
"images": (chunk_images if chunk_images.shape[0] else None),
"position_ids": chunk_position_ids,
}
)
input_ids_st += chunk_seq_len[idx]
image_type_ids_st += actual_image_num
grid_thw_st += chunk_image_num[idx]
patch_st += chunk_patch_num
request.set("prefill_chunk_info", chunks_info)
def _schedule_request_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
tracing.trace_set_thread_info("Scheduler Task to Work")
current_id = 0
while getattr(self, "running", True):
try:
if self.resource_manager.available_batch() == 0:
time.sleep(0.001)
continue
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0:
if (
self.cfg.scheduler_config.splitwise_role == "mixed"
or self.split_connector.has_splitwise_tasks()
):
time.sleep(0.005)
continue
if self.engine_worker_queue.num_cache_infos() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
batch=num_prefill_batch,
)
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
for task in tasks:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role == "decode":
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
# so the same request sent by the decode api server will be ignored
continue
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks:
task.metrics.ask_decode_resource_start_time = time.time()
self.split_connector.send_splitwise_tasks(tasks, current_id)
insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}."
self.llm_logger.error(err_msg)
def _schedule_request_to_worker_v1(self):
"""
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
"""
tracing.trace_set_thread_info("Scheduler Task to Work")
get_request_pool = ThreadPoolExecutor(max_workers=1)
is_fetching = False
def _fetch_request():
try:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
nonlocal is_fetching
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
else:
max_num_batched_tokens = self.cfg.model_config.max_model_len
# In multi-mode scenarios, using available_block_num to pull requests to prevent heavy rescheduling
# in the frequency domain due to insufficient blocks
if self.cfg.model_config.enable_mm:
self.resource_manager.check_and_free_block_tables()
available_blocks = self.resource_manager.available_block_num()
else:
available_blocks = self.cfg.cache_config.max_block_num_per_seq
tasks = self.scheduler.get_requests(
available_blocks=available_blocks,
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
for task in tasks:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
if self.cfg.scheduler_config.splitwise_role == "decode":
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
# so the same request sent by the decode api server will be ignored
is_fetching = False
return
if tasks:
self.llm_logger.debug(
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
)
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
need_delete_tasks = []
if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.debug(
f"P has allocated resources and then ask D resource for request: {task.request_id}"
)
task.metrics.ask_decode_resource_start_time = time.time()
while True:
self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(
f"D failed to allocate resource for request {task.request_id}, try again."
)
time.sleep(0.05)
else:
task.metrics.ask_decode_resource_finish_time = time.time()
break
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
else:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.debug(
f"P has allocated resources and then ask D resource for req_id: {task.request_id}"
)
task.metrics.ask_decode_resource_start_time = time.time()
self.split_connector.send_splitwise_tasks([task], task.idx)
for task in tasks:
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
task.metrics.ask_decode_resource_finish_time = time.time()
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# to send cache info to cache messager
if tasks:
need_check_req_ids = [task.request_id for task in tasks]
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
finished_ids, delete_tasks_list = [], []
while need_check_req_ids:
finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req())
self.llm_logger.debug(
f"P has successfully sent cache infos to cache messager for requests: {finished_ids}"
)
if finished_ids:
for task in tasks:
result = self.resource_manager.waiting_async_process(task)
if result is None:
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=task.error_code,
error_msg=task.error_message,
)
]
)
need_check_req_ids.remove(task.request_id)
delete_tasks_list.append(task)
elif result is False:
if task.request_id in finished_ids:
need_check_req_ids.remove(task.request_id)
finished_ids.remove(task.request_id)
else:
time.sleep(0.001)
for tmp_task in delete_tasks_list:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:
task.metrics.add_req_to_resource_manager_time = time.time()
trace_print(
LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")
)
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks)
self.llm_logger.info(
f"P add requests into running queue: {[task.request_id for task in tasks]}"
)
else:
for task in tasks:
self.resource_manager.add_request(task)