Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,16 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
self.pp_size = config.mapping.pp_size
self.has_custom_lm_head = False

if config.mapping.enable_lm_head_tp_in_adp:
has_mtp = (config.spec_config is not None
and (config.spec_config.spec_dec_mode.is_mtp_one_model()
or config.spec_config.spec_dec_mode.is_mtp_eagle()))
if not has_mtp:
logger.warning(
"enable_lm_head_tp_in_adp is set but MTP speculative "
"decoding is not configured. This option only takes effect "
"with MTP enabled. The flag will have no effect.")

if config.mapping.enable_attention_dp and not config.mapping.enable_lm_head_tp_in_adp:
self.lm_head = LMHead(
vocab_size,
Expand Down
72 changes: 54 additions & 18 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,14 @@ def validate_attention_dp_config(self) -> 'AttentionDpConfig':


class CpConfig(StrictBaseModel):
"""
Configuration for context parallelism.
"""Configuration for context parallelism.

Available cp_types:
- ULYSSES: Default when cp_size=1 (no context parallelism). Not supported
by the attention layer when cp_size > 1.
- HELIX: Required for context parallelism (cp_size > 1). Designed for
models using Multi-head Latent Attention (MLA) such as DeepSeek-V3/R1.
Requires Blackwell GPUs.
"""
# TODO: given that multiple fields here are only used with specific cp_types, consider
# making this a Pydantic discriminated union.
Expand Down Expand Up @@ -2370,7 +2376,10 @@ class BaseLlmArgs(StrictBaseModel):
moe_cluster_parallel_size: Optional[int] = Field(
default=None,
description="The cluster parallel size for MoE model's expert weights.",
status="beta")
status="deprecated",
deprecated=
"moe_cluster_parallel_size is deprecated and will be removed in a future release."
)

moe_tensor_parallel_size: Optional[int] = Field(
default=None,
Expand All @@ -2381,13 +2390,15 @@ class BaseLlmArgs(StrictBaseModel):
description="The expert parallel size for MoE model's expert weights.")

enable_attention_dp: bool = Field(
default=False,
description="Enable attention data parallel.",
status="beta")
default=False, description="Enable attention data parallel.")

enable_lm_head_tp_in_adp: bool = Field(
default=False,
description="Enable LM head TP in attention dp.",
description=
"Enable tensor parallelism for the LM head when attention data parallelism (ADP) is enabled. "
"This reduces LM head latency in wide expert-parallel MoE scenarios by splitting the LM head "
"weight matrix across GPUs. Requires enable_attention_dp=True and MTP speculative decoding. "
"Supported for MoE models with MTP enabled (e.g. DeepSeek-V3/R1, GLM-4-MoE, ExaoneMoE, NemotronH).",
status="prototype")

pp_partition: Optional[List[int]] = Field(
Expand All @@ -2398,7 +2409,10 @@ class BaseLlmArgs(StrictBaseModel):

cp_config: Optional[CpConfig] = Field(
default=None,
description="Context parallel config.",
description=
"Context parallel config. Available cp_types: ULYSSES (default when cp_size=1) "
"and HELIX (required for context parallelism with cp_size > 1, designed for models using "
"Multi-head Latent Attention such as DeepSeek-V3/R1, requires Blackwell GPUs).",
status="prototype")

load_format: Literal['auto', 'dummy'] = Field(
Expand Down Expand Up @@ -2435,12 +2449,12 @@ class BaseLlmArgs(StrictBaseModel):
iter_stats_max_iterations: Optional[int] = Field(
default=None,
description="The maximum number of iterations for iter stats.",
status="prototype")
status="beta")

request_stats_max_iterations: Optional[int] = Field(
default=None,
description="The maximum number of iterations for request stats.",
status="prototype")
status="beta")

# A handful of options from PretrainedConfig
peft_cache_config: Optional[PeftCacheConfig] = Field(
Expand Down Expand Up @@ -2525,7 +2539,7 @@ class BaseLlmArgs(StrictBaseModel):
default=None,
description="Target URL to which OpenTelemetry traces will be sent.",
alias="otlp_traces_endpoint",
status="prototype")
status="beta")

backend: Optional[str] = Field(
default=None,
Expand All @@ -2537,12 +2551,12 @@ class BaseLlmArgs(StrictBaseModel):

return_perf_metrics: bool = Field(default=False,
description="Return perf metrics.",
status="prototype")
status="beta")

perf_metrics_max_requests: NonNegativeInt = Field(
default=0,
description=
"The maximum number of requests for perf metrics. Must also set return_perf_metrics to true to get perf metrics.",
"The maximum number of responses to retain perf metrics for. Must also set return_perf_metrics to true to get perf metrics.",
status="prototype")

orchestrator_type: Optional[Literal["rpc", "ray"]] = Field(
Expand Down Expand Up @@ -2629,6 +2643,16 @@ def normalize_optional_fields_to_defaults(self):

@model_validator(mode="after")
def validate_parallel_config(self):
if self.enable_lm_head_tp_in_adp and not self.enable_attention_dp:
logger.warning(
"enable_lm_head_tp_in_adp has no effect without enable_attention_dp=True."
)

if "moe_cluster_parallel_size" in self.model_fields_set:
logger.warning(
"moe_cluster_parallel_size is deprecated and will be removed in a future release."
)

if self.moe_cluster_parallel_size is None:
self.moe_cluster_parallel_size = -1

Expand Down Expand Up @@ -2765,7 +2789,10 @@ class TrtLlmArgs(BaseLlmArgs):
default=False,
description=
"Fail fast when attention window is too large to fit even a single sequence in the KV cache.",
status="prototype")
status="deprecated",
deprecated=
"fail_fast_on_attention_window_too_large is deprecated and will be removed in a future release."
)

# Once set, the model will reuse the build_cache
enable_build_cache: Union[BuildCacheConfig,
Expand Down Expand Up @@ -2812,6 +2839,14 @@ class TrtLlmArgs(BaseLlmArgs):
_convert_checkpoint_options: Dict[str,
Any] = PrivateAttr(default_factory=dict)

@model_validator(mode="after")
def warn_deprecated_fields(self):
if "fail_fast_on_attention_window_too_large" in self.model_fields_set:
logger.warning(
"fail_fast_on_attention_window_too_large is deprecated and will be removed in a future release."
)
return self

@model_validator(mode="after")
def init_build_config(self):
"""
Expand Down Expand Up @@ -3231,13 +3266,13 @@ class TorchLlmArgs(BaseLlmArgs):
status="prototype")

torch_compile_config: Optional[TorchCompileConfig] = Field(
default=None, description="Torch compile config.", status="prototype")
default=None, description="Torch compile config.", status="beta")

enable_autotuner: bool = Field(
default=True,
description=
"Enable autotuner for all tunable ops. This flag is for debugging purposes only, and the performance may significantly degrade if set to false.",
status="prototype")
status="beta")

enable_layerwise_nvtx_marker: bool = Field(
default=False,
Expand Down Expand Up @@ -3297,12 +3332,13 @@ class TorchLlmArgs(BaseLlmArgs):
checkpoint_format: Optional[str] = Field(
default=None,
description=
"The format of the provided checkpoint. You may use a custom checkpoint format by subclassing "
"The format of the provided checkpoint. Available formats: 'HF', 'mistral', 'mistral_large_3'. "
"You may use a custom checkpoint format by subclassing "
"`BaseCheckpointLoader` and registering it with `register_checkpoint_loader`.\n"
"If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF "
"and the default HfCheckpointLoader will be used.\n"
"If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored.",
status="prototype",
status="beta",
)

kv_connector_config: Optional[KvCacheConnectorConfig] = Field(
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class CpType(StrEnum):
ULYSSES = "ULYSSES"
# CP type for star attention
STAR = "STAR"
# CP type for ring attention
RING = "RING"
# CP type for helix parallelism
HELIX = "HELIX"

Expand Down
20 changes: 8 additions & 12 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ methods:
moe_cluster_parallel_size:
annotation: Optional[int]
default: null
status: beta
enable_attention_dp:
annotation: bool
default: False
status: beta
status: deprecated
cp_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.CpConfig]
default: null
Expand All @@ -26,15 +22,15 @@ methods:
iter_stats_max_iterations:
annotation: Optional[int]
default: null
status: prototype
status: beta
request_stats_max_iterations:
annotation: Optional[int]
default: null
status: prototype
status: beta
return_perf_metrics:
annotation: bool
default: False
status: prototype
status: beta
# Bindings and mirrored configs
peft_cache_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig]
Expand Down Expand Up @@ -94,7 +90,7 @@ methods:
checkpoint_format:
annotation: Optional[str]
default: null
status: prototype
status: beta
mm_encoder_only:
annotation: bool
default: False
Expand Down Expand Up @@ -158,11 +154,11 @@ methods:
torch_compile_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.TorchCompileConfig]
default: null
status: prototype
status: beta
enable_autotuner:
annotation: bool
default: True
status: prototype
status: beta
enable_layerwise_nvtx_marker:
annotation: bool
default: False
Expand Down Expand Up @@ -202,7 +198,7 @@ methods:
otlp_traces_endpoint:
annotation: Optional[str]
default: null
status: prototype
status: beta
ray_worker_extension_cls:
annotation: Optional[str]
default: null
Expand Down
3 changes: 3 additions & 0 deletions tests/unittest/api_stability/references_committed/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ methods:
moe_expert_parallel_size:
annotation: Optional[int]
default: null
enable_attention_dp:
annotation: bool
default: False
# LoRA
enable_lora:
annotation: bool
Expand Down
Loading