Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/tensorrt_llm/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
Expand Down Expand Up @@ -213,7 +213,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
Expand Down Expand Up @@ -388,7 +388,7 @@ void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const*
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
Expand Down
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,
if (!weight_warp)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}

for (int ki = 0; ki < K_LOOPS_DMA; ki++)
Expand Down Expand Up @@ -441,6 +440,9 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output,

if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
profile[blockIdx.x].complete = gclock64();

if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
cudaTriggerProgrammaticLaunchCompletion();
}
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
Expand Down
5 changes: 5 additions & 0 deletions jenkins/BuildDockerImage.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,11 @@ pipeline {
trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade pip")
trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade requests")
def nspect_commit = "4cb9c0c42d44ebeeba1e40d2c3eb6aab6fb90173"
def override_commit = env."NSPECT_OVERRIDE_${nspect_commit}"
if (override_commit) {
echo "Overriding nspect_commit with value from environment variable \$NSPECT_OVERRIDE_${nspect_commit}: ${override_commit}"
nspect_commit = override_commit
}
withCredentials([string(credentialsId: "TRTLLM_NSPECT_REPO", variable: "NSPECT_REPO")]) {
trtllm_utils.checkoutSource("${NSPECT_REPO}", nspect_commit, "nspect")
}
Expand Down
23 changes: 20 additions & 3 deletions tensorrt_llm/_mnnvl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,32 @@ def get_comm(cls, mapping: Mapping):
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
+ mapping.tp_rank * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.pp_rank * mapping.tp_size + mapping.tp_rank,
mapping.cp_rank,
)
cls.comm = comm
return comm


def init_helix_cp_comm(mapping: Mapping) -> None:
"""Pre-initialize the Helix CP communicator.

This function MUST be called during model initialization when all ranks
are synchronized (before any PP pipeline divergence). The MPI Split operation
is collective and requires all ranks in the communicator to participate.

In PP (pipeline parallel) mode, different PP stages execute different parts
of the model at different times. If the communicator is initialized lazily
during the first forward pass, ranks in different PP stages may not reach
the Split operation at the same time, causing a deadlock.

Args:
mapping: The mapping object containing parallelism configuration.
"""
if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True):
HelixCpMnnvlMemory.get_comm(mapping)


@dataclass
class MoEAlltoallInfo:
local_gather_indices: torch.Tensor
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ transforms:
match_rope_layout:
stage: pattern_matcher
expected_layout: bsnd
match_rmsnorm_pattern:
stage: pattern_matcher
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################
Expand Down
59 changes: 59 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,65 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return torch.empty_like(input)


@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=())
def torch_rmsnorm_gated(
x: torch.Tensor,
weight: torch.Tensor,
gate: torch.Tensor | None,
eps: float,
group_size: int,
norm_before_gate: bool = False,
) -> torch.Tensor:
"""Custom operator for Torch gated RMSNorm implementation.

Group RMSNorm with optional SiLU gating, using pure PyTorch operations.

Args:
x: Input tensor of shape [..., H].
weight: Scaling weights of shape [H].
gate: Optional gate tensor with same shape as x, or None.
eps: Small constant for numerical stability.
group_size: Size of groups for grouped normalization. H must be divisible by group_size.
norm_before_gate: If True, apply gating after normalization. If False, apply before.

Returns:
Normalized and optionally gated tensor of shape like x.
"""
dtype = x.dtype
weight = weight.float()
x = x.float()
z = gate.float() if gate is not None else gate

if z is not None and not norm_before_gate:
x = x * F.silu(z)

if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = x * rstd * weight
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight

if z is not None and norm_before_gate:
out *= F.silu(z)

return out.to(dtype)


@torch_rmsnorm_gated.register_fake
def _(
x: torch.Tensor,
weight: torch.Tensor,
gate: torch.Tensor | None,
eps: float,
group_size: int,
norm_before_gate: bool = False,
) -> torch.Tensor:
"""Fake implementation for the custom operator during tracing."""
return x.new_empty(x.shape, dtype=x.dtype)


@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
def triton_rmsnorm_gated(
x: torch.Tensor,
Expand Down
26 changes: 26 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""The model factory interface used by auto-deploy to build custom models."""

import copy
Expand All @@ -12,6 +28,7 @@
from torch.fx import GraphModule

from ..custom_ops.attention_interface import CacheConfig
from ..utils.cuda_mem_tracker import get_mem_info_in_mb
from ..utils.logger import ad_logger

DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
Expand Down Expand Up @@ -285,11 +302,20 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):

"""
ad_logger.info("Loading and initializing weights.")
free_mem_pre, _ = get_mem_info_in_mb()
ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}")
self._to_maybe_random(model, device)
params_size = sum(p.numel() * p.element_size() for p in model.parameters())
total_size_GB = params_size / (1024**3)
ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB")

if not self.skip_loading_weights:
self.prefetch_checkpoint(force=True)
self._load_checkpoint(model, device)

ad_logger.info("Loading and initializing weights. Done.")
free_mem_post, _ = get_mem_info_in_mb()
ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}")

@staticmethod
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):
Expand Down
23 changes: 11 additions & 12 deletions tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
def _make_allreduce_residual_rmsnorm_pattern(
add_order: str = "residual_first", strategy: str = "AUTO"
):
"""Factory function to create pattern functions for allreduce+residual+rmsnorm fusion.
"""Factory function to create pattern functions for allreduce+residual+torch_rmsnorm fusion.

This pattern matches the graph after match_rmsnorm_pattern has replaced
RMSNorm patterns with torch_rmsnorm ops.

Args:
add_order: Either "residual_first" (residual + x) or "x_first" (x + residual)
Expand All @@ -45,15 +48,14 @@ def _make_allreduce_residual_rmsnorm_pattern(
def pattern_fn(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
):
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> RMSNorm
"""Pattern: trtllm_dist_all_reduce(x) -> add residual -> torch_rmsnorm

Reference PyTorch composition:
y = trtllm_dist_all_reduce(x)
z = residual + y (or y + residual)
normed = RMSNorm(z, weight, eps)
normed = torch_rmsnorm(z, weight, eps)
Returns (normed, z)
"""
input_dtype = x.dtype
hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x, strategy)

# Handle addition order
Expand All @@ -62,11 +64,8 @@ def pattern_fn(
else: # x_first
add = hidden_states + residual

hidden_states = add.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)

normed = weight * hidden_states.to(input_dtype)
# Use torch_rmsnorm op (already replaced by match_rmsnorm_pattern)
normed = torch.ops.auto_deploy.torch_rmsnorm(add, weight, eps)

return normed, add

Expand Down Expand Up @@ -94,6 +93,9 @@ class FuseAllreduceResidualRMSNorm(BaseTransform):
This transform only applies when TRT-LLM ops are used (MPI mode), as it provides
optimized fused kernels. The torch backend (demollm mode) does not benefit from
this fusion and uses unfused operations.

Note: This transform expects torch_rmsnorm ops in the graph, which are created
by the match_rmsnorm_pattern transform that runs earlier in the pipeline.
"""

def _apply(
Expand All @@ -114,7 +116,6 @@ def _apply(
0.1253, # eps
]

op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
scalar_workaround = {"eps": 0.1253}

# ============================================================================
Expand All @@ -139,7 +140,6 @@ def _apply(
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
patterns=patterns,
dummy_args=dummy_args,
op_ignore_types=op_ignore_types,
scalar_workaround=scalar_workaround,
)

Expand All @@ -149,7 +149,6 @@ def _apply(
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
patterns=patterns,
dummy_args=dummy_args,
op_ignore_types=op_ignore_types,
scalar_workaround=scalar_workaround,
)

Expand Down
51 changes: 42 additions & 9 deletions tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Graph transformation to automatically add kv cache into fused MHA op."""

import inspect
Expand All @@ -21,6 +37,7 @@
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import add_graph_input
from ...utils.cuda_mem_tracker import get_mem_info_in_mb
from ...utils.node_utils import is_op
from ..interface import (
BaseTransform,
Expand Down Expand Up @@ -288,11 +305,7 @@ def _apply_to_full_model(
) -> Tuple[nn.Module, TransformInfo]:
free_mem_ratio = self.config.free_mem_ratio

def _get_mem_info_in_mb():
free_mem, total_mem = torch.cuda.mem_get_info()
return free_mem // 1024**2, total_mem // 1024**2

free_mem, total_mem = _get_mem_info_in_mb()
free_mem, total_mem = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None)
Expand All @@ -301,7 +314,7 @@ def _get_mem_info_in_mb():
)
current_num_pages = cm.info.num_pages
self._log_info(
f"Current cache size (MB): {current_cache_size // 1024 // 1024}, "
f"Current cache size (MB): {current_cache_size // 1024**2}, "
f"Current num pages: {current_num_pages}"
)
if current_kv_cache_size != current_cache_size:
Expand All @@ -320,12 +333,32 @@ def _get_mem_info_in_mb():

# Let's run a forward pass to get the memory usage
cm.info.set_max_num_tokens_sample()
free_mem_pre, _ = _get_mem_info_in_mb()
free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")

mod(**cm.named_args)
# Reset peak memory stats to get the extra memory used during the forward pass
torch.cuda.reset_peak_memory_stats()
memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2
try:
mod(**cm.named_args)
except torch.OutOfMemoryError as e:
self.ad_logger.error(
f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}"
)
raise e

peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2
mem_used_during_forward_pass_mb = (
peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb
)
self._log_info(
f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}"
)
self._log_info(
f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}"
)

free_mem_post, _ = _get_mem_info_in_mb()
free_mem_post, _ = get_mem_info_in_mb(empty_cache=True)
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")

memory_for_forward_pass = free_mem_pre - free_mem_post
Expand Down
Loading
Loading