-
Notifications
You must be signed in to change notification settings - Fork 723
Expand file tree
/
Copy pathworker_process.py
More file actions
1173 lines (1035 loc) · 45.4 KB
/
worker_process.py
File metadata and controls
1173 lines (1035 loc) · 45.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import json
import os
import time
from typing import Tuple
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from fastdeploy import envs
from fastdeploy.config import (
CacheConfig,
DeviceConfig,
EarlyStopConfig,
EPLBConfig,
ErnieArchitectures,
FDConfig,
GraphOptimizationConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
PlasAttentionConfig,
RoutingReplayConfig,
SpeculativeConfig,
StructuredOutputsConfig,
)
from fastdeploy.engine.request import RequestType
from fastdeploy.eplb.async_expert_loader import (
MODEL_MAIN_NAME,
REARRANGE_EXPERT_MAGIC_NUM,
create_mmap,
load_tensor_from_shm_mem,
)
from fastdeploy.eplb.experts_manager import RedundantExpertManager
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import (
ExistTaskStatus,
IPCLock,
IPCSignal,
ModelWeightsStatus,
RearrangeExpertStatus,
)
from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.model_executor.utils import v1_loader_support
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import get_logger, optional_type
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log")
def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
"""
get worker of different device
"""
if fd_config.model_config.enable_logprob and not current_platform.is_cuda() and not current_platform.is_xpu():
raise NotImplementedError("Only CUDA and XPU platforms support logprob.")
if current_platform.is_dcu():
from fastdeploy.worker.dcu_worker import DcuWorker
return DcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_cuda():
from fastdeploy.worker.gpu_worker import GpuWorker
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_xpu():
from fastdeploy.worker.xpu_worker import XpuWorker
return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_iluvatar():
from fastdeploy.worker.iluvatar_worker import IluvatarWorker
return IluvatarWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_gcu():
from fastdeploy.worker.gcu_worker import GcuWorker
return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_maca():
from fastdeploy.worker.metax_worker import MetaxWorker
return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_intel_hpu():
from fastdeploy.worker.hpu_worker import HpuWorker
return HpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
"""Initialize Paddle Fleet and get rank of worker"""
# Global rank
ranks = dist.get_world_size()
dist_strategy = fleet.DistributedStrategy()
if ranks > 0:
dist_strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": ranks,
"pp_degree": 1,
"sharding_degree": 1,
}
# Set control in tensor parallel
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
fleet.init(is_collective=True, strategy=dist_strategy)
# Local rank
local_rank = fleet.worker_index()
else:
local_rank = 0
return ranks, local_rank
def update_fd_config_for_mm(fd_config: FDConfig) -> None:
architectures = fd_config.model_config.architectures
if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures):
fd_config.model_config.tensor_model_parallel_size = fd_config.parallel_config.tensor_parallel_size
fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype
class PaddleDisWorkerProc:
"""
Paddle Distributed wrapper for fastdeploy.worker.Worker,
for handling single-node multi-GPU tensor parallel.
The wrapper internally executes an event loop that continuously executes requests
in the task queue. Control flow is transmitted by IPC.
"""
def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) -> None:
"""
Initialize a distributed worker and task queue for single-node multi-GPU setup.
Args:
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
"""
self.ranks = ranks
self.local_rank = local_rank
self.fd_config = fd_config
self.parallel_config = fd_config.parallel_config
self.cache_config = fd_config.cache_config
self.scheduler_config = fd_config.scheduler_config
self.eplb_config = fd_config.eplb_config
# TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
def init_health_status(self) -> None:
"""
Initialize the health status of the worker.
Worker Status:
worker_ready_signal:
worker_healthy_live_signal:
exist_task_signal:
exist_swapped_task_signal:
model_weights_status:
"""
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if (
self.parallel_config.enable_expert_parallel
and self.parallel_config.data_parallel_size > 1
and not envs.FD_ENABLE_MULTI_API_SERVER
):
launched_expert_service_signal_data = np.zeros(
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while (
self.launched_expert_service_signal.value[
self.parallel_config.local_data_parallel_id % self.max_chips_per_node
]
== 0
):
pass
# init worker_ready_signal
array_size = min(
self.max_chips_per_node,
self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
)
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
array=workers_ready,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
# init worker_healthy_live_signal
workers_alive = np.zeros(shape=[min(array_size, self.parallel_config.tensor_parallel_size)], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=workers_alive,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
# init model_weights_status
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
self.model_weights_status = IPCSignal(
name="model_weights_status",
array=workers_model_weights,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# init kv_cache_status
kv_cache_status_data = np.zeros(shape=[1], dtype=np.int32)
self.kv_cache_status = IPCSignal(
name="kv_cache_status",
array=kv_cache_status_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# init exist_task_signal
workers_exist_task = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=workers_exist_task,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# init exist_swapped_task_signal
workers_swapped_task = np.zeros(shape=[1], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=workers_swapped_task,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# init exist_prefill_task_signal
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_prefill_task_signal = IPCSignal(
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# gpu_cache_lock: file-based lock for mutual exclusion between worker
# and CPU transfer when accessing GPU KV cache.
self.gpu_cache_lock = IPCLock(
name="gpu_cache_lock",
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
def update_weights_from_tensor(self, mmap_infos):
"""
update_weights_from_tensor
"""
import time
while True:
if self.experts_manager.tensor_infos is None:
time.sleep(0.1)
else:
break
state_dicts = load_tensor_from_shm_mem(self.experts_manager.tensor_infos, mmap_infos[MODEL_MAIN_NAME], logger)
rank_expert_list, logical_to_physical_map, expert_count = self.experts_manager.get_ep_rank_to_expert_id_list()
self.worker.get_model().redundant_table_manger.update_expert_rank_table(
rank_expert_list, logical_to_physical_map, expert_count
)
# TO BE FIXED
self.worker.get_model().update_state_dict(state_dicts)
self.experts_manager.tensor_infos = None
def _broadcast_model_weights_signal(self, src: int, group) -> int:
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
value = model_weights_signal_tensor.numpy()[0]
return int(value)
def _tp_barrier_wait(self):
if current_platform.is_xpu():
self.task_queue.worker_process_tp_barrier.wait()
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
def _init_eplb_signal(self):
if not self.eplb_config.enable_eplb:
return
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.last_dump_expert_workload_ts = 0
self.experts_manager = RedundantExpertManager(
rank=self.local_rank,
ep_size=self.ranks,
fd_config=self.fd_config,
ipc_signal_suffix=self.parallel_config.engine_worker_queue_port,
)
dp_ipc_signal_suffix = (
f"{self.parallel_config.engine_worker_queue_port}_dp{self.parallel_config.local_data_parallel_id}"
)
if local_rank == 0: # master rank0
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
self.signal_update_weight_from_tensor_array = IPCSignal(
name="signal_update_weight_from_tensor",
array=signal_update_weight_from_tensor,
dtype=np.int32,
suffix=dp_ipc_signal_suffix,
create=False,
)
rearrange_experts_status = np.zeros([1], dtype=np.int32)
self.rearrange_experts_signal = IPCSignal(
name="rearrange_experts_status",
array=rearrange_experts_status,
dtype=np.int32,
suffix=dp_ipc_signal_suffix,
create=False,
)
tp_ipc_signal_suffix = f"{dp_ipc_signal_suffix}_tp{local_rank}"
experts_token_stats = np.zeros(
(self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.moe_num_experts),
dtype=np.int32,
)
self.local_experts_token_stats_array = IPCSignal(
name="local_experts_token_stats",
array=experts_token_stats,
dtype=np.int32,
suffix=tp_ipc_signal_suffix,
create=False,
)
clear_experts_token_stats = np.zeros([1], dtype=np.int32)
self.signal_clear_experts_token_stats = IPCSignal(
name="signal_clear_experts_token_stats",
array=clear_experts_token_stats,
dtype=np.int32,
suffix=tp_ipc_signal_suffix,
create=False,
)
self.mmap_infos = create_mmap(
[MODEL_MAIN_NAME],
self.local_rank,
self.ranks,
shm_uuid=self.parallel_config.engine_worker_queue_port,
eplb_config=self.eplb_config,
logger=logger,
)
def _run_eplb(self, tp_rank):
"""internal call to run eplb"""
if not self.eplb_config.enable_eplb:
return
rearrange_time = time.time()
# Get expert load
if self.local_experts_token_stats_array.value is not None and (
int(rearrange_time) - self.last_dump_expert_workload_ts
> self.eplb_config.redundant_expert_dump_workload_interval
):
self.last_dump_expert_workload_ts = int(rearrange_time)
clear_stat = False
if self.signal_clear_experts_token_stats.value[0] == 1:
clear_stat = True
self.signal_clear_experts_token_stats.value[0] = 0
(
new_stats_array,
_,
_,
_,
) = self.worker.get_model().redundant_table_manger.get_expert_tokens_stats(clear_stat=clear_stat)
self.local_experts_token_stats_array.value[:] = new_stats_array[:]
elif self.local_experts_token_stats_array.value is None:
logger.warning("redundant_expert: local_experts_token_stats not init")
# All DP synchronously update weights
broadcast_value = 0
if tp_rank == 0 and self.signal_update_weight_from_tensor_array.value[0] == 1:
logger.info("redundant_expert: update_weight_from_tensor broadcast signal")
self.signal_update_weight_from_tensor_array.value[0] = 0
broadcast_value = REARRANGE_EXPERT_MAGIC_NUM
data = paddle.to_tensor([broadcast_value])
paddle.distributed.broadcast(data, 0)
if data[0] == REARRANGE_EXPERT_MAGIC_NUM:
self.update_weights_from_tensor(self.mmap_infos)
logger.info(
f"redundant_expert: update_weight_from_tensor success, cost {(time.time() - rearrange_time)*1000}ms"
)
paddle.distributed.barrier()
if tp_rank == 0:
self.rearrange_experts_signal.value[0] = RearrangeExpertStatus.DONE.value
logger.info("redundant_expert: done")
def _acquire_kvcache_lock(self, tp_rank):
"""Acquire the GPU KV cache lock for the worker process.
Uses a file-based lock (fcntl.flock) to ensure mutual exclusion
between the worker and the CPU transfer process during model
execution. Only rank 0 acquires the lock to avoid deadlock among
tensor-parallel workers.
Args:
tp_rank: Tensor parallel rank of the current worker. Only rank 0
acquires the lock.
"""
if not envs.FD_USE_KVCACHE_LOCK:
return
if tp_rank == 0:
self.gpu_cache_lock.acquire()
def _release_kvcache_lock(self, tp_rank):
"""Release the GPU KV cache lock held by the worker process.
Args:
tp_rank: Tensor parallel rank of the current worker. Only rank 0
releases the lock.
"""
if not envs.FD_USE_KVCACHE_LOCK:
return
if tp_rank == 0:
self.gpu_cache_lock.release()
def event_loop_normal(self) -> None:
"""Main event loop for Paddle Distributed Workers.
TODO(gongshaotian): support remote calling of functions that control worker.
"""
# init eplb signal
self._init_eplb_signal()
tp_size = self.parallel_config.tensor_parallel_size
# Currently, only support single node
self.nnode = int((tp_size + 7) // 8)
max_occupied_batch_index = 0
tp_rank = self.local_rank % tp_size
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
# run eplb
self._run_eplb(tp_rank)
if self.fd_config.load_config.dynamic_load_weight:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
self.insert_step = False
req_dicts = None
self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue
if tp_rank == 0:
if self.task_queue.num_tasks() > 0:
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
self.fd_config.model_config.enable_mm and self.worker.exist_prefill()
):
if self.nnode > 1 and tp_size > self.max_chips_per_node:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
if tp_size > 1:
# Synchronize the signal for other workers
self._tp_barrier_wait()
if self.fd_config.load_config.dynamic_load_weight:
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
if self.ranks > 1:
paddle.distributed.barrier()
logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
)
from fastdeploy.rl.dynamic_weight_manager import (
DynamicWeightManager,
)
self.model_weights_status.value[0] = self.model_weights_signal[0]
self.kv_cache_status.value[0] = self.model_weights_signal[0]
DynamicWeightManager.check_model_weights_status(
self.model_weights_status,
self.kv_cache_status if self.fd_config.cache_config.num_cpu_blocks > 0 else None,
# model_weights_signal
self.worker.model_runner,
self.parallel_config.engine_worker_queue_port,
self.parallel_config.shutdown_comm_group_if_worker_idle,
)
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
self.task_queue.clear_data()
if self.model_weights_signal[0] == ModelWeightsStatus.UPDATING:
logger.info(
f"Rank: {self.local_rank} has updated parameters. {self.model_weights_status.value[0]}"
)
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
elif self.model_weights_signal[0] == ModelWeightsStatus.CLEARING:
logger.info(
f"Rank: {self.local_rank} has cleared parameters. {self.model_weights_status.value[0]}"
)
# 如果清理权重后不关闭通信组,那么将推理进程统一阻塞在下面的循环中,否则信号量可能同步混乱;直到下次权重更新时唤醒
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
if self.ranks > 1: # 所有 Rank 同时入睡,监听下次的更新信号
paddle.distributed.barrier()
while self.model_weights_signal[0] != ModelWeightsStatus.UPDATING:
self.model_weights_signal[0] = self.model_weights_status.value[0]
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=None
)
time.sleep(1)
self.model_weights_status.value[0] = (
ModelWeightsStatus.UPDATING
) # 所有 Rank 已同步唤醒,启动权重更新流程
continue
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks()
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
self.task_queue.read_finish_flag.set(0)
req_dicts = []
for req_dict, bsz in tasks:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
# Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
num_scheduled_requests = len(req_dicts)
scheduled_request_ids = [req.request_id for req in req_dicts]
logger.info(
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
f"max_occupied_batch_index: {max_occupied_batch_index}, "
f"num_scheduled_requests: {num_scheduled_requests}, "
f"scheduled_request_ids: {scheduled_request_ids}"
)
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
if self.ranks > 1:
self._tp_barrier_wait()
time.sleep(0.001)
continue
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
start_execute_time = time.time()
self._acquire_kvcache_lock(tp_rank)
self.worker.execute_model(req_dicts, max_occupied_batch_index)
self._release_kvcache_lock(tp_rank)
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
if self.fd_config.parallel_config.do_profile:
# 1. Get available memory(bytes)
available_kv_cache_memory = self.worker.determine_available_memory()
logger.info(f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------")
# 2. Calculate the appropriate number of blocks
model_block_memory_used = self.worker.cal_theortical_kvcache()
num_blocks_local = int(available_kv_cache_memory // model_block_memory_used)
# NOTE(liuzichang): Too many block will lead to illegal memory access
# We will develop dynamic limits in future.
if num_blocks_local > 40000:
logger.info(f"------- Reset num_blocks_local {num_blocks_local} to 40000")
num_blocks_local = min(40000, num_blocks_local)
logger.info(f"------- model_block_memory_used:{model_block_memory_used / 1024**3} GB --------")
logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
if num_blocks_local <= 0:
raise ValueError(
f"The total number of blocks cannot be less than zero bug got {num_blocks_local}. "
"Please increase gpu_memory_utilization "
"Or decrease max_num_batched_tokens(max model length)."
)
if self.ranks > 1:
num_blocks_local = paddle.full(shape=[1], fill_value=num_blocks_local, dtype="int32")
dist.all_reduce(num_blocks_local, op=dist.ReduceOp.MIN)
num_blocks_local = num_blocks_local.item()
if self.local_rank % self.max_chips_per_node == 0:
# 3. Send IPCSignal
get_profile_block_num = np.zeros(shape=[1], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
self.get_profile_block_num_signal.value[0] = num_blocks_local
else:
num_blocks_local = self.fd_config.cache_config.total_block_num
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
def graph_optimize_and_warm_up_model(self) -> None:
self.worker.graph_optimize_and_warm_up_model()
# reset cache_messager prefilled_step signal
if not envs.ENABLE_V1_KVCACHE_SCHEDULER and self.scheduler_config.splitwise_role == "prefill":
gpu_id = self.worker.model_runner.device_id
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.local_rank}"
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
step_shm_value = IPCSignal(
name=prefilled_step_name, array=prefilled_step_idx_data, dtype=np.int32, suffix=gpu_id, create=False
)
step_shm_value.value[0] = -1
def init_device(self) -> None:
"""Initialize device and Construct model runner"""
self.worker.init_device()
def start_task_queue_service(self):
# Initialize task queue
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
task_address = (
self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port,
)
else:
task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.engine_worker_queue_port}.sock"
logger.info(f"connect task queue address {task_address}")
self.task_queue = TaskQueue(
address=task_address,
is_server=False,
num_client=self.parallel_config.tensor_parallel_size,
client_id=self.parallel_config.tensor_parallel_rank,
local_data_parallel_id=self.parallel_config.local_data_parallel_id,
)
def load_model(self) -> None:
"""Load weights and create model"""
self.worker.load_model()
loaded_model_signal_data = np.zeros(shape=[1], dtype=np.int32)
self.loaded_model_signal = IPCSignal(
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
if self.ranks > 1:
paddle.distributed.barrier()
self.loaded_model_signal.value[0] = 1
def parse_args():
"""
Parse args from command line
"""
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
parser.add_argument(
"-m",
"--model",
type=str,
default="./output",
help="model dir",
)
parser.add_argument("-mbs", "--max_num_seqs", type=int, default=34, help="max batch size")
parser.add_argument("--num_gpu_blocks_override", type=int, default=None)
parser.add_argument("--block_size", type=int, default=64)
parser.add_argument("--pod_ip", type=str, default="127.0.0.1")
parser.add_argument("--engine_worker_queue_port", type=str, default="9923")
parser.add_argument("--max_model_len", type=int, default=3072, help="max model len")
parser.add_argument("--device_ids", type=str, default="0", help="cuda visible devices")
parser.add_argument("--dtype", type=str, default="bfloat16", help="input dtype")
parser.add_argument("--enc_dec_block_num", type=int, default=1, help="encoder's decoder num")
parser.add_argument(
"--kv_cache_ratio",
type=float,
default=0.7,
help="kv cache ratio for input",
)
parser.add_argument("--first_token_id", type=int, default=1, help="first token id")
parser.add_argument(
"--gpu_memory_utilization",
type=float,
default=0.9,
help="gpu memory utilization",
)
parser.add_argument("--engine_pid", type=int, default=None, help="Process ID of engine")
parser.add_argument("--do_profile", action="store_true", help="do profile or not")
parser.add_argument("--pad_token_id", type=int, default=-1, help="pad token id")
parser.add_argument("--eos_tokens_lens", type=int, default=2, help="eos token lens")
parser.add_argument(
"--enable_chunked_prefill",
action="store_true",
help="enable chunked prefill",
)
parser.add_argument(
"--use_internode_ll_two_stage",
action="store_true",
help="enable internode_ll_two_stage",
)
parser.add_argument(
"--speculative_config",
type=json.loads,
default=None,
help="Configuration of SpeculativeConfig.",
)
parser.add_argument(
"--max_num_batched_tokens",
type=int,
default=2048,
help="max num batched tokens",
)
parser.add_argument(
"--enable_prefix_caching",
action="store_true",
help="enable prefix cache",
)
parser.add_argument(
"--disable_custom_all_reduce",
action="store_true",
help="enable custom all-reduce",
)
parser.add_argument(
"--disable_sequence_parallel_moe",
action="store_true",
help="disable sequence parallel moe",
)
parser.add_argument("--splitwise_role", type=str, default="mixed", help="splitwise role")
parser.add_argument(
"--tensor_parallel_size",
type=int,
default=1,
help="tensor parallel size",
)
parser.add_argument(
"--expert_parallel_size",
type=int,
default=1,
help="expert parallel size",
)
parser.add_argument(
"--data_parallel_size",
type=int,
default=1,
help="data parallel size",
)
parser.add_argument(
"--enable_expert_parallel",
action="store_true",
help="enable expert parallel",
)
parser.add_argument(
"--enable_chunked_moe",
action="store_true",
help="enable chunked moe",
)
parser.add_argument(
"--chunked_moe_size",
type=int,
default=256,
help="chunk size of moe input",
)
parser.add_argument("--ori_vocab_size", type=int, default=None)
parser.add_argument("--think_end_id", type=int, default=-1)
parser.add_argument("--image_patch_id", type=int, default=-1)
parser.add_argument("--line_break_id", type=int, default=-1)
parser.add_argument("--think_truncate_prompt_ids", type=json.loads, default=[])
parser.add_argument(
"--quantization",
type=json.loads,
default=None,
help="Quantization name for the model, currently support "
"'wint4', 'wint8',"
"default is None. The priority of this configuration "
"is lower than that of the config file. "
"More complex quantization methods need to be configured via the config file.",
)
parser.add_argument(
"--graph_optimization_config",
type=json.loads,
default=None,
help="Configuration of Graph optimization backend.",
)
parser.add_argument(
"--plas_attention_config",
type=json.loads,
default=None,
help="Configation of plas attention.",
)
parser.add_argument(
"--guided_decoding_backend",
type=str,
default="off",
help="guided decoding backend",
)
parser.add_argument(
"--disable_any_whitespace",
action="store_true",
help="Disable any whitespace for guided decoding.",
)
parser.add_argument(
"--dynamic_load_weight",
action="store_true",
help="Enable dynamic weight loading strategy",
)
parser.add_argument(
"--load_strategy",
type=str,
choices=["ipc", "ipc_snapshot", "meta", "normal"],
default="ipc_snapshot",
help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, "
"'ipc_snapshot': load from disk snapshot of IPC weights.",
)
parser.add_argument(
"--enable_logprob",
action="store_true",
help="Enable output of token-level log probabilities.",
)
parser.add_argument(
"--max_logprobs",
type=int,
default=20,
help="Maximum number of log probabilities.",
)
parser.add_argument(
"--logprobs_mode",
type=str,
default="raw_logprobs",
help="Indicates the content returned in the logprobs.",
)
parser.add_argument(
"--reasoning_parser",
type=str,
default=None,
help="Flag specifies the reasoning parser to use for extracting reasoning content from the model output",
)
parser.add_argument(
"--early_stop_config",
type=json.loads,
default=None,
help="Configuration of early stop.",
)
parser.add_argument(
"--load_choices",
type=str,
default="default_v1",
help="The format of the model weights to load. default/default_v1.",
)
parser.add_argument(
"--ips",
type=str,
default=None,
help="The ips of multinode deployment.",
)
parser.add_argument(
"--lm_head_fp32",
action="store_true",
help="Flag to specify dtype of lm_head as FP32",
)
parser.add_argument(
"--max_encoder_cache",
type=int,
help="Maximum encoder cache tokens(use 0 to disable).",
)
parser.add_argument(
"--cache-transfer-protocol",
type=str,
default="ipc",
help="support protocol list, comma separated, default is ipc",
)
parser.add_argument(
"--runner",
type=str,
default="auto",
help="The type of model runner to use.Each FD instance only supports one model runner.even if the same model can be used for multiple types.",
)
parser.add_argument(
"--convert",
type=str,
default="auto",
help="Convert the model using adapters. The most common use case is to adapt a text generation model to be used for pooling tasks.",
)
parser.add_argument(
"--override-pooler-config",
type=optional_type(json.loads),
default=None,
help="Override configuration for the pooler.",
)
parser.add_argument(
"--logits-processors",
type=str,
nargs="+",
default=[],
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
)
parser.add_argument(
"--eplb_config",
type=json.loads,
default=None,
help="EPLB Configuration.",
)
parser.add_argument(
"--routing_replay_config",
type=json.loads,
default=None,
help="Configation of Rollout Routing Replay.",
)
parser.add_argument(
"--shutdown_comm_group_if_worker_idle",
action="store_true",
help="Shutdown comm group if worker idle.",
)
parser.add_argument(
"--enable_entropy",
action="store_true",
help="Enable output of token-level entropy.",
)
parser.add_argument(
"--num_cpu_blocks",
type=int,
default=0,
help="Number of cpu blocks.",
)
parser.add_argument(
"--mm_max_tokens_per_item",
type=json.loads,
default=None,
help="Maximum tokens per item in mm input.",
)
args = parser.parse_args()
return args