Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
MoEAllReduce, MoEAllReduceParams, allgather,
cp_allgather)
MoEAllReduce, MoEAllReduceParams, allgather)
from ..model_config import ModelConfig
from ..modules.attention import MLA
from ..modules.attention import (MLA, maybe_allgather_for_helix_cp,
maybe_slice_for_helix_cp)
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoE,
Expand Down Expand Up @@ -1177,6 +1177,8 @@ def __init__(self,
mapping_with_cp: Optional[Mapping] = None):
super().__init__()
self.model_config = model_config
self.layer_idx = layer_idx
self.mapping_with_cp = mapping_with_cp
self.config = model_config.pretrained_config
config = self.config

Expand Down Expand Up @@ -1300,7 +1302,6 @@ def __init__(self,
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.layer_idx = layer_idx
self.next_layer_layernorm: RMSNorm = None

def _get_decoder_layer_quant_config(
Expand Down Expand Up @@ -1369,15 +1370,17 @@ def forward(
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, residual = self.self_attn(
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
residual=residual,
**kwargs,
)
residual = maybe_slice_for_helix_cp(residual, attn_metadata,
self.mapping_with_cp,
self.layer_idx)
if isinstance(self.mlp, Deepseekv3MoE):
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
Expand Down Expand Up @@ -1637,13 +1640,12 @@ def norm_hidden():
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, residual = self.self_attn(
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
residual=residual,
**kwargs,
)

Expand Down Expand Up @@ -1757,17 +1759,9 @@ def forward(
spec_metadata=spec_metadata,
)

# With CP helix, the last layer's reduce-scatter leaves each rank
# with only its chunk of tokens. AllGather restores the full token
# count so the LM head (and norm) see every token.
if (self.mapping_with_cp is not None
and self.mapping_with_cp.has_cp_helix()
and self.mapping_with_cp.enable_attention_dp):
hidden_states = cp_allgather(hidden_states,
self.mapping_with_cp,
dim=0)
hidden_states = hidden_states[:attn_metadata.num_tokens]

hidden_states = maybe_allgather_for_helix_cp(hidden_states,
attn_metadata,
self.mapping_with_cp)
return hidden_states


Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..modules.attention import (maybe_allgather_for_helix_cp,
maybe_slice_for_helix_cp)
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
Expand Down Expand Up @@ -95,6 +97,7 @@ def __init__(
self.layer_idx = layer_idx
config = model_config.pretrained_config
self.mapping = model_config.mapping
self.mapping_with_cp = mapping_with_cp
self.enable_attention_dp = self.mapping.enable_attention_dp

# When enable_attention_dp is True, TP reduction is skipped since each DP rank
Expand Down Expand Up @@ -168,6 +171,9 @@ def forward(
mrope_config=mrope_config,
**kwargs,
)
residual = maybe_slice_for_helix_cp(residual, attn_metadata,
self.mapping_with_cp,
self.layer_idx)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
Expand Down Expand Up @@ -198,6 +204,7 @@ def __init__(self,
mapping_with_cp: Optional[Mapping] = None):
super().__init__(model_config)
config = self.model_config
self.mapping_with_cp = mapping_with_cp

self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
Expand Down Expand Up @@ -256,6 +263,9 @@ def forward(
)

hidden_states, _ = self.norm(hidden_states, residual)
hidden_states = maybe_allgather_for_helix_cp(hidden_states,
attn_metadata,
self.mapping_with_cp)
return hidden_states


Expand Down
Loading