Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import AllReduceParams
from ..distributed import AllReduceParams, cp_allgather
from ..model_config import ModelConfig
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
Expand Down Expand Up @@ -159,12 +159,13 @@ def forward(
hidden_states, residual)

# Self Attention
hidden_states = self.self_attn(
hidden_states, residual = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not self.disable_attn_allreduce),
residual=residual,
mrope_config=mrope_config,
**kwargs,
)
Expand Down Expand Up @@ -198,6 +199,7 @@ def __init__(self,
mapping_with_cp: Optional[Mapping] = None):
super().__init__(model_config)
config = self.model_config
self.mapping_with_cp = mapping_with_cp

self.embed_tokens = Embedding(
config.pretrained_config.vocab_size,
Expand Down Expand Up @@ -256,6 +258,15 @@ def forward(
)

hidden_states, _ = self.norm(hidden_states, residual)

if (self.mapping_with_cp is not None
and self.mapping_with_cp.has_cp_helix()
and self.mapping_with_cp.enable_attention_dp):
hidden_states = cp_allgather(hidden_states,
self.mapping_with_cp,
dim=0)
hidden_states = hidden_states[:attn_metadata.num_tokens]

return hidden_states


Expand Down
199 changes: 99 additions & 100 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,82 @@ def _helix_post_process(
gathered_o, gathered_stats, 1.0, 1)


def _helix_cp_pad(tensor: torch.Tensor, num_tokens: int,
cp_size: int) -> tuple[torch.Tensor, int]:
"""Pad tensor along dim-0 so its length is divisible by cp_size."""
chunk_size = math.ceil(num_tokens / cp_size)
padded_size = chunk_size * cp_size
if num_tokens < padded_size:
tensor = torch.nn.functional.pad(tensor,
(0, 0, 0, padded_size - num_tokens),
mode="constant",
value=0)
return tensor, chunk_size


def _helix_cp_slice(tensor: torch.Tensor, attn_metadata: AttentionMetadata,
mapping: Mapping) -> torch.Tensor:
"""Slice a tensor to this CP rank's chunk, matching post-RS size."""
tensor, chunk_size = _helix_cp_pad(tensor, attn_metadata.num_tokens,
mapping.cp_size)
start = mapping.cp_rank * chunk_size
return tensor[start:start + chunk_size]


def _helix_cp_allgather_input(hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mapping: Mapping, layer_idx: int) -> torch.Tensor:
"""AllGather hidden states from CP group for layers after the first.

The first layer already has the full input from the embedding.
Subsequent layers need to undo the previous layer's reduce-scatter.
"""
if (mapping.has_cp_helix() and mapping.enable_attention_dp
and layer_idx > 0):
hidden_states = cp_allgather(hidden_states, mapping, dim=0)
hidden_states = hidden_states[:attn_metadata.num_tokens]
return hidden_states


def _helix_cp_output_projection(
o_proj: "Linear",
attn_output: torch.Tensor,
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams],
residual: Optional[torch.Tensor],
mapping: Mapping,
mapping_o: Mapping,
layer_idx: int,
lora_params: Optional[dict] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Apply output projection with reduce-scatter when Helix CP+DP is active.

Reduce-scatter sums partial sums across the CP group and scatters the
result so each CP rank processes a distinct token chunk through the MLP.
Falls back to the standard AllReduce path otherwise.
"""
if mapping.has_cp_helix() and mapping.enable_attention_dp:
attn_output = o_proj(
attn_output,
all_reduce_params=AllReduceParams(enable_allreduce=False),
lora_params=lora_params,
layer_idx=layer_idx)

attn_output, _ = _helix_cp_pad(attn_output, attn_metadata.num_tokens,
mapping.cp_size)
attn_output = reducescatter(attn_output, mapping_o, dim=0)

if (layer_idx == 0 and residual is not None and residual is not ...):
residual = _helix_cp_slice(residual, attn_metadata, mapping)
else:
attn_output = o_proj(attn_output,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
layer_idx=layer_idx)

return attn_output, residual


class Attention(nn.Module):

def __init__(
Expand Down Expand Up @@ -397,6 +473,7 @@ def __init__(
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp,
)
self.mapping_o = mapping_o

self.o_proj = Linear(
tp_size * self.q_size,
Expand Down Expand Up @@ -718,8 +795,9 @@ def forward(
attention_window_size: Optional[int] = None,
attention_mask_data: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""
Forward pass for the Attention module.

Expand All @@ -733,9 +811,16 @@ def forward(
lora_params (Optional[dict]): The LoRA parameters.
attention_window_size (Optional[int]): The attention window size.
attention_mask_data (Optional[torch.Tensor]): The attention mask data.
residual (Optional[torch.Tensor]): The residual tensor. When
provided, the output projection uses reduce-scatter (if
applicable) and the method returns (attn_output, residual).
Returns:
torch.Tensor: The output tensor.
torch.Tensor or (torch.Tensor, torch.Tensor): The output tensor,
or a tuple of (output, residual) when residual is provided.
"""
hidden_states = _helix_cp_allgather_input(hidden_states, attn_metadata,
self.mapping, self.layer_idx)

qkv = self.qkv_proj(hidden_states)

if bool(lora_params):
Expand Down Expand Up @@ -782,6 +867,13 @@ def forward(
gate = torch.sigmoid(gate)
attn_output = attn_output * gate

if residual is not None:
attn_output, residual = _helix_cp_output_projection(
self.o_proj, attn_output, attn_metadata, all_reduce_params,
residual, self.mapping, self.mapping_o, self.layer_idx,
lora_params)
return attn_output, residual

attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
Expand Down Expand Up @@ -2470,98 +2562,6 @@ def forward_sparse_mla_kvcache_bf16(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
return output

def _needs_cp_reduce_scatter(self) -> bool:
"""Check if we should use CP reduce-scatter instead of AllReduce."""
return (self.mapping.has_cp_helix()
and self.mapping.enable_attention_dp)

def _maybe_allgather_input(
self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""AllGather input hidden states from CP group if needed.

For the first layer (Embed -> Attn), all CP ranks already have the
full input, so this is a no-op. For subsequent layers, the previous
layer's reduce-scatter left each rank with a portion that must be
reconstructed before attention.
"""
if self._needs_cp_reduce_scatter() and self.layer_idx > 0:
hidden_states = cp_allgather(hidden_states, self.mapping, dim=0)
# Remove padding introduced by reduce-scatter alignment.
hidden_states = hidden_states[:attn_metadata.num_tokens]
return hidden_states

def _pad_for_cp(self, tensor: torch.Tensor,
num_tokens: int) -> tuple[torch.Tensor, int]:
"""Pad tensor along dim-0 so its length is divisible by cp_size.

Returns the (possibly padded) tensor and the per-rank chunk size.
"""
cp_size = self.mapping.cp_size
chunk_size = math.ceil(num_tokens / cp_size)
padded_size = chunk_size * cp_size

if num_tokens < padded_size:
tensor = torch.nn.functional.pad(
tensor, (0, 0, 0, padded_size - num_tokens),
mode="constant",
value=0)

return tensor, chunk_size

def _slice_for_cp(self, tensor: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Slice a tensor to this CP rank's chunk, matching post-RS size.

Used for the first layer's residual: since there is no prior RS to
divide it, we manually extract this rank's portion so it aligns with
the reduce-scattered attention output.
"""
tensor, chunk_size = self._pad_for_cp(tensor, attn_metadata.num_tokens)
start = self.mapping.cp_rank * chunk_size
return tensor[start:start + chunk_size]

def _output_projection(
self,
attn_output: torch.Tensor,
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams],
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Apply output projection (o_proj) and reduce across parallel ranks.

With CP reduce-scatter, o_proj produces partial sums (each CP rank
contributes from its head partition). Reduce-scatter sums these
and divides the result among CP ranks for subsequent MoE processing.
Otherwise, o_proj uses the standard AllReduce path.

The residual is passed through unchanged unless this is the first
layer with CP reduce-scatter, in which case it is sliced to match
the post-RS token count.
"""
if self._needs_cp_reduce_scatter():
# Skip AllReduce in o_proj; use reduce-scatter instead.
attn_output = self.o_proj(
attn_output,
all_reduce_params=AllReduceParams(enable_allreduce=False))

# Pad to make token count divisible by cp_size for reduce-scatter.
attn_output, _ = self._pad_for_cp(attn_output,
attn_metadata.num_tokens)

# Reduce-scatter using mapping_o where tp_group = cp_group.
attn_output = reducescatter(attn_output, self.mapping_o, dim=0)

# For the first layer, the residual comes from the embedding and
# has not been through a prior RS. Slice it to match.
if self.layer_idx == 0 and residual is not ...:
residual = self._slice_for_cp(residual, attn_metadata)
else:
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)

return attn_output, residual

def forward(
self,
position_ids: Optional[torch.Tensor],
Expand All @@ -2572,8 +2572,8 @@ def forward(
residual: Optional[torch.Tensor] = ...,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:

hidden_states = self._maybe_allgather_input(hidden_states,
attn_metadata)
hidden_states = _helix_cp_allgather_input(hidden_states, attn_metadata,
self.mapping, self.layer_idx)

attn_output = self.create_output(hidden_states,
attn_metadata.num_contexts)
Expand All @@ -2594,10 +2594,9 @@ def forward(
output=attn_output,
latent_cache_gen=latent_cache_gen)

attn_output, residual = self._output_projection(attn_output,
attn_metadata,
all_reduce_params,
residual)
attn_output, residual = _helix_cp_output_projection(
self.o_proj, attn_output, attn_metadata, all_reduce_params,
residual, self.mapping, self.mapping_o, self.layer_idx)
if residual is ...:
return attn_output
return attn_output, residual
Expand Down
Loading