-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathdynamic_engine.py
More file actions
2058 lines (1769 loc) · 90.7 KB
/
dynamic_engine.py
File metadata and controls
2058 lines (1769 loc) · 90.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import asyncio
import concurrent.futures
import logging
import multiprocessing
import socket
import struct
import time
import warnings
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
from itertools import repeat
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.cuda.nvtx import range_pop, range_push
from megatron.core.inference.config import KVCacheManagementMode
from megatron.core.inference.contexts.dynamic_context import (
DynamicInferenceContext,
MaxSequenceLengthOverflowError,
TokenOverflowError,
)
from megatron.core.inference.data_parallel_inference_coordinator import (
DataParallelInferenceCoordinator,
)
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.headers import Headers, UnknownHeaderError
from megatron.core.inference.inference_request import (
DynamicInferenceEvent,
DynamicInferenceEventType,
DynamicInferenceRequest,
DynamicInferenceRequestRecord,
Status,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.inference.utils import (
Counter,
await_process_call,
set_inference_cuda_graphed_iteration_for_ep_inference,
unset_inference_cuda_graphed_iteration_for_ep_inference,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
from megatron.core.utils import (
deprecate_args,
experimental_api,
get_asyncio_loop,
get_pg_rank,
get_pg_size,
get_pg_src_rank,
internal_api,
trace_async_exceptions,
)
from .async_zmq_communicator import AsyncZMQCommunicator
try:
from tqdm import tqdm
HAVE_TQDM = True
except:
HAVE_TQDM = False
try:
import zmq
HAVE_ZMQ = True
except:
HAVE_ZMQ = False
try:
import msgpack
HAVE_MSGPACK = True
except:
HAVE_MSGPACK = False
try:
import wandb
HAVE_WANDB = True
except ImportError:
HAVE_WANDB = False
wandb = None
try:
import psutil
HAVE_PSUTIL = True
except ImportError:
HAVE_PSUTIL = False
DEPRECATED_ARGS = [
"enable_cuda_graph",
"random_seed",
"track_paused_request_events",
"enable_chunked_prefill",
"inference_logging_step_interval",
"pg_collection",
]
class EngineState(Enum):
"""State machine for the inference engine."""
RUNNING = auto() # Processing requests
PAUSING = auto() # PAUSE received; waiting for EP consensus + world barrier
PAUSED = auto() # Globally confirmed idle
UNPAUSING = auto() # UNPAUSE received; waiting for world barrier
SUSPENDING = auto() # SUSPEND received; offloading GPU; waiting for world barrier
SUSPENDED = auto() # GPU offloaded, all ranks confirmed
RESUMING = auto() # RESUME received; onloading GPU; waiting for world barrier
RESUMED = auto() # GPU onloaded, all ranks confirmed; cleared on next SUSPEND
STOPPING = auto() # STOP received; futures cancelled; waiting for world barrier
STOPPED = auto() # All ranks confirmed; teardown complete
class EngineSuspendedError(Exception):
"""Engine is currently suspended and not performing steps."""
pass
def format_mem_bytes(mem_bytes):
"""Convert a byte count to a human-readable string in tb, gb, mb, kb, or bytes."""
for power, suffix in [(4, "tb"), (3, "gb"), (2, "mb"), (1, "kb"), (0, "bytes")]:
suffix_bytes = 1024**power
if mem_bytes >= suffix_bytes:
return "%.1f %s" % (mem_bytes / suffix_bytes, suffix)
return "%d bytes" % mem_bytes
@dataclass(kw_only=True)
class RequestEntry:
"""Entry in the engine's `self.requests` dict."""
record: DynamicInferenceRequestRecord
future: asyncio.Future
# pylint: disable=line-too-long
@experimental_api
class DynamicInferenceEngine(AbstractEngine):
"""The dynamic inference engine.
This engine allows requests of varying length to be dynamically added and
removed in each inference step. In contrast to the static engine that has a
set batch size and sequence length during the forward pass, each request in
the dynamic engine can have different *current* prompt and output length at
any given step, and the processing is restricted only by a max number of total
tokens across all requests.
Args:
text_generation_controller (TextGenerationController): A text generation
controller that will be used to define how to preprocess prompts, generate
outputs and detokenizer the output tokens.
inference_context (DynamicInferenceContext): Context for managing in-flight
batching and a dynamic block-level KV cache (similar to paged attention).
"""
# Map stable states to their corresponding asyncio events.
_STATE_EVENTS = (
EngineState.RUNNING,
EngineState.PAUSED,
EngineState.SUSPENDED,
EngineState.RESUMED,
EngineState.STOPPED,
)
@deprecate_args(
*DEPRECATED_ARGS,
message="Argument `{name}` has been deprecated. Only pass `controller` and `context`",
)
def __init__(self, controller: TextGenerationController, context: DynamicInferenceContext):
assert isinstance(
controller, TextGenerationController
), f"controller must be a TextGenerationController, got {type(controller)}"
assert isinstance(
context, DynamicInferenceContext
), f"context must be a DynamicInferenceContext, got {type(context)}"
model_config = controller.inference_wrapped_model.model.config
inference_config = context.config
if inference_config.pg_collection is not None:
self.pg_collection = inference_config.pg_collection
else:
self.pg_collection = ProcessGroupCollection.use_mpu_process_groups()
# Initialization options.
self.controller = controller
self.context = context
self.track_paused_request_events = inference_config.track_paused_request_events
self.track_generated_token_events = inference_config.track_generated_token_events
self.enable_chunked_prefill = inference_config.enable_chunked_prefill
self.metrics_writer = inference_config.metrics_writer
self.logging_step_interval = inference_config.logging_step_interval
self.unified_memory_level = inference_config.unified_memory_level
self.materialize_only_last_token_logits = (
inference_config.materialize_only_last_token_logits
)
self.cuda_graph_impl = model_config.cuda_graph_impl
self.cuda_graph_scope = model_config.cuda_graph_scope
# Initialize engine.
self.reset()
# Set callback for getting stop word finished request IDs
self.controller.set_stop_word_finished_ids_callback(
self._get_and_clear_stop_word_finished_ids
)
# Configure wandb to use separate step counter for inference metrics (only once)
if self.logging_step_interval > 0 and self.metrics_writer is not None:
logging.info(
f"\033[1;93m[INFERENCE]\033[0m "
f"\033[1;95mLogging inference metrics to wandb (rank {self.rank})\033[0m"
)
if HAVE_WANDB and self.metrics_writer.__name__ == "wandb":
# Make all inference/* metrics use inference_step as their x-axis
# This allows inference and training to have independent step counters
context.metrics_writer.define_metric(
"inference/*", step_metric="inference/inference_step"
)
# Initialize inference step offset by querying existing run history
self.inference_step_offset = 0
if wandb.run is not None:
api_run = wandb.Api().run(
f"{wandb.run.entity}/{wandb.run.project}/{wandb.run.id}"
)
max_step = 0
for row in api_run.scan_history(keys=["inference/inference_step"]):
val = row.get("inference/inference_step")
if isinstance(val, (int, float)) and int(val) > max_step:
max_step = int(val)
self.inference_step_offset = int(max_step)
# Create cuda graphs.
self.create_cuda_graphs()
def reset(self) -> None:
"""Reset by removing all requests and reset all state."""
self.context.reset()
# Request state.
self.request_counter = Counter()
self.finished_request_count = 0
self.evicted_request_count = 0
self.requests: Dict[int, RequestEntry] = {}
self.waiting_request_ids = deque()
self.failed_request_ids = []
# Track requests that should stop due to stop words (detected in post_process_requests)
self.stop_word_finished_request_ids: set[int] = set()
# Track requests currently being finished due to stop words (to skip extra token)
self.stop_word_being_finished_ids: set[int] = set()
# Timing and logging variables.
self.rank = torch.distributed.get_rank()
self.step_start_event = torch.cuda.Event(enable_timing=True)
self.step_end_event = torch.cuda.Event(enable_timing=True)
self.capture_stats = None
# Runtime state.
self._loop = get_asyncio_loop(getattr(self, "_loop", None))
self._cond = asyncio.Condition()
self._state_events = {k: asyncio.Event() for k in self._STATE_EVENTS}
self.state = EngineState.RUNNING
self._state_events[EngineState.RUNNING].set()
self._pending_signals = deque()
self.resume_request_ids = None
# Prefix caching coordination state.
self._prefix_coordination_waits = 0
# Coordinator state.
self.use_coordinator = False
async def wait_until(self, state: EngineState):
"""Wait until the engine reaches the given state.
Only stable states (RUNNING, PAUSED, SUSPENDED, RESUMED,
STOPPED) are supported. Transient states (PAUSING, SUSPENDING,
RESUMING, STOPPING) are not directly waitable.
"""
event = self._state_events.get(state)
if event is None:
raise ValueError(f"Cannot wait for transient state {state}")
await event.wait()
def create_cuda_graphs(self, reset_context: bool = True):
"""Create cuda graphs.
This method iterates the dynamic context's `cuda_graph_request_counts`
to record and capture cuda graphs.
Args:
reset_context (bool): Whether to reset the context after building cuda graphs.
"""
if self.cuda_graph_impl != "local":
return
if (
CudaGraphScope.full_iteration in self.cuda_graph_scope
and CudaGraphScope.full_iteration_inference not in self.cuda_graph_scope
):
warnings.warn(
"\n\n*** WARNING: 'full_iteration' CUDA graph scope used during inference! "
"This will not create inference CUDA graphs. Use '--cuda-graph-scope=full_iteration_inference' instead. ***\n"
)
context = self.context
controller = self.controller
time_start = time.time()
mem_stats_start = torch.cuda.memory_stats()
logging.info("> dynamic_engine.py: building cuda graphs for ")
for graph in context.cuda_graph_batch_dimensions_list:
logging.info(graph)
# Enable inference dispatcher for EP during graph capture
model_config = controller.inference_wrapped_model.model.config
is_inference_optimized_ep = (
model_config.transformer_impl == "inference_optimized"
and model_config.expert_model_parallel_size > 1
)
if is_inference_optimized_ep:
unwrapped_model = controller.inference_wrapped_model.model
set_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model)
tbar = enumerate(context.cuda_graph_batch_dimensions_list)
if HAVE_TQDM:
tbar = tqdm(tbar, total=len(context.cuda_graph_batch_dimensions_list))
for tbar_idx, cuda_graph_batch_dimension in tbar:
input_ids, position_ids = self.controller._dynamic_step_context_init(
construct_graph_dimensions=cuda_graph_batch_dimension
)
# Progress.
tbar_str = f"cuda graph warmup - {cuda_graph_batch_dimension}"
if HAVE_TQDM:
tbar.set_description(tbar_str)
else:
logging.info(
f"{tbar_idx}/{len(context.cuda_graph_batch_dimensions_list)}. {tbar_str}"
)
# Enable routing recording during warmup if routing replay is enabled.
# This ensures the record_indices copy operation is captured in the CUDA graph.
model_config = controller.inference_wrapped_model.model.config
if model_config.moe_enable_routing_replay:
RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD)
# Forward pass -> logits.
controller._dynamic_step_forward_logits(input_ids, position_ids)
context.reset()
# Disable inference dispatcher after graph capture
if is_inference_optimized_ep:
unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model)
# Memory usage.
time_end = time.time()
mem_stats_end = torch.cuda.memory_stats()
capture_stats = {
"time": time_end - time_start,
"allocated_bytes": (
mem_stats_end["allocated_bytes.all.current"]
- mem_stats_start["allocated_bytes.all.current"]
),
"reserved_bytes": (
mem_stats_end["reserved_bytes.all.current"]
- mem_stats_start["reserved_bytes.all.current"]
),
}
logging.info(
"> built cuda graph(s) in %.2f sec, with total memory usage: "
"allocated %s, reserved %s.",
capture_stats["time"],
format_mem_bytes(capture_stats["allocated_bytes"]),
format_mem_bytes(capture_stats["reserved_bytes"]),
)
self.capture_stats = capture_stats
@internal_api
async def start_listening_to_data_parallel_coordinator(
self,
inference_coordinator_port: int | None = None,
launch_inference_coordinator: bool = True,
*,
coordinator_schedule_output_path: str | None = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
"""Initializes ZMQ communication to connect the engine with an inference coordinator.
This asynchronous method sets up the distributed communication infrastructure
that allows this inference engine to act as a worker under a central
`InferenceCoordinator`. It configures different ZMQ socket patterns
based on the rank's role within the distributed topology.
Note that this method must be called on all ranks, as it uses blocking torch broadcasts.
The setup involves two primary roles within each data-parallel group:
1. **MP Coordinator (TP_rank=0, PP_rank=0)**: This rank connects directly
to the central coordinator via a ZMQ `DEALER` socket. It receives
requests and uses a ZMQ `PUB` (publisher) socket to broadcast them
to all other ranks within its model-parallel (MP) group.
2. **MP Workers (all other ranks)**: These ranks use ZMQ `SUB` (subscriber)
sockets to listen for requests broadcast by their local MP Coordinator.
This architecture uses TCP sockets for both inter-node and intra-node broadcasts
within an MP group.
Finally, after setting up the communication channels and ensuring all ranks
are synchronized, this method starts the main engine processing loop
(`self.run_engine`) as a background asyncio task.
Args:
inference_coordinator_port (int | None): The network port where the central
`InferenceCoordinator` is or will be listening.
If None, a random available port will be selected.
If not None, the coordinator will attempt to bind to this port, but should it
not succeed (e.g., if the port is already in use), it may bind to a different port.
The actual port used is returned by this method.
launch_inference_coordinator (bool, optional): If True, the global rank 0
process will spawn and manage the `InferenceCoordinator`
process. Defaults to True.
Returns:
inference_coordinator_addresss (str): The network address of the central
`InferenceCoordinator`, which may not have the same port as what the user requested
with `inference_coordinator_port`.
"""
assert HAVE_ZMQ, (
"please install the pyzmq library to use InferenceCoordinator\n" "pip install pyzmq"
)
assert HAVE_MSGPACK, (
"please install the messagepack library to use InferenceCoordinator\n"
"pip install msgpack"
)
self.zmq_context = zmq.Context.instance()
self.zmq_sockets = [] # keep track of all sockets created by this engine
# Get world info.
dp_group = self.pg_collection.dp
dp_src = get_pg_src_rank(dp_group)
dp_size = get_pg_size(self.pg_collection.dp)
dp_rank = get_pg_rank(self.pg_collection.dp)
mp_group = self.pg_collection.mp
mp_src = get_pg_src_rank(mp_group)
tp_rank = get_pg_rank(self.pg_collection.tp)
pp_rank = get_pg_rank(self.pg_collection.pp)
self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator
local_ip = socket.gethostname()
# Spawn a DP coordinator process and get the connection info.
if launch_inference_coordinator and self.is_dp_coordinator:
spawn_context = multiprocessing.get_context('spawn')
deterministic_mode = torch.are_deterministic_algorithms_enabled()
dp_pipe, dp_process_pipe = spawn_context.Pipe()
coordinator_ready_event = spawn_context.Event()
self.inference_coordinator_process = spawn_context.Process(
target=DataParallelInferenceCoordinator.entrypoint,
args=(
dp_process_pipe,
coordinator_ready_event,
get_pg_size(self.pg_collection.dp),
self.controller.tokenizer,
inference_coordinator_port,
deterministic_mode,
self.context.block_size_tokens,
self.context.enable_prefix_caching,
self.context.prefix_caching_coordinator_policy,
coordinator_schedule_output_path,
),
)
self.inference_coordinator_process.start()
await await_process_call(dp_pipe.poll, self.inference_coordinator_process)
dp_addr = dp_pipe.recv()
dp_pipe.close()
# Check if the port number is not inference_coordinator_port
actual_port = int(dp_addr.rsplit(":", 1)[-1])
if inference_coordinator_port != None and actual_port != inference_coordinator_port:
logging.warning(
f"Requested InferenceCoordinator port {inference_coordinator_port} "
f"but got port {actual_port} instead. This happens if the request port "
f"is already in use."
)
elif not launch_inference_coordinator:
dp_addr = f"tcp://{local_ip}:{inference_coordinator_port}"
else:
dp_addr = None
# Find available ports for MP and bind to them.
if self.is_mp_coordinator:
mp_req_sock = self.zmq_context.socket(zmq.PUB)
mp_req_sock.bind_to_random_port(f"tcp://{local_ip}")
mp_req_addr = mp_req_sock.getsockopt_string(zmq.LAST_ENDPOINT)
mp_len_sock = self.zmq_context.socket(zmq.PUB)
mp_len_sock.bind_to_random_port(f"tcp://{local_ip}")
mp_len_addr = mp_len_sock.getsockopt_string(zmq.LAST_ENDPOINT)
else:
mp_req_addr = None
mp_len_addr = None
# Broadcast addresses to respective ranks.
bcast = [dp_addr]
torch.distributed.broadcast_object_list(bcast, src=dp_src, group=dp_group)
[dp_addr] = bcast
bcast = [mp_req_addr, mp_len_addr]
torch.distributed.broadcast_object_list(bcast, src=mp_src, group=mp_group)
[mp_req_addr, mp_len_addr] = bcast
identity = f'mp-coord-{dp_rank}'
if self.is_mp_coordinator:
# 1. Create dealer sockets where tp_rank = 0 and pp_rank = 0
# These will receive requests from an InferenceCoordinator.
self.socket_for_receiving_requests = self.zmq_context.socket(zmq.DEALER)
self.socket_for_receiving_requests.setsockopt(zmq.IDENTITY, identity.encode('utf-8'))
self.socket_for_receiving_requests.connect(dp_addr)
# send empty string. this is used to register with the coordinator.
self.socket_for_receiving_requests.send(b"")
# 2. Create a publisher socket. This is used to publish or broadcast
# requests within the model parallel group
self.model_parallel_publisher_socket = mp_req_sock
# 3. Create another publisher socket to broadcast the number of messages to receive.
self.model_parallel_num_msgs_publisher_socket = mp_len_sock
self.zmq_sockets += [
self.socket_for_receiving_requests,
self.model_parallel_num_msgs_publisher_socket,
self.model_parallel_publisher_socket,
]
# All MP ranks subscribe to the two publisher sockets
self.model_parallel_subscriber_socket = self.zmq_context.socket(zmq.SUB)
self.model_parallel_subscriber_socket.connect(mp_req_addr)
self.model_parallel_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.model_parallel_num_msgs_subscriber_socket = self.zmq_context.socket(zmq.SUB)
self.model_parallel_num_msgs_subscriber_socket.connect(mp_len_addr)
self.model_parallel_num_msgs_subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.zmq_sockets += [
self.model_parallel_subscriber_socket,
self.model_parallel_num_msgs_subscriber_socket,
]
torch.distributed.barrier(mp_group)
# initialize zmq-based EP communicator
self.ep_rank = get_pg_rank(self.pg_collection.ep)
self.ep_world_size = get_pg_size(self.pg_collection.ep)
if self.ep_world_size > 1:
self.expert_parallel_zmq_communicator = AsyncZMQCommunicator(
self.zmq_context, process_group=self.pg_collection.ep
)
# initialize zmq-based world communicator for consensus barriers
total_world_size = torch.distributed.get_world_size()
if total_world_size > 1:
self.world_zmq_communicator = AsyncZMQCommunicator(self.zmq_context, process_group=None)
if launch_inference_coordinator and self.is_dp_coordinator:
await await_process_call(
coordinator_ready_event.wait, self.inference_coordinator_process
)
logging.info("Inference co-ordinator is ready to receive requests!")
logging.info(f"Data parallel coordinator can be found at {dp_addr}")
# Finally run the engine infinite loop.
loop = get_asyncio_loop(loop)
self.engine_loop_task = loop.create_task(self.run_engine_with_coordinator(loop=loop))
return dp_addr
@contextmanager
@staticmethod
def suspend_resume_ctx(key: str, *, unified_memory_level: int) -> None:
"""Context manager for of suspending and resuming the engine.
This context manager records the time and memory usage when suspending
and resuming the context. TODO(@lmcafee): add argument to optionally
return nullcontext, to avoid overhead.
Args:
key (str): Key that identifies caller (e.g., 'suspend' or 'resume').
Return:
None.
"""
try:
start_mem = torch.cuda.memory_stats()
start_time = time.time()
range_push(f"{key}-inference-context")
torch.cuda.synchronize()
yield
finally:
range_pop()
end_time = time.time()
end_mem = torch.cuda.memory_stats()
start_mem_alloc = start_mem["allocated_bytes.all.current"]
end_mem_alloc = end_mem["allocated_bytes.all.current"]
start_mem_res = start_mem["reserved_bytes.all.current"]
end_mem_res = end_mem["reserved_bytes.all.current"]
rank_str = torch.distributed.get_rank()
dir_str = "deallocating" if end_mem_alloc <= start_mem_alloc else "allocating"
relative_time_str = f"{end_time - start_time:.3f} sec"
relative_mem_str = f"{abs(start_mem_alloc - end_mem_alloc) / 1024**3:.1f} gb"
if HAVE_PSUTIL:
process = psutil.Process()
mem_info = process.memory_info()
cpu_mem_str = f"{mem_info.rss / 1024**3:.1f} gb"
else:
cpu_mem_str = "--"
total_mem_str = ", ".join(
(
f"cpu: {cpu_mem_str}",
f"gpu: alloc {end_mem_alloc / 1024**3:.1f} gb",
f"res {end_mem_res / 1024**3:.1f} gb",
)
)
logging.info(
f"[rank {rank_str}] dynamic engine {key}, "
f"unified {unified_memory_level}, "
f"{dir_str} "
f"{relative_mem_str} in {relative_time_str} ... "
f"abs mem usage: {total_mem_str}"
)
def suspend(self):
"""Suspend engine by deallocating context's GPU state."""
# Skip if already suspended or in the process of suspending.
if self.state in (EngineState.SUSPENDED, EngineState.SUSPENDING):
return
# Deallocate context tensors.
with self.__class__.suspend_resume_ctx(
"suspended", unified_memory_level=self.unified_memory_level
):
self.context.deallocate_inference_state_buffers()
if (
self.context.kv_cache_management_mode != KVCacheManagementMode.PERSIST
and not self.context.static_kv_memory_pointers
):
delete_cuda_graphs()
# Build the list of requests to re-add on resume.
# All waiting requests are always included; active requests are included
# only if they are marked for recompute (their KV cache will be gone).
waiting_request_ids = list(self.waiting_request_ids)
active_request_ids = set(self.requests.keys()) - set(waiting_request_ids)
if self.context.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE:
recompute_active_ids = active_request_ids
else:
recompute_active_ids = set()
self.resume_request_ids = [*recompute_active_ids, *waiting_request_ids]
self.waiting_request_ids.clear()
# Checkpoint active requests that are marked for recompute.
for request_id in recompute_active_ids:
self.requests[request_id].record.checkpoint()
# If we are not using the inference coordinator, we need to manually handle state.
if not self.use_coordinator:
self.state = EngineState.SUSPENDED
def resume(self):
"""Resume engine by reallocating context's GPU state."""
# Skip if not suspended or in the process of suspending.
if self.state not in (EngineState.SUSPENDED, EngineState.SUSPENDING):
return
# Resume.
with self.__class__.suspend_resume_ctx(
"resumed", unified_memory_level=self.unified_memory_level
):
# Allocate context tensors.
alloc_time = time.time()
torch.cuda.synchronize()
self.context.reinitialize_inference_state_buffers()
torch.cuda.synchronize()
alloc_time = time.time() - alloc_time
capture_time = time.time()
if (
self.context.kv_cache_management_mode != KVCacheManagementMode.PERSIST
and not self.context.static_kv_memory_pointers
):
self.create_cuda_graphs()
capture_time = time.time() - capture_time
# Re-add requests saved during suspend.
add_time = time.time()
torch.cuda.synchronize()
for request_id in self.resume_request_ids:
self._add_request(self.get_request(request_id))
torch.cuda.synchronize()
add_time = time.time() - add_time
# Print inner timing (must be outside context manager above for correct formatting).
logging.info(
" > "
+ ", ".join(
(
f"inner timing: alloc {alloc_time:.3f}",
f"add {add_time:.3f}",
f"capture {capture_time:.3f}.",
)
)
)
# If we are not using the inference coordinator, we need to manually handle state.
if not self.use_coordinator:
self.state = EngineState.RUNNING
# Notify the condition variable that run_engine() waits on.
self._loop.call_soon_threadsafe(
asyncio.create_task, self._notify_cond_for_new_request()
)
@trace_async_exceptions
async def _notify_cond_for_new_request(self):
"""Helper function to notify condition variable when a new request is added."""
async with self._cond:
self._cond.notify_all()
def has_unfinished_requests(self) -> bool:
"""Test if context contains unfinished requests."""
return self.context.has_unfinished_requests() or len(self.waiting_request_ids) > 0
def get_request(self, request_id: int) -> DynamicInferenceRequest:
"""Get most recent request from a request record.
Args:
request_id (int): Request id.
Returns:
(DynamicInferenceRequest) The most recent request in the record.
"""
return self.requests[request_id].record[-1]
def _handle_failed_request(self, request_id: int):
"""Handle a failed request by sending the reply immediately.
The request is added to failed_request_ids so that the next bookkeeping pass can return it.
"""
request_entry = self.requests[request_id]
request = request_entry.record[-1]
if self.rank == 0:
warnings.warn(f"Request {request_id} failed to be added to the engine due to errors.")
request.add_event_fail()
self.failed_request_ids.append(request_id)
# Send the reply immediately, because it may never get a chance to be sent again.
if self.use_coordinator and self.is_mp_coordinator:
payload = msgpack.packb(
[Headers.ENGINE_REPLY.value, [entry.record.serialize()]], use_bin_type=True
)
self.socket_for_receiving_requests.send(payload)
elif not self.use_coordinator:
if request.prompt is None:
request.prompt = self.controller.tokenizer.detokenize(
request.prompt_tokens.tolist()
)
request.generated_text = self.controller.tokenizer.detokenize(request.generated_tokens)
entry.future.set_result(entry.record)
def _add_request(
self, request: DynamicInferenceRequest
) -> asyncio.Future[DynamicInferenceRequest]:
request_id = request.request_id
# Add request to self.requests. If the engine has previously been
# suspended, then the request may already exist.
if request_id not in self.requests:
self.requests[request_id] = RequestEntry(
record=DynamicInferenceRequestRecord.from_request(request),
future=self._loop.create_future(),
)
request.add_event_add_engine() # Record when request enters engine
if request.status is None:
request.status = Status.ACTIVE_AND_GENERATING_TOKENS
assert (
request.sampling_params.num_tokens_to_generate is None
or request.sampling_params.num_tokens_total is None
)
if request.sampling_params.top_n_logprobs > 0:
assert (
request.sampling_params.return_log_probs
), "top_n_logprobs requires sampling_params.return_log_probs to be True"
if (
request.sampling_params.return_log_probs
and not request.sampling_params.skip_prompt_log_probs
):
assert not self.materialize_only_last_token_logits, (
"Prompt log probs cannot be calculated if only last token logits are materialized. "
"Set materialize_only_last_token_logits to False in DynamicInferenceContext "
"or skip_prompt_log_probs to True in SamplingParams."
)
if request.sampling_params.num_tokens_total is not None:
request.sampling_params.num_tokens_to_generate = (
request.sampling_params.num_tokens_total - len(request.prompt_tokens)
)
request.sampling_params.num_tokens_total = None
if request.sampling_params.num_tokens_to_generate is None:
request.sampling_params.num_tokens_to_generate = self.context.max_sequence_length - len(
request.prompt_tokens
)
if request.sampling_params.termination_id is None:
try:
eod = self.controller.tokenizer.eod
except AttributeError:
if self.rank == 0:
warnings.warn(
"Termination ID not specified, and tokenizer does not define eod."
"Defaulting to not using termination id."
)
eod = -1
request.sampling_params.termination_id = eod
if (
len(request.prompt_tokens) + request.sampling_params.num_tokens_to_generate
> self.context.max_sequence_length
) or (request.sampling_params.num_tokens_to_generate < 0):
logging.error(
f"{request_id=} Invalid number of tokens to generate. Prompt len: {len(request.prompt_tokens)}, tokens to generate: {request.sampling_params.num_tokens_to_generate}, max seq len: {self.context.max_sequence_length}."
)
request.status = Status.FAILED
request.add_event_error_nontransient(MaxSequenceLengthOverflowError(request_id))
if len(request.prompt_tokens) > self.context.max_tokens and not self.enable_chunked_prefill:
logging.error(
f"{request_id=} Prompt is longer than context.max_tokens. Prompt tokens: {len(request.prompt_tokens)}, context.max_tokens: {self.context.max_tokens}, chunked_prefill: {self.enable_chunked_prefill}"
)
request.status = Status.FAILED
request.add_event_error_nontransient(TokenOverflowError(request_id))
# Tokenize stop words if provided
if request.sampling_params.stop_words:
stop_word_ids = [
self.controller.tokenize_prompt(stop_word, add_BOS=False)
for stop_word in request.sampling_params.stop_words
]
request.stop_word_ids = stop_word_ids
if request.status != Status.FAILED:
self.waiting_request_ids.append(request_id)
else:
self._handle_failed_request(request_id)
return self.requests[request_id].future
def add_request(
self,
request_id: int,
prompt: Union[str, List[int], Tensor],
sampling_params: Optional[SamplingParams] = None,
) -> asyncio.Future[DynamicInferenceRequest]:
"""Add request to inference context.
Args:
request_id (int): Unique ID of request.
prompt (Union[str, Tensor]): Prompt as either a text string or token IDs.
sampling_params (Optional[SamplingParams]): Sampling parameters for the request.
Return:
Returns an asyncio `Future[DynamicInferenceRequest]` for the user to wait on.
"""
prompt_str = None
# Tokenize prompt if text.
if isinstance(prompt, str):
# Tokenize prompt if text. Support legacy single-arg mocks.
prompt_str = prompt
try:
prompt_token_ids = self.controller.tokenize_prompt(prompt, sampling_params.add_BOS)
except TypeError:
prompt_token_ids = self.controller.tokenize_prompt(prompt)
tokens = torch.tensor(
prompt_token_ids, dtype=torch.int64, device=torch.cuda.current_device()
)
elif isinstance(prompt, list):
# Convert List[int] -> Tensor.
tokens = torch.tensor(prompt, dtype=torch.int64, device=torch.cuda.current_device())
elif isinstance(prompt, torch.Tensor):
# Prompt already tokenized.
assert prompt.dtype == torch.int64, prompt.dtype
assert prompt.device == torch.device(
f"cuda:{torch.cuda.current_device()}"
), prompt.device
tokens = prompt
else:
raise Exception("specialize for <%s>." % type(prompt).__name__)
# Initialize request.
request = DynamicInferenceRequest(
request_id=request_id,
prompt=prompt_str,
prompt_tokens=tokens,
sampling_params=sampling_params,
block_size_tokens=self.context.block_size_tokens,
enable_prefix_caching=self.context.enable_prefix_caching,
)
# Add request.
return self._add_request(request)
def post_process_requests(
self,
request_ids: torch.Tensor,
finished_request_ids: torch.Tensor,
evict_request_ids: torch.Tensor,
step_time: float,
sample: torch.Tensor,
log_probs: torch.Tensor,
top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None,
routing_indices_per_request: Optional[Dict[int, torch.Tensor]] = None,
) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
"""
Handles post-processing for requests after a step.
Args:
request_ids (torch.Tensor): A list of request_ids
finished_request_ids (torch.Tensor): A list of finished request ids
evict_request_ids (torch.Tensor): A list of evicted request ids.
step_time (float): The latency of the last step
sample: (torch.Tensor): The newly generated tokens for each request
log_probs: (List): Log probs for each request
top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to
list of (top_n_logprobs, top_n_indices) tuples.
routing_indices_per_request: (Dict[int, Tensor]): MoE routing indices
pre-mapped by request_id. Each value is a tensor of shape
[num_tokens_this_step, num_layers, topk].
Returns:
A list of active requests and completed requests as `DynamicInferenceRequest` objects
"""
active_request_ids: list[int] = []
finished_request_ids = set(finished_request_ids.tolist())
finished_request_records: list[DynamicInferenceRequestRecord] = []
self.finished_request_count += len(finished_request_ids)
if evict_request_ids is not None:
self.evicted_request_count += evict_request_ids.numel()
log_probs_iter = log_probs if log_probs else repeat(None)
block_allocator = self.context.block_allocator
# Pre-compute step-level block stats (before the per-request loop)
if self.track_generated_token_events:
blocks_allocated = block_allocator.total_count - block_allocator.total_avail
if block_allocator.enable_prefix_caching:
blocks_hashed_active = int((block_allocator.block_ref_counts > 0).sum().item())
blocks_ref_count = block_allocator.block_ref_counts.sum().item()
else:
blocks_hashed_active = blocks_allocated