forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvllm.py
More file actions
1799 lines (1602 loc) · 77.2 KB
/
vllm.py
File metadata and controls
1799 lines (1602 loc) · 77.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import getpass
import json
import os
import tempfile
import threading
import time
from contextlib import contextmanager
from dataclasses import is_dataclass
from datetime import datetime
from enum import IntEnum
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args
import torch
from pydantic import ConfigDict, Field, model_validator
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash
from .attention import AttentionConfig
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
from .ec_transfer import ECTransferConfig
from .kernel import KernelConfig
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .offload import OffloadConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config, replace
from .weight_transfer import WeightTransferConfig
if TYPE_CHECKING:
from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
else:
PretrainedConfig = Any
QuantizationConfig = Any
KVCacheConfig = Any
logger = init_logger(__name__)
class OptimizationLevel(IntEnum):
"""Optimization level enum."""
O0 = 0
"""O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately"""
O1 = 1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs"""
O2 = 2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
O3 = 3
"""O3: Currently the same as -O2s."""
PerformanceMode = Literal["balanced", "interactivity", "throughput"]
IS_QUANTIZED = False
IS_DENSE = False
# The optimizations that depend on these properties currently set to False
# in all cases.
# if model_config is not None:
# IS_QUANTIZED = lambda c: c.model_config.is_quantized()
# IS_DENSE = lambda c: not c.model_config.is_model_moe()
# See https://github.com/vllm-project/vllm/issues/25689.
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
"""Enable if either RMS norm or quant FP8 custom op is active;
otherwise Inductor handles fusion."""
return cfg.compilation_config.is_custom_op_enabled(
"rms_norm"
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
def enable_act_fusion(cfg: "VllmConfig") -> bool:
"""
Enable if either SiLU+Mul or quant FP8 custom op is active;
otherwise Inductor handles fusion.
Also enable for FP4 models as FP4 quant is always custom so Inductor cannot fuse it.
"""
return (
cfg.compilation_config.is_custom_op_enabled("silu_and_mul")
or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
or (cfg.model_config is not None and cfg.model_config.is_nvfp4_quantized())
)
def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
"""Enable if TP > 1 and Hopper/Blackwell and flashinfer installed."""
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
return (
cfg.parallel_config.tensor_parallel_size > 1
and current_platform.is_cuda()
and has_flashinfer()
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
# tp-dp combination broken:
# https://github.com/vllm-project/vllm/issues/34458
and cfg.parallel_config.data_parallel_size == 1
# tp-pp combination broken:
# https://github.com/vllm-project/vllm/issues/35426
and cfg.parallel_config.pipeline_parallel_size == 1
)
def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
"""Enable if rotary embedding custom op is active and
use_inductor_graph_partition is enabled.
"""
from vllm._aiter_ops import rocm_aiter_ops
return (
rocm_aiter_ops.is_enabled()
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
and cfg.compilation_config.use_inductor_graph_partition
)
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
from vllm._aiter_ops import rocm_aiter_ops
return (
rocm_aiter_ops.is_rmsnorm_enabled()
and not rocm_aiter_ops.is_triton_gemm_enabled()
and cfg.model_config is not None
and cfg.model_config.get_hidden_size() == 2880
)
OPTIMIZATION_LEVEL_00 = {
"compilation_config": {
"pass_config": {
"fuse_norm_quant": False,
"fuse_act_quant": False,
"fuse_allreduce_rms": False,
"fuse_attn_quant": False,
"fuse_gemm_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": False,
"fuse_rope_kvcache": False,
},
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": False,
},
}
OPTIMIZATION_LEVEL_01 = {
"compilation_config": {
"pass_config": {
"fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": False,
"fuse_attn_quant": False,
"fuse_gemm_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
},
}
OPTIMIZATION_LEVEL_02 = {
"compilation_config": {
"pass_config": {
"fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": enable_allreduce_rms_fusion,
"fuse_attn_quant": IS_QUANTIZED,
"fuse_gemm_quant": False,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
},
}
OPTIMIZATION_LEVEL_03 = {
"compilation_config": {
"pass_config": {
"fuse_norm_quant": enable_norm_fusion,
"fuse_act_quant": enable_act_fusion,
"fuse_allreduce_rms": enable_allreduce_rms_fusion,
"fuse_attn_quant": IS_QUANTIZED,
"fuse_gemm_quant": False,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
},
}
OPTIMIZATION_LEVEL_TO_CONFIG = {
OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
}
@config(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
# TODO: use default_factory once default constructing ModelConfig doesn't
# try to download a model
model_config: ModelConfig = Field(default=None)
"""Model configuration."""
cache_config: CacheConfig = Field(default_factory=CacheConfig)
"""Cache configuration."""
parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
"""Parallel configuration."""
scheduler_config: SchedulerConfig = Field(
default_factory=SchedulerConfig.default_factory,
)
"""Scheduler configuration."""
device_config: DeviceConfig = Field(default_factory=DeviceConfig)
"""Device configuration."""
load_config: LoadConfig = Field(default_factory=LoadConfig)
"""Load configuration."""
offload_config: OffloadConfig = Field(default_factory=OffloadConfig)
"""Model weight offloading configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
"""Kernel configuration."""
lora_config: LoRAConfig | None = None
"""LoRA configuration."""
speculative_config: SpeculativeConfig | None = None
"""Speculative decoding configuration."""
structured_outputs_config: StructuredOutputsConfig = Field(
default_factory=StructuredOutputsConfig
)
"""Structured outputs configuration."""
observability_config: ObservabilityConfig = Field(
default_factory=ObservabilityConfig
)
"""Observability configuration."""
quant_config: QuantizationConfig | None = None
"""Quantization configuration."""
compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
"""`torch.compile` and cudagraph capture configuration for the model.
As a shorthand, one can append compilation arguments via
-cc.parameter=argument such as `-cc.mode=3` (same as `-cc='{"mode":3}'`).
You can specify the full compilation config like so:
`{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
"""
profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
"""Profiling configuration."""
kv_transfer_config: KVTransferConfig | None = None
"""The configurations for distributed KV cache transfer."""
kv_events_config: KVEventsConfig | None = None
"""The configurations for event publishing."""
ec_transfer_config: ECTransferConfig | None = None
"""The configurations for distributed EC cache transfer."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
additional_config: dict | SupportsHash = Field(default_factory=dict)
"""Additional config for specified platform. Different platforms may
support different configs. Make sure the configs are valid for the platform
you are using. Contents must be hashable."""
instance_id: str = ""
"""The ID of the vLLM instance."""
optimization_level: OptimizationLevel = OptimizationLevel.O2
"""The optimization level. These levels trade startup time cost for
performance, with -O0 having the best startup time and -O3 having the best
performance. -O2 is used by default. See OptimizationLevel for full
description."""
performance_mode: PerformanceMode = "balanced"
"""Performance mode for runtime behavior, 'balanced' is the default.
'interactivity' favors low end-to-end per-request latency at small batch
sizes (fine-grained CUDA graphs, latency-oriented kernels).
'throughput' favors aggregate tokens/sec at high concurrency (larger CUDA
graphs, more aggressive batching, throughput-oriented kernels)."""
weight_transfer_config: WeightTransferConfig | None = None
"""The configurations for weight transfer during RL training."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
# summarize vllm config
vllm_factors: list[Any] = []
from vllm import __version__
vllm_factors.append(__version__)
if self.model_config:
vllm_factors.append(self.model_config.compute_hash())
if (
self.compilation_config
and getattr(self.compilation_config, "compile_mm_encoder", False)
and self.model_config.multimodal_config
):
vllm_factors.append(self.model_config.multimodal_config.compute_hash())
else:
vllm_factors.append("None")
if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash())
else:
vllm_factors.append("None")
if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash())
else:
vllm_factors.append("None")
if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash())
else:
vllm_factors.append("None")
if self.device_config:
vllm_factors.append(self.device_config.compute_hash())
else:
vllm_factors.append("None")
if self.load_config:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.offload_config:
vllm_factors.append(self.offload_config.compute_hash())
else:
vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
else:
vllm_factors.append("None")
if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash())
else:
vllm_factors.append("None")
if self.structured_outputs_config:
vllm_factors.append(self.structured_outputs_config.compute_hash())
if self.profiler_config:
vllm_factors.append(self.profiler_config.compute_hash())
else:
vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash())
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash())
else:
vllm_factors.append("None")
if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.ec_transfer_config:
vllm_factors.append(self.ec_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = safe_hash(
json.dumps(additional_config, sort_keys=True).encode(),
usedforsecurity=False,
).hexdigest()
else:
additional_config_hash = additional_config.compute_hash()
vllm_factors.append(additional_config_hash)
else:
vllm_factors.append("None")
factors.append(vllm_factors)
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
return hash_str
@property
def num_speculative_tokens(self) -> int:
if (
self.speculative_config is not None
and self.speculative_config.num_speculative_tokens is not None
):
return self.speculative_config.num_speculative_tokens
return 0
@property
def needs_dp_coordinator(self) -> bool:
"""
Determine if the DPCoordinator process is needed.
The DPCoordinator is needed in two cases:
1. For MoE models with DP > 1: to handle wave coordination
(even in external LB mode, since wave coordination runs in the coordinator)
2. For non-MoE models in internal/hybrid LB mode: to collect and publish
queue stats for load balancing across DP ranks
Returns:
True if DPCoordinator process is needed, False otherwise.
"""
# For non-MoE models, only need coordinator in internal/hybrid LB mode
# (for stats collection).
return self.parallel_config.data_parallel_size > 1 and (
self.model_config is None
or self.model_config.is_moe
or not self.parallel_config.data_parallel_external_lb
)
def enable_trace_function_call_for_thread(self) -> None:
"""
Set up function tracing for the current thread,
if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
"""
if envs.VLLM_TRACE_FUNCTION:
tmp_dir = tempfile.gettempdir()
# add username to tmp_dir to avoid permission issues
tmp_dir = os.path.join(tmp_dir, getpass.getuser())
filename = (
f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
).replace(" ", "_")
log_path = os.path.join(
tmp_dir,
"vllm",
f"vllm-instance-{self.instance_id}",
filename,
)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
@staticmethod
def _get_quantization_config(
model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig | None:
"""Get the quantization config."""
from vllm.platforms import current_platform
if model_config.quantization is not None:
from vllm.model_executor.model_loader.weight_utils import get_quant_config
quant_config = get_quant_config(model_config, load_config)
capability_tuple = current_platform.get_device_capability()
if capability_tuple is not None:
capability = capability_tuple.to_int()
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. Minimum "
f"capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
quant_config.maybe_update_config(model_config.model)
return quant_config
return None
@staticmethod
def get_quantization_config(
model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig | None:
import copy
# For some reason, the _ version of this modifies the model_config
# object, so using deepcopy to avoid this problem.
return VllmConfig._get_quantization_config(
copy.deepcopy(model_config), load_config
)
def with_hf_config(
self,
hf_config: PretrainedConfig,
architectures: list[str] | None = None,
) -> "VllmConfig":
if architectures is not None:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = architectures
model_config = copy.deepcopy(self.model_config)
if (
model_config.is_multimodal_model
and hasattr(model_config.hf_config, "tie_word_embeddings")
and not hasattr(hf_config.get_text_config(), "tie_word_embeddings")
):
# In Transformers v5, tie_word_embeddings belongs to the config of the class
# that can see both layers to be tied. For example:
#
# SomeVLModel:
# self.language_model = SomeLanguageModel()
# self.vision_model = SomeVisionModel()
#
# SomeVLModelForMultimodalLM:
# self.model = SomeVLModel()
# self.lm_head = nn.Linear()
#
# Therefore, tie_word_embeddings is defined in SomeVLModelForMultimodalLM's
# config and is not present in SomeVLModel's config. In vLLM, the lm_head
# belongs to the language_model, so we must ensure that tie_word_embeddings
# is set in the language_model's config.
tie_word_embeddings = model_config.hf_config.tie_word_embeddings
hf_config.get_text_config().tie_word_embeddings = tie_word_embeddings
model_config.hf_config = hf_config
model_config.model_arch_config = model_config.get_model_arch_config()
return replace(self, model_config=model_config)
def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
"""Set config attribute to default if not already set by user.
Args:
config_obj: Configuration object to update.
key: Attribute name.
value: Default value (static or callable).
"""
if getattr(config_obj, key) is None:
# Some config values are known before initialization and are
# hard coded.
# Other values depend on the user given configuration, so they are
# implemented with lambda functions and decided at run time.
setattr(config_obj, key, value(self) if callable(value) else value)
def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
"""Apply optimization level defaults using self as root.
Recursively applies values from defaults into nested config objects.
Only fields present in defaults are overwritten.
If the user configuration does not specify a value for a default field
and if the default field is still None after all user selections are
applied, then default values will be applied to the field. User specified
fields will not be overridden by the default.
Args:
defaults: Dictionary of default values to apply.
"""
def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
"""Recursively apply defaults to config_obj, using self as root."""
for key, value in config_defaults.items():
if not hasattr(config_obj, key):
continue
current = getattr(config_obj, key)
if isinstance(value, dict) and is_dataclass(current):
apply_recursive(current, value)
else:
self._set_config_default(config_obj, key, value)
apply_recursive(self, defaults)
def _post_init_kv_transfer_config(self) -> None:
"""Update KVTransferConfig based on top-level configs in VllmConfig.
Right now, this function reads the offloading settings from
CacheConfig and configures the KVTransferConfig accordingly.
"""
# KV offloading is only activated when kv_offloading_size is set.
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
return
kv_offloading_backend = self.cache_config.kv_offloading_backend
# If no KVTransferConfig is provided, create a default one.
if self.kv_transfer_config is None:
self.kv_transfer_config = KVTransferConfig()
num_kv_ranks = (
self.parallel_config.tensor_parallel_size
* self.parallel_config.pipeline_parallel_size
)
if kv_offloading_backend == "native":
self.kv_transfer_config.kv_connector = "OffloadingConnector"
self.kv_transfer_config.kv_connector_extra_config.update(
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
)
elif kv_offloading_backend == "lmcache":
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
self.kv_transfer_config.kv_connector_extra_config = {
"lmcache.local_cpu": True,
"lmcache.max_local_cpu_size": kv_gb_per_rank,
}
# This is the same for all backends
self.kv_transfer_config.kv_role = "kv_both"
def __post_init__(self):
"""Verify configs are valid & consistent with each other."""
# To give each torch profile run a unique instance name.
self.instance_id = f"{time.time_ns()}"
if self.performance_mode != "balanced":
logger.info_once(
"Performance mode set to '%s'.", self.performance_mode, scope="local"
)
self.try_verify_and_update_config()
if self.model_config is not None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(self.load_config)
self.parallel_config.is_moe_model = self.model_config.is_moe
if self.lora_config is not None:
self.lora_config.verify_with_model_config(self.model_config)
if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config
)
executor_backend = self.parallel_config.distributed_executor_backend
executor_supports_async_sched = executor_backend in (
"mp",
"uni",
"external_launcher",
)
if self.scheduler_config.async_scheduling:
# Async scheduling explicitly enabled, hard fail any incompatibilities.
# Currently, async scheduling only support eagle speculative
# decoding.
if self.speculative_config is not None:
if (
self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
and self.speculative_config.method != "draft_model"
):
raise ValueError(
"Currently, async scheduling is only supported "
"with EAGLE/MTP/Draft Model/NGram GPU kind of "
"speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"Async scheduling is not compatible with "
"disable_padded_drafter_batch=True."
)
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "
"`external_launcher` distributed executor backend, but you chose "
f"`{executor_backend}`."
)
elif self.scheduler_config.async_scheduling is None:
# Enable async scheduling unless there is an incompatible option.
if (
self.speculative_config is not None
and self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
):
logger.warning_once(
"Async scheduling not supported with %s-based "
"speculative decoding and will be disabled.",
self.speculative_config.method,
scope="local",
)
self.scheduler_config.async_scheduling = False
elif (
self.speculative_config is not None
and self.speculative_config.disable_padded_drafter_batch
):
logger.warning_once(
"Async scheduling is not compatible with "
"disable_padded_drafter_batch=True and will be disabled.",
scope="local",
)
self.scheduler_config.async_scheduling = False
elif not executor_supports_async_sched:
logger.warning_once(
"Async scheduling will be disabled because it is not supported "
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
"`external_launcher` are supported).",
executor_backend,
scope="local",
)
self.scheduler_config.async_scheduling = False
else:
self.scheduler_config.async_scheduling = True
logger.info_once(
"Asynchronous scheduling is %s.",
"enabled" if self.scheduler_config.async_scheduling else "disabled",
)
if self.parallel_config.disable_nccl_for_dp_synchronization is None:
if self.scheduler_config.async_scheduling:
if self.parallel_config.data_parallel_size > 1 and (
self.model_config is None or self.model_config.is_moe
):
logger.info_once(
"Disabling NCCL for DP synchronization "
"when using async scheduling.",
scope="local",
)
self.parallel_config.disable_nccl_for_dp_synchronization = True
else:
self.parallel_config.disable_nccl_for_dp_synchronization = False
from vllm.platforms import current_platform
if (
self.model_config is not None
and self.scheduler_config.enable_chunked_prefill
and self.model_config.dtype == torch.float32
and current_platform.get_device_capability() == (7, 5)
):
logger.warning_once(
"Turing devices tensor cores do not support float32 matmul. "
"To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels."
)
if self.model_config is not None and self.model_config.enforce_eager:
logger.warning(
"Enforce eager set, disabling torch.compile and CUDAGraphs. "
"This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none"
)
self.compilation_config.mode = CompilationMode.NONE
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if self.compilation_config.backend == "eager" or (
self.compilation_config.mode is not None
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.warning(
"Inductor compilation was disabled by user settings, "
"optimizations settings that are only active during "
"inductor compilation will be ignored."
)
def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than
# native implementation
# https://github.com/vllm-project/vllm/issues/25094
if has_blocked_weights():
custom_ops = self.compilation_config.custom_ops
if "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
current_platform.apply_config_platform_defaults(self)
if self.compilation_config.mode is None:
if self.optimization_level > OptimizationLevel.O0:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
else:
self.compilation_config.mode = CompilationMode.NONE
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
and self.compilation_config.mode != CompilationMode.NONE
):
self.compilation_config.custom_ops.append("none")
else:
self.compilation_config.custom_ops.append("all")
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)
if self.kernel_config.enable_flashinfer_autotune is None:
raise ValueError(
"KernelConfig.enable_flashinfer_autotune must be set after applying "
"optimization level defaults."
)
if (
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.info(
"Cudagraph mode %s is not compatible with compilation mode %s."
"Overriding to NONE.",
self.compilation_config.cudagraph_mode,
self.compilation_config.mode,
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.fuse_gemm_comms:
self.compilation_config.pass_config.enable_sp = True
if self.compilation_config.pass_config.enable_sp:
if self.parallel_config.tensor_parallel_size == 1:
logger.warning("Sequence Parallelism requires TP>1, disabling")
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
else:
# Compute SP threshold early; disable if None (model too
# small for SP to be beneficial).
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
)
tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size
)
if pass_config.sp_min_token_num is None:
logger.warning(
"Model hidden_size too small for the SP "
"threshold heuristic, disabling. To force SP, "
"set pass_config.sp_min_token_num manually."
)
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
if HAS_OPAQUE_TYPE:
# On torch >= 2.11 the hoisted OpaqueObject approach supersedes
# fast_moe_cold_start, so force it off.
self.compilation_config.fast_moe_cold_start = False
elif self.compilation_config.fast_moe_cold_start is None:
# resolve default behavior: try to be as safe as possible
# this config is unsafe if any spec decoding draft model has a MOE.
# We'll conservatively turn it off if we see spec decoding.
self.compilation_config.fast_moe_cold_start = (
self.speculative_config is None
)
self._set_max_num_scheduled_tokens()
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if model_config := self.model_config:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
# Check if KV connector requires PIECEWISE mode for CUDA graphs
if (
self.kv_transfer_config is not None
and self.kv_transfer_config.is_kv_transfer_instance
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
# Lazy import to avoid circular dependencies
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory,
)
connector_cls = KVConnectorFactory.get_connector_class(
self.kv_transfer_config
)
if connector_cls.requires_piecewise_for_cudagraph(
self.kv_transfer_config.kv_connector_extra_config
):
logger.warning_once(
"KV connector %s requires PIECEWISE CUDA graph mode "
"due to layerwise async operations that cannot be "
"captured in CUDA graphs. "
"Overriding cudagraph_mode from %s to PIECEWISE.",
connector_cls.__name__,
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# override related settings when enforce eager
self.compilation_config.max_cudagraph_capture_size = 0
self.compilation_config.cudagraph_capture_sizes = []
else:
self.compilation_config.cudagraph_num_of_warmups = 1
self._set_cudagraph_sizes()
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if self.cache_config.kv_sharing_fast_prefill:
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
raise ValueError(
"Fast prefill optimization for KV sharing is not "
"compatible with EAGLE as EAGLE requires correct logits "
"for all tokens while fast prefill gives incorrect logits "
"for prompt tokens."
)
logger.warning_once(
"--kv-sharing-fast-prefill requires changes on model side for "
"correctness and to realize prefill savings."
)
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()
if (
self.model_config
and self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "