generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathvllm_serve.py
More file actions
974 lines (848 loc) · 45 KB
/
vllm_serve.py
File metadata and controls
974 lines (848 loc) · 45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import base64
import logging
import os
from collections.abc import Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from io import BytesIO
from itertools import chain
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
import torch
import torch.distributed.distributed_c10d as c10d
from packaging.version import Version
from transformers import is_torch_xpu_available, is_vision_available
from trl import TrlParser
from trl.generation.vllm_generation import extract_logprobs
from trl.import_utils import (
is_fastapi_available,
is_pydantic_available,
is_uvicorn_available,
is_vllm_ascend_available,
is_vllm_available,
)
if is_fastapi_available():
from fastapi import FastAPI
if is_pydantic_available():
from pydantic import BaseModel
if is_uvicorn_available():
import uvicorn
if is_vision_available():
from PIL import Image
if is_vllm_available():
import vllm
from vllm import LLM, SamplingParams
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.parallel_state import get_world_group
from vllm.distributed.utils import StatelessProcessGroup
if Version(vllm.__version__) <= Version("0.11.0"):
from vllm.utils import get_open_port
else:
from vllm.utils.network_utils import get_open_port
if Version(vllm.__version__) <= Version("0.10.2"):
from vllm.sampling_params import GuidedDecodingParams
else:
from vllm.sampling_params import StructuredOutputsParams
if is_vllm_ascend_available():
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
logger = logging.getLogger(__name__)
# We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following
# error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use
# the 'spawn' start method
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
class WeightSyncWorkerExtension:
"""
A vLLM worker extension that enables weight synchronization between a client and multiple server workers.
This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` or
`ProcessGroupXCCL` to handle efficient GPU-based communication using NCCL. The primary purpose of this class is to
receive updated model weights from a client process and distribute them to all worker processes participating in
model inference.
"""
# The following attributes are initialized when `init_communicator` method is called.
communicator = None # Communicator for weight updates
client_rank = None # Source rank for broadcasting updated weights
def init_communicator(self, host: str, port: int, world_size: int, client_device_uuid: str) -> None:
"""
Initializes the weight update communicator using a stateless process group.
This method creates a `StatelessProcessGroup` that allows external training processes to communicate with vLLM
workers without interfering with the global torch distributed group.
Args:
host (`str`):
Hostname or IP address of the master node.
port (`int`):
Port number to be used for communication.
world_size (`int`):
Total number of participating processes in the update group.
client_device_uuid (`str`):
UUID of the device of client main process. Used to assert that devices are different from vllm workers
devices.
"""
if self.communicator is not None:
raise RuntimeError("Weight update group already initialized. Call close_communicator first.")
# TODO: will remove after torch xpu 2.9 support uuid in get_device_properties
if torch.cuda.is_available() or (
is_torch_xpu_available() and hasattr(torch.xpu.get_device_properties(self.device), "uuid")
):
accelerator_module = torch.xpu if is_torch_xpu_available() else torch.cuda
if client_device_uuid == str(accelerator_module.get_device_properties(self.device).uuid):
raise RuntimeError(
f"Attempting to use the same CUDA device (UUID: {client_device_uuid}) for multiple distinct "
"roles/ranks within the same communicator. This setup is unsupported and will likely lead to program "
"hangs or incorrect behavior. Ensure that trainer is using different devices than vLLM server."
)
# Get the rank of the current worker in the global world group.
rank = get_world_group().rank
if is_torch_xpu_available():
store = torch.distributed.TCPStore(host_name=host, port=port, world_size=world_size, is_master=(rank == 0))
prefixed_store = c10d.PrefixStore("client2server", store)
xccl_options = c10d.ProcessGroupXCCL.Options()
pg = c10d.ProcessGroupXCCL(
store=prefixed_store,
rank=rank,
size=world_size,
options=xccl_options,
)
self.communicator = pg
else:
# Create a stateless process group to manage communication between training processes and vLLM workers.
# Initialize the NCCL-based communicator for weight synchronization.
pg = StatelessProcessGroup.create(host=host, port=port, rank=rank, world_size=world_size)
self.communicator = PyNcclCommunicator(pg, device=self.device)
# The client process that sends updated weights has the highest rank (world_size - 1).
self.client_rank = world_size - 1
def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None:
"""
Receives updated weights from the client process and updates the named parameter in the model.
Args:
name (`str`):
Name of the weight tensor being updated.
dtype (`str`):
Data type of the weight tensor as a string (e.g., `"torch.float32"`).
shape (`Sequence[int]`):
Shape of the weight tensor.
"""
if self.communicator is None:
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.")
dtype = getattr(torch, dtype.split(".")[-1])
# Allocate memory for the incoming weight tensor on the correct device.
weight = torch.empty(shape, dtype=dtype, device=self.device)
if is_torch_xpu_available():
# Use XCCL to broadcast the updated weights from the client (src) to all workers.
self.communicator.broadcast(weight, root=self.client_rank)
self.communicator.barrier()
else:
# Use NCCL to broadcast the updated weights from the client (src) to all workers.
self.communicator.broadcast(weight, src=self.client_rank)
self.communicator.group.barrier()
# Load the received weights into the model.
self.model_runner.model.load_weights(weights=[(name, weight)])
def batch_update_named_params(self, params: list[tuple[str, str, tuple[int, ...]]]) -> None:
"""
Receives and updates multiple named parameters in a single batch.
This avoids per-parameter HTTP round-trips. The client broadcasts each
parameter's tensor via NCCL in the same order as the params list.
Args:
params: List of (name, dtype, shape) tuples for each parameter.
"""
if self.communicator is None:
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.")
weights_to_load = []
for name, dtype_str, shape in params:
dtype = getattr(torch, dtype_str.split(".")[-1])
weight = torch.empty(shape, dtype=dtype, device=self.device)
if is_torch_xpu_available():
self.communicator.broadcast(weight, root=self.client_rank)
else:
self.communicator.broadcast(weight, src=self.client_rank)
weights_to_load.append((name, weight))
# Single barrier after all broadcasts
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()
# Load all weights at once
self.model_runner.model.load_weights(weights=weights_to_load)
def close_communicator(self) -> None:
"""
Closes the communicator when weight synchronization is no longer needed.
This method deletes the NCCL communicator to release associated resources.
"""
if self.communicator is not None:
del self.communicator
self.communicator = None # Ensure attribute is reset to None
self.client_rank = None # Ensure attribute is reset to None
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
model (`str`):
Model name or path to load the model from.
revision (`str`, *optional*):
Revision to use for the model. If not specified, the default branch will be used.
tensor_parallel_size (`int`, *optional*, defaults to `1`):
Number of tensor parallel workers to use.
data_parallel_size (`int`, *optional*, defaults to `1`):
Number of data parallel workers to use.
host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host address to run the server on.
port (`int`, *optional*, defaults to `8000`):
Port to run the server on.
gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.
dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
max_model_len (`int`, *optional*):
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
enable_prefix_caching (`bool`, *optional*):
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
this feature.
enforce_eager (`bool`, *optional*, defaults to `False`):
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
vllm_model_impl (`str`, *optional*, defaults to `"vllm"`):
Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
implementation.
kv_cache_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether to trust remote code when loading models. Set to `True` to allow executing code from model
repositories. This is required for some custom models but introduces security risks.
log_level (`str`, *optional*, defaults to `"info"`):
Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`,
`"trace"`.
"""
model: str = field(
metadata={"help": "Model name or path to load the model from."},
)
revision: str | None = field(
default=None,
metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."},
)
tensor_parallel_size: int = field(
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
)
data_parallel_size: int = field(
default=1,
metadata={"help": "Number of data parallel workers to use."},
)
host: str = field(
default="0.0.0.0",
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: float = field(
default=0.9,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
dtype: str = field(
default="auto",
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
max_model_len: int | None = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
enable_prefix_caching: bool | None = field(
default=None,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature."
},
)
enforce_eager: bool | None = field(
default=False,
metadata={
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
"execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager "
"execution in hybrid."
},
)
kv_cache_dtype: str = field(
default="auto",
metadata={
"help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": "Whether to trust remote code when loading models. Set to True to allow executing code from model "
"repositories. This is required for some custom models but introduces security risks."
},
)
log_level: str = field(
default="info",
metadata={
"help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', "
"'trace'."
},
)
vllm_model_impl: str = field(
default="vllm",
metadata={
"help": "Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: "
"Use the `transformers` backend for model implementation. `vllm`: Use the `vllm` library for "
"model implementation."
},
)
def llm_worker(
script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
trust_remote_code=script_args.trust_remote_code,
model_impl=script_args.vllm_model_impl,
# Important so temperature scaling/logit tweaking affects the TIS log probs
logprobs_mode="processed_logprobs",
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
def chunk_list(lst: list, n: int) -> list[list]:
"""
Split list `lst` into `n` evenly distributed sublists.
Example:
```python
>>> chunk_list([1, 2, 3, 4, 5, 6], 2)
[[1, 2, 3], [4, 5, 6]]
>>> chunk_list([1, 2, 3, 4, 5, 6], 4)
[[1, 2], [3, 4], [5], [6]]
>>> chunk_list([1, 2, 3, 4, 5, 6], 8)
[[1], [2], [3], [4], [5], [6], [], []]
```
"""
k, r = divmod(len(lst), n)
return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)]
def main(script_args: ScriptArguments):
if not is_fastapi_available():
raise ImportError(
"FastAPI is required to run the vLLM serve script. Please install it using `pip install fastapi`."
)
if not is_pydantic_available():
raise ImportError(
"Pydantic is required to run the vLLM serve script. Please install it using `pip install pydantic`."
)
if not is_uvicorn_available():
raise ImportError(
"Uvicorn is required to run the vLLM serve script. Please install it using `pip install uvicorn`."
)
if not is_vllm_available():
raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.")
# Spawn dp workers, and setup pipes for communication
master_port = get_open_port()
connections = []
processes = []
for data_parallel_rank in range(script_args.data_parallel_size):
parent_connection, child_connection = Pipe()
process = Process(target=llm_worker, args=(script_args, data_parallel_rank, master_port, child_connection))
process.start()
connections.append(parent_connection)
processes.append(process)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Wait for all workers to send "ready"
ready_connections = set()
while len(ready_connections) < script_args.data_parallel_size:
for connection in connections:
msg = connection.recv()
if isinstance(msg, dict) and msg.get("status") == "ready":
ready_connections.add(connection)
yield
# Wait for processes to terminate
for process in processes:
process.join(timeout=10) # Wait for 10 seconds for the process to terminate
if process.is_alive():
logger.warning(f"Process {process} is still alive after 10 seconds, attempting to terminate...")
process.terminate()
process.join() # ensure process termination after calling terminate()
app = FastAPI(lifespan=lifespan)
# Define the endpoints for the model server
@app.get("/health/")
async def health():
"""
Health check endpoint to verify that the server is running.
"""
return {"status": "ok"}
@app.get("/get_world_size/")
async def get_world_size():
"""
Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`.
Returns:
`dict`:
A dictionary containing the world size.
Example response:
```json
{"world_size": 8}
```
"""
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size}
class GenerateRequest(BaseModel):
prompts: list[str]
images: list[str] | None = None
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
logprobs: int | None = 0
truncate_prompt_tokens: int | None = None
structured_outputs_regex: str | None = None
generation_kwargs: dict = field(default_factory=dict)
class GenerateResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
@app.post("/generate/", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""
Generates completions for the provided prompts.
Args:
request (`GenerateRequest`):
- `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions.
- `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process
along with prompts.
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
- `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during
generation.
- `temperature` (`float`, *optional*, defaults to `1.0`): Temperature for sampling. Higher values lead
to more random outputs.
- `top_p` (`float`, *optional*, defaults to `1.0`): Top-p (nucleus) sampling parameter. It controls the
diversity of the generated text.
- `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables
top-k sampling.
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion.
- `logprobs` (`int`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0,
only the sampled token's logprob is returned. When N>0, returns the top-N logprobs sorted by
descending probability.
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
truncation). If set to `None`, truncation is disabled.
- `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided,
the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
`SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains
keys that conflict with the other parameters, they will override them.
Returns:
`GenerateResponse`:
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
- `logprobs` (list of list of list of `float`): Per-token logprobs of shape (num_sequences, seq_len,
num_logprobs), sorted by descending probability.
- `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each logprob, same
shape as `logprobs`.
Example request:
```json
{"prompts": ["Hello world", "What is AI?"]}
```
Example response:
```json
{
"prompt_ids": [[101, 102], [201, 202]],
"completion_ids": [[103, 104, 105], [203, 204, 205]],
"logprobs": [[[-0.1], [-0.2], [-0.3]], [[-0.4], [-0.5], [-0.6]]],
"logprob_token_ids": [[[103], [104], [105]], [[203], [204], [205]]]
}
```
"""
request.images = request.images or [None] * len(request.prompts)
prompts = []
for prompt, image in zip(request.prompts, request.images, strict=True):
row = {"prompt": prompt}
if image is not None:
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
prompts.append(row)
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
"logprobs": request.logprobs,
}
generation_kwargs.update(request.generation_kwargs)
# Structured outputs, if enabled
if Version(vllm.__version__) <= Version("0.10.2"):
structured_outputs_key = "guided_decoding"
if request.structured_outputs_regex is not None:
if generation_kwargs.get("guided_decoding") is not None:
logger.warning(
"Both `structured_outputs_regex` and `generation_kwargs['guided_decoding']` are set; "
"`structured_outputs_regex` takes precedence."
)
structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex)
else:
structured_outputs = generation_kwargs.get("guided_decoding")
else:
structured_outputs_key = "structured_outputs"
if request.structured_outputs_regex is not None:
if generation_kwargs.get("structured_outputs") is not None:
logger.warning(
"Both `structured_outputs_regex` and `generation_kwargs['structured_outputs']` are set; "
"`structured_outputs_regex` takes precedence."
)
structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex)
elif isinstance(generation_kwargs.get("structured_outputs"), dict):
# If structured_outputs is passed as a dictionary in generation_kwargs, convert it to a
# StructuredOutputsParams object to ensure compatibility with vLLM's SamplingParams.
structured_outputs_dict = generation_kwargs.get("structured_outputs")
structured_outputs = StructuredOutputsParams(**structured_outputs_dict)
else:
structured_outputs = generation_kwargs.get("structured_outputs")
generation_kwargs[structured_outputs_key] = structured_outputs
sampling_params = SamplingParams(**generation_kwargs)
# Evenly distribute prompts across DP ranks
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
# Send the prompts to each worker
for connection, prompts in zip(connections, chunked_prompts, strict=True):
# When the number of prompts is less than data_parallel_size, some workers will receive empty prompts.
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
# with vLLM's requirement, and we later ignore the result.
if not prompts:
prompts = ["<placeholder>"]
kwargs = {"prompts": prompts, "sampling_params": sampling_params}
connection.send({"type": "call", "method": "generate", "kwargs": kwargs})
# Receive results
all_outputs = [connection.recv() for connection in connections]
# Handle empty prompts (see above)
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts, strict=True) if prompts]
# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
prompt_ids = [output.prompt_token_ids for output in all_outputs]
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
logprobs, logprob_token_ids = extract_logprobs(all_outputs)
return {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
"logprobs": logprobs,
"logprob_token_ids": logprob_token_ids,
}
class ChatRequest(BaseModel):
messages: list[list[dict]]
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
logprobs: int | None = 0
truncate_prompt_tokens: int | None = None
structured_outputs_regex: str | None = None
generation_kwargs: dict = field(default_factory=dict)
chat_template_kwargs: dict = field(default_factory=dict)
class ChatResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
@app.post("/chat/", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""
Generates completions for the provided chat messages.
Args:
request (`ChatRequest`):
- `messages` (list of `dict`): A list of messages (dicts with "role" and "content" keys) for the model
to generate completions.
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
- `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during
generation.
- `temperature` (`float`, *optional*, defaults to `1.0`): Temperature for sampling. Higher values lead
to more random outputs.
- `top_p` (`float`, *optional*, defaults to `1.0`): Top-p (nucleus) sampling parameter. It controls the
diversity of the generated text.
- `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables
top-k sampling.
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion.
- `logprobs` (`int`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0,
only the sampled token's logprob is returned. When N>0, returns the top-N logprobs sorted by
descending probability.
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
truncation). If set to `None`, truncation is disabled.
- `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided,
the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
`SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains
keys that conflict with the other parameters, they will override them.
- `chat_template_kwargs` (`dict`, *optional*): Additional keyword arguments to pass to the chat
template.
Returns:
`ChatResponse`:
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
- `logprobs` (list of list of list of `float`): Per-token logprobs of shape (num_sequences, seq_len,
num_logprobs), sorted by descending probability.
- `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each logprob, same
shape as `logprobs`.
Example request:
```bash
curl -X POST 'http://0.0.0.0:8000/chat/' \
-H 'Content-Type: application/json' \
-d '{"messages": [[{ "role": "user", "content": "Hello!" }]]}'
```
Example response:
```json
{
"prompt_ids": [[151644, 872, 198, 9707, 0, 151645, 198, 151644, 77091, 198]],
"completion_ids": [[151667, 198, 32313, 11, 279]],
"logprobs": [[[-0.0003], [-3.58e-07], [-0.0902], [-6.39e-05], [-0.0387]]],
"logprob_token_ids": [[[151667], [198], [32313], [11], [279]]]
}
```
"""
# Convert PIL images to base64 strings
for message_list in request.messages:
for message in message_list:
if isinstance(message["content"], list):
for part in message["content"]:
if part["type"] == "image_pil":
part["image_pil"] = Image.open(BytesIO(base64.b64decode(part["image_pil"])))
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
"logprobs": request.logprobs,
}
generation_kwargs.update(request.generation_kwargs)
# Structured outputs, if enabled
if Version(vllm.__version__) <= Version("0.10.2"):
structured_outputs_key = "guided_decoding"
if request.structured_outputs_regex is not None:
if generation_kwargs.get("guided_decoding") is not None:
logger.warning(
"Both `structured_outputs_regex` and `generation_kwargs['guided_decoding']` are set; "
"`structured_outputs_regex` takes precedence."
)
structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex)
else:
structured_outputs = generation_kwargs.get("guided_decoding")
else:
structured_outputs_key = "structured_outputs"
if request.structured_outputs_regex is not None:
if generation_kwargs.get("structured_outputs") is not None:
logger.warning(
"Both `structured_outputs_regex` and `generation_kwargs['structured_outputs']` are set; "
"`structured_outputs_regex` takes precedence."
)
structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex)
elif isinstance(generation_kwargs.get("structured_outputs"), dict):
# If structured_outputs is passed as a dictionary in generation_kwargs, convert it to a
# StructuredOutputsParams object to ensure compatibility with vLLM's SamplingParams.
structured_outputs_dict = generation_kwargs.get("structured_outputs")
structured_outputs = StructuredOutputsParams(**structured_outputs_dict)
else:
structured_outputs = generation_kwargs.get("structured_outputs")
generation_kwargs[structured_outputs_key] = structured_outputs
sampling_params = SamplingParams(**generation_kwargs)
# Evenly distribute prompts across DP ranks
chunked_messages = chunk_list(request.messages, script_args.data_parallel_size)
# Send the messages to each worker
for connection, messages in zip(connections, chunked_messages, strict=True):
# When the number of messages is less than data_parallel_size, some workers will receive empty messages.
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
# with vLLM's requirement, and we later ignore the result.
if not messages:
messages = [[{"role": "user", "content": "<placeholder>"}]]
kwargs = {
"messages": messages,
"sampling_params": sampling_params,
"chat_template_kwargs": request.chat_template_kwargs,
}
connection.send({"type": "call", "method": "chat", "kwargs": kwargs})
# Receive results
all_outputs = [connection.recv() for connection in connections]
# Handle empty prompts (see above)
all_outputs = [output for output, prompts in zip(all_outputs, chunked_messages, strict=True) if prompts]
# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
prompt_ids = [output.prompt_token_ids for output in all_outputs]
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
logprobs, logprob_token_ids = extract_logprobs(all_outputs)
return {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
"logprobs": logprobs,
"logprob_token_ids": logprob_token_ids,
}
class InitCommunicatorRequest(BaseModel):
host: str
port: int
world_size: int
client_device_uuid: str
@app.post("/init_communicator/")
async def init_communicator(request: InitCommunicatorRequest):
"""
Initializes the communicator for synchronizing model weights between a client and multiple server workers.
Args:
request (`InitCommunicatorRequest`):
- `host` (`str`): Hostname or IP address of the master node.
- `port` (`int`): Port number to be used for communication.
- `world_size` (`int`): Total number of participating processes in the group.
- `client_device_uuid` (`str`): UUID of the device of client main process. Used to assert that devices
are different from vLLM workers devices.
"""
world_size = script_args.tensor_parallel_size * script_args.data_parallel_size + 1
# The function init_communicator is called this way: init_communicator(host, port, world_size)
# So with collective_rpc we need to call it this way:
# llm.collective_rpc(method="init_communicator", args=(host, port, world_size))
kwargs = {
"method": "init_communicator",
"args": (request.host, request.port, world_size, request.client_device_uuid),
}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(*(loop.run_in_executor(None, conn.send, msg) for conn in connections))
return {"message": "Request received, initializing communicator"}
class UpdateWeightsRequest(BaseModel):
name: str
dtype: str
shape: list[int]
@app.post("/update_named_param/")
async def update_named_param(request: UpdateWeightsRequest):
"""
Updates the model weights with the provided tensor.
Once this endpoint is called, the client process should broadcast the updated weights to all server workers.
Args:
request (`UpdateWeightsRequest`):
- `name` (`str`): Name of the weight tensor being updated.
- `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`).
- `shape` (list of `int`): Shape of the weight
"""
# The function update_named_param is called this way: update_named_param("name", "torch.float32", (10, 10))
# So with collective_rpc we need to call it this way:
# llm.collective_rpc("update_named_param", args=("name", "torch.float32", (10, 10)))
kwargs = {"method": "update_named_param", "args": (request.name, request.dtype, tuple(request.shape))}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(*(loop.run_in_executor(None, conn.send, msg) for conn in connections))
return {"message": "Request received, updating named parameter"}
class BatchUpdateWeightsRequest(BaseModel):
params: list[dict] # List of {"name": str, "dtype": str, "shape": list[int]}
@app.post("/batch_update_named_params/")
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
"""Batch update: sends all param metadata in one HTTP call, then NCCL broadcasts happen in sequence."""
params_list = [(p["name"], p["dtype"], tuple(p["shape"])) for p in request.params]
kwargs = {"method": "batch_update_named_params", "args": (params_list,)}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(*(loop.run_in_executor(None, conn.send, msg) for conn in connections))
return {"message": f"Batch update started for {len(params_list)} parameters"}
@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
"""
Resets the prefix cache for the model.
"""
for connection in connections:
connection.send({"type": "call", "method": "reset_prefix_cache"})
# Wait for and collect all results
all_outputs = [connection.recv() for connection in connections]
success = all(output for output in all_outputs)
return {"message": "Request received, resetting prefix cache status: " + str(success)}
@app.post("/close_communicator/")
async def close_communicator():
"""
Closes the weight update group and cleans up associated resources.
"""
kwargs = {"method": "close_communicator"}
for connection in connections:
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})
return {"message": "Request received, closing communicator"}
# Start the server
uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level)
def make_parser(subparsers: argparse._SubParsersAction | None = None):
if subparsers is not None:
parser = subparsers.add_parser("vllm-serve", help="Run the vLLM serve script", dataclass_types=ScriptArguments)
else:
parser = TrlParser(ScriptArguments)
return parser
if __name__ == "__main__":
parser = make_parser()
(script_args,) = parser.parse_args_and_config()
main(script_args)