-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathllm_args.py
More file actions
3782 lines (3203 loc) · 151 KB
/
llm_args.py
File metadata and controls
3782 lines (3203 loc) · 151 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import ast
import functools
import json
import math
import os
import types
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Annotated, Any, ClassVar, Dict, List,
Literal, Optional, Set, Tuple, Type, TypeAlias, TypeVar,
Union, get_args, get_origin)
import torch
import yaml
from pydantic import AliasChoices, BaseModel, ConfigDict
from pydantic import Field as PydanticField
from pydantic import (NonNegativeFloat, NonNegativeInt, PositiveInt,
PrivateAttr, field_validator, model_validator)
from strenum import StrEnum
from transformers import PreTrainedTokenizerBase
try:
from ray.util.placement_group import PlacementGroup
except ImportError:
PlacementGroup = None
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from .._utils import _str_to_torch_dtype_dict, mpi_rank, prefer_pinned
# yapf: disable
# isort: off
from ..bindings.executor import (BatchingType as _BatchingType,
CacheTransceiverBackendType as _CacheTransceiverBackendType,
CacheTransceiverConfig as _CacheTransceiverConfig,
CapacitySchedulerPolicy as _CapacitySchedulerPolicy,
ContextChunkingPolicy as _ContextChunkingPolicy,
DecodingConfig,
DecodingMode,
DynamicBatchConfig as _DynamicBatchConfig,
EagleConfig as _EagleConfig,
ExecutorConfig as _ExecutorConfig,
ExtendedRuntimePerfKnobConfig as _ExtendedRuntimePerfKnobConfig,
KvCacheConfig as _KvCacheConfig,
LookaheadDecodingConfig as _LookaheadDecodingConfig,
PeftCacheConfig as _PeftCacheConfig,
SchedulerConfig as _SchedulerConfig) # isort: skip
# isort: on
# yapf: enable
from ..builder import BuildConfig, EngineConfig
from ..logger import logger
from ..mapping import CpType, Mapping
from ..models.automodel import AutoConfig
from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig,
SpeculativeDecodingMode)
from ..sampling_params import BatchedLogitsProcessor
from .build_cache import BuildCacheConfig
from .tokenizer import TokenizerBase, tokenizer_factory
from .utils import (StrictBaseModel, generate_api_docs_as_docstring,
get_type_repr)
TypeBaseModel = TypeVar("T", bound=BaseModel)
if TYPE_CHECKING:
from tensorrt_llm._torch.virtual_memory import \
RestoreMode as _VirtualMemoryRestoreMode
else:
_VirtualMemoryRestoreMode = Enum
def Field(default: Any = ...,
*,
status: Optional[Literal["prototype", "beta", "deprecated"]] = None,
**kwargs: Any) -> Any:
"""Custom Field wrapper that adds status to json_schema_extra.
Args:
default: The default value for the field
status: Optional status indicator that gets added to json_schema_extra.
- None: Stable.
- "beta": Recommended for use per the latest documentation.
- "prototype": Not yet stable and subject to breaking changes; intended for experimentation only.
**kwargs: All other arguments passed to the original Pydantic Field
Returns:
A Pydantic FieldInfo object with the status added to json_schema_extra if provided
"""
if status is not None:
json_schema_extra = kwargs.get('json_schema_extra', {})
if isinstance(json_schema_extra, dict):
json_schema_extra['status'] = status
else:
# If json_schema_extra is not a dict, create a new dict with the status
json_schema_extra = {'status': status}
kwargs['json_schema_extra'] = json_schema_extra
return PydanticField(default, **kwargs)
class CudaGraphConfig(StrictBaseModel):
"""
Configuration for CUDA graphs.
"""
# List of batch sizes to create CUDA graphs for.
batch_sizes: Optional[List[int]] = Field(
default=None,
description="List of batch sizes to create CUDA graphs for.")
max_batch_size: NonNegativeInt = Field(
default=0, description="Maximum batch size for CUDA graphs.")
enable_padding: bool = Field(
default=False,
description=
"If true, batches are rounded up to the nearest cuda_graph_batch_size. This is usually a net win for performance."
)
@model_validator(mode='after')
def validate_cuda_graph_config(self) -> 'CudaGraphConfig':
"""Validate CUDA graph configuration.
Ensures that:
1. If batch_sizes is provided, max_batch_size is derived as max(batch_sizes).
If max_batch_size was already set it must be compatible (equal to max(batch_sizes));
otherwise an error is raised.
2. If only max_batch_size is provided, batch_sizes is generated from it.
3. If neither is provided, a default max_batch_size of 128 is used.
"""
if self.batch_sizes:
self.batch_sizes = sorted(self.batch_sizes)
derived_max = max(self.batch_sizes)
if self.max_batch_size == 0:
self.max_batch_size = derived_max
elif self.max_batch_size != derived_max:
raise ValueError(
"CudaGraphConfig.max_batch_size is incompatible with "
"CudaGraphConfig.batch_sizes. When both are provided, "
"max_batch_size must equal max(batch_sizes).\n"
f"CudaGraphConfig.batch_sizes: {self.batch_sizes}, "
f"max(batch_sizes): {derived_max}, "
f"CudaGraphConfig.max_batch_size: {self.max_batch_size}")
else:
max_batch_size = self.max_batch_size or 128
generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes(
max_batch_size, self.enable_padding)
self.batch_sizes = generated_sizes
self.max_batch_size = max_batch_size
return self
@staticmethod
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
enable_padding: bool) -> List[int]:
"""Generate a list of batch sizes for CUDA graphs.
Args:
max_batch_size: Maximum batch size to generate up to
enable_padding: Whether padding is enabled, which affects the batch size distribution
Returns:
List of batch sizes to create CUDA graphs for
"""
if enable_padding:
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
else:
batch_sizes = list(range(1, 32)) + [32, 64, 128]
# Add powers of 2 up to max_batch_size
batch_sizes += [
2**i for i in range(8, math.ceil(math.log(max_batch_size, 2)))
]
# Filter and sort batch sizes
batch_sizes = sorted(
[size for size in batch_sizes if size <= max_batch_size])
# Add max_batch_size if not already included
if max_batch_size != batch_sizes[-1]:
batch_sizes.append(max_batch_size)
return batch_sizes
@staticmethod
def _merge_schedule_keys(batch_sizes: List[int],
schedule: dict[int, int]) -> List[int]:
"""Merge draft_len_schedule keys into batch_sizes so that each
schedule threshold has a corresponding CUDA graph.
e.g. draft_len_schedule={100:4, 200:3, 300:2} adds 100, 200, 300
into batch_sizes.
Args:
batch_sizes: Sorted list of existing CUDA graph batch sizes.
schedule: draft_len_schedule mapping batch-size thresholds to
draft lengths.
Returns:
Sorted, deduplicated list of batch sizes.
"""
max_bs = batch_sizes[-1]
extra = sorted(bs for bs in schedule if bs <= max_bs)
if not extra:
return batch_sizes
merged = []
i, j = 0, 0
while i < len(batch_sizes) and j < len(extra):
if batch_sizes[i] < extra[j]:
merged.append(batch_sizes[i])
i += 1
elif batch_sizes[i] > extra[j]:
merged.append(extra[j])
j += 1
else:
merged.append(batch_sizes[i])
i += 1
j += 1
merged.extend(batch_sizes[i:])
merged.extend(extra[j:])
return merged
class GuidedDecodingConfig(StrictBaseModel):
class GuidedDecodingBackend(Enum):
XGRAMMAR = 0
LLGUIDANCE = 1
backend: GuidedDecodingBackend = Field(
default=GuidedDecodingBackend.XGRAMMAR,
description="The backend for guided decoding config.")
encoded_vocab: Optional[List[str]] = Field(
default=None,
description="The encoded vocab for guided decoding config.")
tokenizer_str: Optional[str] = Field(
default=None,
description="The tokenizer string for guided decoding config.")
stop_token_ids: Optional[List[int]] = Field(
default=None,
description="The stop token ids for guided decoding config.")
class BaseSparseAttentionConfig(StrictBaseModel):
"""
Configuration for sparse attention.
"""
algorithm: str
seq_len_threshold: Optional[int] = Field(
default=None,
description=
"The sequence length threshold for separating short and long sequences."
)
def supports_backend(self, backend: str) -> bool:
"""
Override if the sparse attention algorithm does not support
a subset of the possible backends.
"""
return True
def get_indices_block_size(self) -> int:
return 1
def needs_separate_short_long_cuda_graphs(self) -> bool:
"""
Determines whether to capture a dedicated CUDA graph for batches consisting entirely of short sequences.
If True, capture distinct graphs for short-only batches and general cases (e.g., long or mixed batches).
If False, capture a single unified CUDA graph for all sequences regardless of length.
The seq_len_threshold parameter defines the cutoff boundary between short and long sequences.
"""
return False
class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
"""
Configuration for RocketKV sparse attention.
"""
algorithm: Literal["rocket"] = "rocket"
window_size: Optional[int] = Field(
default=32, description="The window size for RocketKV.")
kernel_size: Optional[int] = Field(
default=63, description="The kernel size for RocketKV.")
topr: Optional[Union[int, float]] = Field(default=128, description="Top-r")
topk: Optional[int] = Field(default=64, description="Top-k")
prompt_budget: Optional[int] = Field(default=2048,
description="Prompt budget")
page_size: Optional[int] = Field(default=4, description="Page size")
kt_cache_dtype: Optional[str] = Field(
default='float8_e5m2',
choices=['bfloat16', 'float8_e5m2'],
description="KT cache dtype",
)
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
def get_indices_block_size(self) -> int:
return self.page_size
class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig):
"""
Configuration for DeepSeek Sparse Attention.
"""
algorithm: Literal["dsa"] = "dsa"
index_n_heads: Optional[int] = Field(
default=None, description="The number of heads for the indexer.")
index_head_dim: Optional[int] = Field(
default=None, description="The dimension of the indexer heads.")
index_topk: Optional[int] = Field(default=None,
description="The topk for the indexer.")
indexer_max_chunk_size: Optional[int] = Field(
default=None, description="The maximum chunk size for the indexer.")
skip_indexer_for_short_seqs: bool = Field(
default=True,
description=
"Whether to skip the MQA and Top-K in the indexer for short sequences.")
q_split_threshold: int = Field(
default=8192,
description=
"If number of packed tokens in prefill chunk exceeds this threshold, \
q tokens will be evenly distributed across ranks for indexer computation. \
If negative, q split will always be disabled.")
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
def needs_separate_short_long_cuda_graphs(self) -> bool:
"""
Whether to capture separate CUDA graphs for short and long sequences.
Use seq_len_threshold to determine the threshold for separating short and long sequences.
"""
self.seq_len_threshold = self.index_topk
return self.skip_indexer_for_short_seqs
class SkipSoftmaxAttentionConfig(BaseSparseAttentionConfig):
"""
Configuration for skip softmax attention.
"""
algorithm: Literal["skip_softmax"] = "skip_softmax"
threshold_scale_factor: Optional[Union[float, Dict[str, float]]] = Field(
default=None,
description="The threshold scale factor for skip softmax attention.")
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
@property
def threshold_scale_factor_prefill(self) -> Optional[float]:
if isinstance(self.threshold_scale_factor, dict):
return self.threshold_scale_factor.get('prefill', None)
return self.threshold_scale_factor
@property
def threshold_scale_factor_decode(self) -> Optional[float]:
if isinstance(self.threshold_scale_factor, dict):
return self.threshold_scale_factor.get('decode', None)
return self.threshold_scale_factor
class MoeLoadBalancerConfig(StrictBaseModel):
"""
Pydantic configuration model for the Mixture of Experts (MoE) load balancer.
This model holds configuration data (`num_slots`, etc.) as well as
runtime state (`_ep_rank`, `_ep_size`) which must be set via the
`setup()` method before use.
"""
num_slots: Optional[int] = None
initial_global_assignments: Optional[Dict[int, List[int]]] = Field(
default=None,
repr=False # Exclude this large dict from model representation
)
layer_updates_per_iter: int = 0
_ep_rank: Optional[int] = PrivateAttr(default=None)
_ep_size: Optional[int] = PrivateAttr(default=None)
# --- Methods ---
def setup(self, ep_rank: int, ep_size: int) -> None:
"""
Initializes the runtime state of the configuration.
This must be called before accessing properties like `num_local_slots`.
"""
self._ep_rank = ep_rank
self._ep_size = ep_size
# This assertion was in the original and is critical.
if self.num_slots is None:
raise ValueError("`num_slots` cannot be None when calling setup().")
if self.num_slots % ep_size != 0:
raise ValueError(
f"`num_slots` ({self.num_slots}) must be divisible by `ep_size` ({ep_size})."
)
# --- Computed Properties ---
# These properties depend on the runtime state set by setup()
@property
def ep_rank(self) -> int:
"""Public accessor for the private expert parallel rank."""
if self._ep_rank is None:
raise AttributeError("ep_rank is not set. Call setup() first.")
return self._ep_rank
@property
def ep_size(self) -> int:
"""Public accessor for the private expert parallel size."""
if self._ep_size is None:
raise AttributeError("ep_size is not set. Call setup() first.")
return self._ep_size
@property
def num_local_slots(self) -> int:
"""Calculates the number of slots local to this rank."""
if self.num_slots is None or self._ep_size is None:
raise ValueError(
"Cannot calculate `num_local_slots`. "
"`num_slots` must be set and setup() must be called.")
return self.num_slots // self._ep_size
@property
def slot_start(self) -> int:
"""Calculates the starting global slot index for this rank."""
if self._ep_rank is None:
raise ValueError(
"Cannot calculate `slot_start`. Call setup() first.")
return self._ep_rank * self.num_local_slots
@property
def slot_end(self) -> int:
"""Calculates the ending global slot index (exclusive) for this rank."""
return self.slot_start + self.num_local_slots
def get_layer_initial_global_assignments(
self, layer_idx: int) -> Optional[List[int]]:
"""
Retrieves the initial global assignments for a specific layer.
"""
if self.initial_global_assignments is None:
return None
if layer_idx not in self.initial_global_assignments:
raise KeyError(
f"layer_idx {layer_idx} not found in `initial_global_assignments`."
)
assignments = self.initial_global_assignments[layer_idx]
if self.num_slots is None:
raise ValueError(
"`num_slots` is not set, cannot verify assignment length.")
if len(assignments) != self.num_slots:
raise ValueError(
f"Assignment length ({len(assignments)}) for layer {layer_idx} "
f"does not match `num_slots` ({self.num_slots}).")
return assignments
class MoeConfig(StrictBaseModel):
"""
Configuration for MoE.
"""
backend: Literal[
"AUTO", "CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", "VANILLA",
"TRITON"] = Field(
default='AUTO',
description="MoE backend to use. "
"AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases."
)
max_num_tokens: Optional[int] = Field(
default=None,
description=
"If set, at most max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. If the number of tokens exceeds max_num_tokens, the input tensors will be split into chunks and a for loop will be used."
)
load_balancer: Optional[Union[object, str]] = Field(
default=None,
description="Configuration for MoE load balancing.",
json_schema_extra={"type": "Union[MoeLoadBalancerConfig, dict, str]"})
disable_finalize_fusion: bool = Field(
default=False,
description=
"Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2."
)
use_low_precision_moe_combine: bool = Field(
default=False,
description=
"Use low precision combine in MoE operations (only for NVFP4 quantization). When enabled, uses lower precision for combining expert outputs to improve performance."
)
Nvfp4Backend = Literal['cutlass', 'cublaslt', 'cutedsl', 'cuda_core']
class Nvfp4GemmConfig(StrictBaseModel):
"""
Configuration for NVFP4 GEMM backend selection.
"""
allowed_backends: List[Nvfp4Backend] = Field(
default_factory=lambda: ['cutlass', 'cublaslt', 'cuda_core'],
min_length=1,
description="List of backends to consider for auto-selection. "
"Default excludes 'cutedsl' for faster build time. "
"Add 'cutedsl' for extreme performance at the cost of longer server launch time."
)
class AttentionDpConfig(StrictBaseModel):
"""
Configuration for attention DP.
"""
enable_balance: bool = Field(default=False,
description="Whether to enable balance.")
timeout_iters: int = Field(
default=50, description="The number of iterations to timeout.")
batching_wait_iters: int = Field(
default=10,
description="The number of iterations to wait for batching.")
@model_validator(mode='after')
def validate_attention_dp_config(self) -> 'AttentionDpConfig':
if self.enable_balance:
if self.batching_wait_iters < 0:
raise ValueError(
"attention_dp_config.batching_wait_iters must be greater or equal to 0 when enable_balance is true"
)
if self.timeout_iters < 0:
raise ValueError(
"attention_dp_config.timeout_iters must be greater or equal to 0 when enable_balance is true"
)
return self
class CpConfig(StrictBaseModel):
"""
Configuration for context parallelism.
"""
# TODO: given that multiple fields here are only used with specific cp_types, consider
# making this a Pydantic discriminated union.
cp_type: CpType = Field(default=CpType.ULYSSES,
description="Context parallel type.")
tokens_per_block: Optional[int] = Field(
default=None,
description="Number of tokens per block. Used in HELIX parallelism.")
use_nccl_for_alltoall: Optional[bool] = Field(
default=None,
description=
"Whether to use NCCL for alltoall communication. Used in HELIX parallelism. Defaults to True."
)
fifo_version: Optional[int] = Field(
default=None,
description=
"FIFO version for alltoall communication. Used in HELIX parallelism. Defaults to 2."
)
cp_anchor_size: Optional[int] = Field(
default=None, description="Anchor size for STAR attention.")
block_size: Optional[int] = Field(
default=None, description="Block size for STAR attention.")
@field_validator("cp_type", mode="before")
@classmethod
def validate_cp_type(cls, v):
"""Normalize cp_type string to uppercase."""
if v is None:
return None
if isinstance(v, str):
return v.upper()
return v
class _ParallelConfig(StrictBaseModel):
"""The model distribution configs for LLM."""
tp_size: int = 1
pp_size: int = 1
cp_size: int = 1
gpus_per_node: int = 8
# Set default for MoE fields to -1 to trigger auto-calculation in Mapping
moe_cluster_size: int = -1
moe_tp_size: int = -1
moe_ep_size: int = -1
cp_config: Optional[CpConfig] = Field(default=None)
pp_partition: Optional[List[int]] = Field(default=None)
enable_attention_dp: bool = False
enable_lm_head_tp_in_adp: bool = False
_devices: Optional[List[int]] = PrivateAttr(default=None)
@property
def devices(self) -> List[int]:
if self._devices is None:
return list(range(self.world_size))
return self._devices
@devices.setter
def devices(self, devices: List[int]):
if len(devices) != self.world_size:
raise ValueError(
f"devices {devices} should have the same length as world_size {self.world_size}"
)
self._devices = devices
@property
def world_size(self) -> int:
return self.tp_size * self.pp_size * self.cp_size
@property
def world_size_per_node(self) -> int:
world_size = self.world_size
total_nodes = math.ceil(world_size / self.gpus_per_node)
return world_size // total_nodes #TODO is this right?
@world_size.setter
def world_size(self, world_size: int):
if world_size != self.tp_size * self.pp_size * self.cp_size:
raise ValueError(
f"world_size {world_size} should be equal to tp_size * pp_size * cp_size {self.tp_size * self.pp_size * self.cp_size} "
)
@property
def is_multi_gpu(self) -> bool:
return self.world_size > 1
def to_mapping(self) -> Mapping:
return Mapping(
world_size=self.world_size,
rank=mpi_rank(),
gpus_per_node=self.gpus_per_node,
tp_size=self.tp_size,
pp_size=self.pp_size,
pp_partition=self.pp_partition,
cp_size=self.cp_size,
# TODO: Mapping still uses cp_config as a dict; migrate to CpConfig
cp_config=self.cp_config.model_dump(
exclude_none=True) if self.cp_config else {},
enable_attention_dp=self.enable_attention_dp,
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
moe_cluster_size=self.moe_cluster_size,
moe_tp_size=self.moe_tp_size,
moe_ep_size=self.moe_ep_size)
class CalibConfig(StrictBaseModel):
"""
Calibration configuration.
"""
device: Literal['cuda',
'cpu'] = Field(default='cuda',
description="The device to run calibration.")
calib_dataset: str = Field(
default='cnn_dailymail',
description="The name or local path of calibration dataset.")
calib_batches: int = Field(
default=512,
description="The number of batches that the calibration runs.")
calib_batch_size: int = Field(
default=1, description="The batch size that the calibration runs.")
calib_max_seq_length: int = Field(
default=512,
description="The maximum sequence length that the calibration runs.")
random_seed: int = Field(
default=1234, description="The random seed used for calibration.")
tokenizer_max_seq_length: int = Field(
default=2048,
description=
"The maximum sequence length to initialize tokenizer for calibration.")
class _ModelFormatKind(Enum):
HF = 0
TLLM_CKPT = 1
TLLM_ENGINE = 2
class DecodingBaseConfig(StrictBaseModel):
max_draft_len: Optional[NonNegativeInt] = Field(
default=None, description="The maximum number of draft tokens.")
max_total_draft_tokens: Optional[int] = Field(
default=None,
description=
"The number of draft tokens in the draft tokens tree. If it's a linear tree, each draft layer will "
"only generate one draft token. In this case, max_draft_len == max_total_draft_tokens. If it's a static or "
"dynamic tree, each draft layer may generate more than one draft token. In this case, "
"max_total_draft_tokens >= max_draft_len.")
speculative_model: Optional[Union[str, Path]] = Field(
default=None,
validation_alias=AliasChoices("speculative_model",
"speculative_model_dir"),
description=
"The speculative (draft) model. Accepts either (1) a HuggingFace Hub model ID (e.g. 'yuhuili/EAGLE3-LLaMA3.1-Instruct-8B'), "
"which will be automatically downloaded, or (2) a local filesystem path to a downloaded model directory."
)
max_concurrency: Optional[PositiveInt] = Field(
default=None,
description=
"When specified (>0), speculation will be disabled at batch sizes above this value. Otherwise, "
"speculation will always be on. PyTorch backend only. "
"Mutually exclusive with max_concurrency since draft_len_schedule implicitly supports max concurrency control."
)
draft_len_schedule: Optional[dict[int, int]] = Field(
default=None,
description=
"Developer interface: dynamically adjust draft length based on active batch size in runtime."
"Maps batch size to draft lengths."
"For example: draft_len_schedule = {4:4, 8:2, 32:1}"
" - Batch sizes 1-4: use draft_len=4"
" - Batch sizes 5-8: use draft_len=2"
" - Batch sizes 9-32: use draft_len=1"
" - Batch sizes 33+: use draft_len=0 (implicit, speculation disabled). "
"Mutually exclusive with max_concurrency since draft_len_schedule implicitly support max concurrency control."
)
load_format: Optional[str] = Field(
default=None, description="The load format of the speculative model.")
acceptance_window: Optional[NonNegativeInt] = Field(
default=None,
description=
"The rolling average window size (N) for acceptance length across completed requests. "
"If not set or set to 0, the feature is disabled. PyTorch backend only."
)
acceptance_length_threshold: Optional[NonNegativeFloat] = Field(
default=None,
description=
"The threshold for average acceptance length; speculation will be disabled permanently once the "
"rolling average over the last N completed requests (N = acceptance_window) drops below this value. "
"PyTorch backend only.")
allow_advanced_sampling: bool = Field(
default=False,
status="prototype",
description=
"If true, allows non-greedy sampling when speculation is used. Only applicable "
"to 1-model code paths; non-greedy sampling is always enabled on 2-model paths."
)
# If set, drafting is allowed to use chain drafter.
_allow_chain_drafter: bool = PrivateAttr(True)
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
# Internal: record decoding_type alias used during parsing (for warnings).
_decoding_type_alias: Optional[str] = PrivateAttr(default=None)
# If set, drafting will use separate KV cache in one-model speculative decoding.
_allow_separate_draft_kv_cache: bool = PrivateAttr(True)
# Internal: true when draft_len_schedule was auto-translated from max_concurrency.
_translated_from_max_concurrency: bool = PrivateAttr(False)
@field_validator('draft_len_schedule')
@classmethod
def validate_draft_len_schedule_and_sort(cls, v, info):
"""Validate and sort draft_len_schedule by batch size thresholds."""
if v is not None:
# Validate values
for batch_size, draft_len in v.items():
if batch_size < 1:
raise ValueError(
f"draft_len_schedule: batch size threshold must be >= 1, got {batch_size}"
)
if draft_len < 0:
raise ValueError(
f"draft_len_schedule: draft length must be >= 0, got {draft_len}"
)
# Enforce smallest schedule key maps to max_draft_len for consistency.
smallest_batch_size = min(v.keys())
max_draft_len = info.data.get('max_draft_len')
if max_draft_len is not None and v[
smallest_batch_size] != max_draft_len:
raise ValueError(
f"draft_len_schedule[{smallest_batch_size}] must equal max_draft_len "
f"because it is the smallest batch-size key. "
f"Got schedule[{smallest_batch_size}]={v[smallest_batch_size]}, "
f"but max_draft_len={max_draft_len}.")
# Enforce all draft lengths <= max_draft_len
if max_draft_len is not None:
for batch_size, draft_len in v.items():
if draft_len > max_draft_len:
raise ValueError(
f"draft_len_schedule: all draft lengths must be <= max_draft_len. "
f"Got draft_len={draft_len} for batch_size={batch_size}, "
f"but max_draft_len={max_draft_len}.")
# Return sorted dict (by batch size thresholds)
# This ensures efficient lookup
return dict(sorted(v.items(), key=lambda x: x[0]))
return v
@model_validator(mode='after')
# 1. Validate that max_concurrency and draft_len_schedule are mutually exclusive.
# 2. If max_concurrency is set, translate it to the corresponding draft_len_schedule.
def validate_max_concurrency_and_draft_len_schedule_mutually_exclusive(
self) -> "DecodingBaseConfig":
if self.max_concurrency is not None and self.draft_len_schedule is not None:
# Avoid ValueError during nested re-validation when only max_concurrency is set and draft_len_schedule is translated from max_concurrency
if self._translated_from_max_concurrency:
return self
raise ValueError(
"max_concurrency and draft_len_schedule are mutually exclusive. "
"Use max_concurrency for a simple speculation cutoff, or "
"draft_len_schedule for dynamic draft-length control.")
if self.max_concurrency is None:
return self
if (self.max_draft_len is None
or not self.spec_dec_mode.support_dynamic_draft_len()):
return self
self.draft_len_schedule = {
int(self.max_concurrency): int(self.max_draft_len)
}
self._translated_from_max_concurrency = True
return self
def supports_backend(self, backend: str) -> bool:
"""
Override if the speculation algorithm does not support
a subset of the possible backends.
"""
return True
@property
def spec_dec_mode(self):
# spec_dec_mode has more functionality than the raw decoding_mode string.
# Use an alias for the import here to avoid name collisions with the one for the
# TRT backend.
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
return TorchSpeculativeDecodingMode.from_string(
self.decoding_type.upper())
@functools.cached_property
def is_linear_tree(self) -> bool:
return self.max_draft_len == self.max_total_draft_tokens
@property
def tokens_per_gen_step(self) -> int:
"""Total tokens per gen request in one spec dec iteration (including golden token)."""
return 1 + self.max_total_draft_tokens
def num_capture_layers(self) -> int:
return 0
class KvCacheConnectorConfig(StrictBaseModel):
"""
Configuration for the KV Cache Connector.
"""
connector_module: str = Field(
...,
description=
"The import path to the connector module. It will be imported with `importlib.import_module`."
)
connector_scheduler_class: str = Field(
..., description="The class name of the scheduler within the module.")
connector_worker_class: str = Field(
..., description="The class name of the worker within the module.")
class LayerwiseBenchmarksConfig(StrictBaseModel):
"""
Configuration for layer-wise benchmarks calibration.
"""
calibration_mode: Literal["NONE", "MARK", "COLLECT"] = Field(
default="NONE",
description=
"Instruct the layer-wise benchmarks calibrator to work on MARK mode, or COLLECT mode",
status="prototype")
calibration_file_path: Optional[str] = Field(
default=None,
description=
"The file path which the layer-wise benchmarks calibrator saves to or loads from",
status="prototype")
calibration_layer_indices: Optional[List[int]] = Field(
default=None,
description=
"Layer indices to filter. If None, all layers are collected in COLLECT mode.",
status="prototype")
@model_validator(mode='after')
def validate_calibration_file_path(self) -> 'LayerwiseBenchmarksConfig':
if self.calibration_mode == "COLLECT" and not self.calibration_file_path:
raise ValueError(
f"Expect calibration_file_path not to be empty when work on {self.calibration_mode} mode"
)
return self
class MedusaDecodingConfig(DecodingBaseConfig):
decoding_type: Literal["Medusa"] = "Medusa"
medusa_choices: Optional[List[List[int]]] = Field(
default=None,
description=
"Tree structure for Medusa draft token generation. Each sublist represents a path in the tree where elements are token indices at each level. "
"For example, [[0], [0, 0], [1], [0, 1]] defines multiple branches.")
num_medusa_heads: Optional[int] = Field(
default=None,
description=
"Number of Medusa prediction heads to use. Each head predicts a draft token at a different position in parallel. "
"If not specified, defaults to the 'medusa_num_heads' value from the Medusa model's config.json."
)
@model_validator(mode="after")
def set_max_total_draft_tokens(self):
self.max_total_draft_tokens = self.max_draft_len # Current Medusa only supports linear tree
return self
def supports_backend(self, backend: str) -> bool:
return backend not in ("pytorch", "_autodeploy")
class EagleDecodingConfig(DecodingBaseConfig):
decoding_type: Literal["Eagle"] = "Eagle"
eagle_choices: Optional[List[List[int]]] = Field(
default=None,
description=
"Static tree structure for draft token generation. Each sublist represents a path in the tree. Mutually exclusive with use_dynamic_tree."
)
greedy_sampling: Optional[bool] = Field(
default=True,
description=
"Whether to use greedy sampling (Top-1 with token equality acceptance) or typical acceptance with multinomial sampling."
)
posterior_threshold: Optional[float] = Field(
default=None,
description=
"Minimum token probability threshold for typical acceptance. Corresponds to epsilon in https://arxiv.org/pdf/2401.10774."
)
use_dynamic_tree: Optional[bool] = Field(
default=False,
description=
"Whether to use dynamic tree (Eagle-2 algorithm). Mutually exclusive with eagle_choices."
)
dynamic_tree_max_topK: Optional[int] = Field(
default=None,
description="The topK value for each layer when dynamic tree is enabled."
)
num_eagle_layers: Optional[int] = Field(
default=None,
description=
"The number of eagle layers. Will not be used in pytorch flow, just for compatibility with TRT flow."
)
max_non_leaves_per_layer: Optional[int] = Field(
default=None, description="The number of non-leaves in each layer.")
eagle3_one_model: Optional[bool] = Field(
default=True,
description=
"Whether to use the faster one-model implementation (draft as submodule) or the two-model implementation."
)
eagle3_layers_to_capture: Optional[Set[int]] = Field(
default=None,
description=
"Target model layer indices to capture hidden states from for the EAGLE3 draft model. Defaults to {1, num_layers//2-1, num_layers-4}."
)
eagle3_model_arch: Literal["llama3", "mistral_large3"] = Field(
default="llama3",
description="The model architecture of the eagle3 model.")
@field_validator('eagle_choices', mode='before')
@classmethod
def validate_eagle_choices(cls, v):
if v is not None:
logger.warning(
"The eagle_choices/static tree feature is deprecated and will be removed in release 1.4."
)
if not isinstance(v, list):
if isinstance(v, str):
v = ast.literal_eval(v.replace(" ", ""))
else:
raise ValueError(
"Wrong eagle choices type. Eagle choices should be a List[List[int]] or a string like [[0], [1], [2], [0, 0], [0, 1]]."
)
return v
@model_validator(mode='after')
def validate_eagle_config(self) -> 'EagleDecodingConfig':