diff --git a/cpp/tensorrt_llm/kernels/quantization.cu b/cpp/tensorrt_llm/kernels/quantization.cu index 4d54e6fcbd04..3941277dfa01 100644 --- a/cpp/tensorrt_llm/kernels/quantization.cu +++ b/cpp/tensorrt_llm/kernels/quantization.cu @@ -178,7 +178,7 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = false; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, reinterpret_cast(output), @@ -213,7 +213,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = false; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, @@ -388,7 +388,7 @@ void computePerTokenGlobalScaleForFP4Quantization(int b, int m, int n, T const* config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = false; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; TLLM_CUDA_CHECK(cudaLaunchKernelEx( diff --git a/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh b/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh index 377b63452d29..7f60e787bf24 100644 --- a/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh +++ b/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh @@ -236,7 +236,6 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, if (!weight_warp) { cudaGridDependencySynchronize(); - cudaTriggerProgrammaticLaunchCompletion(); } for (int ki = 0; ki < K_LOOPS_DMA; ki++) @@ -441,6 +440,9 @@ __global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) profile[blockIdx.x].complete = gclock64(); + + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) + cudaTriggerProgrammaticLaunchCompletion(); } } #endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) diff --git a/jenkins/BuildDockerImage.groovy b/jenkins/BuildDockerImage.groovy index f26d5537ed6d..9de7e4e7dd31 100644 --- a/jenkins/BuildDockerImage.groovy +++ b/jenkins/BuildDockerImage.groovy @@ -703,6 +703,11 @@ pipeline { trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade pip") trtllm_utils.llmExecStepWithRetry(this, script: "pip3 install --upgrade requests") def nspect_commit = "4cb9c0c42d44ebeeba1e40d2c3eb6aab6fb90173" + def override_commit = env."NSPECT_OVERRIDE_${nspect_commit}" + if (override_commit) { + echo "Overriding nspect_commit with value from environment variable \$NSPECT_OVERRIDE_${nspect_commit}: ${override_commit}" + nspect_commit = override_commit + } withCredentials([string(credentialsId: "TRTLLM_NSPECT_REPO", variable: "NSPECT_REPO")]) { trtllm_utils.checkoutSource("${NSPECT_REPO}", nspect_commit, "nspect") } diff --git a/tensorrt_llm/_mnnvl_utils.py b/tensorrt_llm/_mnnvl_utils.py index 5d168447f900..2436f30c82b3 100644 --- a/tensorrt_llm/_mnnvl_utils.py +++ b/tensorrt_llm/_mnnvl_utils.py @@ -370,15 +370,32 @@ def get_comm(cls, mapping: Mapping): if cls.comm is not None: return cls.comm comm = mpi_comm().Split( - mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size - + mapping.tp_rank * mapping.moe_tp_size - + mapping.moe_tp_rank, + mapping.pp_rank * mapping.tp_size + mapping.tp_rank, mapping.cp_rank, ) cls.comm = comm return comm +def init_helix_cp_comm(mapping: Mapping) -> None: + """Pre-initialize the Helix CP communicator. + + This function MUST be called during model initialization when all ranks + are synchronized (before any PP pipeline divergence). The MPI Split operation + is collective and requires all ranks in the communicator to participate. + + In PP (pipeline parallel) mode, different PP stages execute different parts + of the model at different times. If the communicator is initialized lazily + during the first forward pass, ranks in different PP stages may not reach + the Split operation at the same time, causing a deadlock. + + Args: + mapping: The mapping object containing parallelism configuration. + """ + if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True): + HelixCpMnnvlMemory.get_comm(mapping) + + @dataclass class MoEAlltoallInfo: local_gather_indices: torch.Tensor diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 53446b0e8a85..ad590a86fee7 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -48,6 +48,8 @@ transforms: match_rope_layout: stage: pattern_matcher expected_layout: bsnd + match_rmsnorm_pattern: + stage: pattern_matcher ############################################################################################ # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION ############################################################################################ diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py index 7ce9b7befa85..39792a14fa9d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py @@ -88,6 +88,65 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: return torch.empty_like(input) +@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=()) +def torch_rmsnorm_gated( + x: torch.Tensor, + weight: torch.Tensor, + gate: torch.Tensor | None, + eps: float, + group_size: int, + norm_before_gate: bool = False, +) -> torch.Tensor: + """Custom operator for Torch gated RMSNorm implementation. + + Group RMSNorm with optional SiLU gating, using pure PyTorch operations. + + Args: + x: Input tensor of shape [..., H]. + weight: Scaling weights of shape [H]. + gate: Optional gate tensor with same shape as x, or None. + eps: Small constant for numerical stability. + group_size: Size of groups for grouped normalization. H must be divisible by group_size. + norm_before_gate: If True, apply gating after normalization. If False, apply before. + + Returns: + Normalized and optionally gated tensor of shape like x. + """ + dtype = x.dtype + weight = weight.float() + x = x.float() + z = gate.float() if gate is not None else gate + + if z is not None and not norm_before_gate: + x = x * F.silu(z) + + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = x * rstd * weight + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + + if z is not None and norm_before_gate: + out *= F.silu(z) + + return out.to(dtype) + + +@torch_rmsnorm_gated.register_fake +def _( + x: torch.Tensor, + weight: torch.Tensor, + gate: torch.Tensor | None, + eps: float, + group_size: int, + norm_before_gate: bool = False, +) -> torch.Tensor: + """Fake implementation for the custom operator during tracing.""" + return x.new_empty(x.shape, dtype=x.dtype) + + @torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=()) def triton_rmsnorm_gated( x: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index b5fb106e10b9..97a34e481bf8 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + """The model factory interface used by auto-deploy to build custom models.""" import copy @@ -12,6 +28,7 @@ from torch.fx import GraphModule from ..custom_ops.attention_interface import CacheConfig +from ..utils.cuda_mem_tracker import get_mem_info_in_mb from ..utils.logger import ad_logger DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension @@ -285,11 +302,20 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType): """ ad_logger.info("Loading and initializing weights.") + free_mem_pre, _ = get_mem_info_in_mb() + ad_logger.info(f"Free memory before loading weights (MB): {free_mem_pre}") self._to_maybe_random(model, device) + params_size = sum(p.numel() * p.element_size() for p in model.parameters()) + total_size_GB = params_size / (1024**3) + ad_logger.info(f"Estimated parameters memory: {total_size_GB:.2f} GB") + if not self.skip_loading_weights: self.prefetch_checkpoint(force=True) self._load_checkpoint(model, device) + ad_logger.info("Loading and initializing weights. Done.") + free_mem_post, _ = get_mem_info_in_mb() + ad_logger.info(f"Free memory after loading weights (MB): {free_mem_post}") @staticmethod def _to_maybe_random(model: nn.Module, device: DeviceLikeType): diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py index 85dc6c48bed3..57685ea28389 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py @@ -32,7 +32,10 @@ def _make_allreduce_residual_rmsnorm_pattern( add_order: str = "residual_first", strategy: str = "AUTO" ): - """Factory function to create pattern functions for allreduce+residual+rmsnorm fusion. + """Factory function to create pattern functions for allreduce+residual+torch_rmsnorm fusion. + + This pattern matches the graph after match_rmsnorm_pattern has replaced + RMSNorm patterns with torch_rmsnorm ops. Args: add_order: Either "residual_first" (residual + x) or "x_first" (x + residual) @@ -45,15 +48,14 @@ def _make_allreduce_residual_rmsnorm_pattern( def pattern_fn( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253 ): - """Pattern: trtllm_dist_all_reduce(x) -> add residual -> RMSNorm + """Pattern: trtllm_dist_all_reduce(x) -> add residual -> torch_rmsnorm Reference PyTorch composition: y = trtllm_dist_all_reduce(x) z = residual + y (or y + residual) - normed = RMSNorm(z, weight, eps) + normed = torch_rmsnorm(z, weight, eps) Returns (normed, z) """ - input_dtype = x.dtype hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x, strategy) # Handle addition order @@ -62,11 +64,8 @@ def pattern_fn( else: # x_first add = hidden_states + residual - hidden_states = add.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + eps) - - normed = weight * hidden_states.to(input_dtype) + # Use torch_rmsnorm op (already replaced by match_rmsnorm_pattern) + normed = torch.ops.auto_deploy.torch_rmsnorm(add, weight, eps) return normed, add @@ -94,6 +93,9 @@ class FuseAllreduceResidualRMSNorm(BaseTransform): This transform only applies when TRT-LLM ops are used (MPI mode), as it provides optimized fused kernels. The torch backend (demollm mode) does not benefit from this fusion and uses unfused operations. + + Note: This transform expects torch_rmsnorm ops in the graph, which are created + by the match_rmsnorm_pattern transform that runs earlier in the pipeline. """ def _apply( @@ -114,7 +116,6 @@ def _apply( 0.1253, # eps ] - op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)} scalar_workaround = {"eps": 0.1253} # ============================================================================ @@ -139,7 +140,6 @@ def _apply( replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy), patterns=patterns, dummy_args=dummy_args, - op_ignore_types=op_ignore_types, scalar_workaround=scalar_workaround, ) @@ -149,7 +149,6 @@ def _apply( replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy), patterns=patterns, dummy_args=dummy_args, - op_ignore_types=op_ignore_types, scalar_workaround=scalar_workaround, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 1207bda245ab..15d6f2599291 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + """Graph transformation to automatically add kv cache into fused MHA op.""" import inspect @@ -21,6 +37,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input +from ...utils.cuda_mem_tracker import get_mem_info_in_mb from ...utils.node_utils import is_op from ..interface import ( BaseTransform, @@ -288,11 +305,7 @@ def _apply_to_full_model( ) -> Tuple[nn.Module, TransformInfo]: free_mem_ratio = self.config.free_mem_ratio - def _get_mem_info_in_mb(): - free_mem, total_mem = torch.cuda.mem_get_info() - return free_mem // 1024**2, total_mem // 1024**2 - - free_mem, total_mem = _get_mem_info_in_mb() + free_mem, total_mem = get_mem_info_in_mb(empty_cache=True) self._log_info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}") current_cache_size = cm.current_cache_size_bytes() current_kv_cache_size = getattr(cm, "current_kv_cache_size_bytes", None) @@ -301,7 +314,7 @@ def _get_mem_info_in_mb(): ) current_num_pages = cm.info.num_pages self._log_info( - f"Current cache size (MB): {current_cache_size // 1024 // 1024}, " + f"Current cache size (MB): {current_cache_size // 1024**2}, " f"Current num pages: {current_num_pages}" ) if current_kv_cache_size != current_cache_size: @@ -320,12 +333,32 @@ def _get_mem_info_in_mb(): # Let's run a forward pass to get the memory usage cm.info.set_max_num_tokens_sample() - free_mem_pre, _ = _get_mem_info_in_mb() + free_mem_pre, _ = get_mem_info_in_mb(empty_cache=True) self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}") - mod(**cm.named_args) + # Reset peak memory stats to get the extra memory used during the forward pass + torch.cuda.reset_peak_memory_stats() + memory_allocated_before_forward_pass_mb = torch.cuda.memory_allocated() // 1024**2 + try: + mod(**cm.named_args) + except torch.OutOfMemoryError as e: + self.ad_logger.error( + f"OutOfMemoryError in forward pass while trying to resize the kv-cache:\n{e}" + ) + raise e + + peak_memory_during_forward_pass_mb = torch.cuda.max_memory_allocated() // 1024**2 + mem_used_during_forward_pass_mb = ( + peak_memory_during_forward_pass_mb - memory_allocated_before_forward_pass_mb + ) + self._log_info( + f"Peak memory uasge during forward pass (MB): {peak_memory_during_forward_pass_mb}" + ) + self._log_info( + f"Extra memory used during forward pass (MB): {mem_used_during_forward_pass_mb}" + ) - free_mem_post, _ = _get_mem_info_in_mb() + free_mem_post, _ = get_mem_info_in_mb(empty_cache=True) self._log_info(f"Free memory after forward pass (MB): {free_mem_post}") memory_for_forward_pass = free_mem_pre - free_mem_post diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py index 860b5b7de5b9..5e068429894b 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py @@ -1,17 +1,17 @@ """Graph transform to optimize RMSNorm execution using FlashInfer.""" -from functools import partial from typing import Tuple, Type import torch from pydantic import Field -from torch.fx import GraphModule +from torch.fx import GraphModule, Node from ...custom_ops.rms_norm import gated_rms_norm_ref from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface # It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher +from ...utils.node_utils import is_op from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern from ..interface import ( BaseTransform, @@ -66,6 +66,22 @@ def _rms_norm_pattern_float32_weights( return (weight.to(torch.float32) * data).to(input_dtype) +def _rms_norm_to_torch_rmsnorm( + data: torch.Tensor, weight: torch.Tensor, eps: float +) -> torch.Tensor: + """Replace RMSNorm pattern with torch_rmsnorm op (standardized representation). + + Args: + data: Input tensor to normalize. + weight: Scaling weights for the normalized output. + eps: Small constant for numerical stability. + + Returns: + Normalized and scaled tensor using torch_rmsnorm. + """ + return torch.ops.auto_deploy.torch_rmsnorm(data, weight, eps) + + def _rms_norm_replacement( data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str ) -> torch.Tensor: @@ -87,41 +103,26 @@ def _rms_norm_replacement( return _BACKEND_OPS[backend.lower()](data, weight, eps) -class FuseRMSNormConfig(TransformConfig): - """Configuration for the RMSNorm fusion transform.""" - - rmsnorm_backend: str = Field( - default="flashinfer", - description="Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch').", - ) - gated_rmsnorm_backend: str = Field( - default="triton", - description="Backend to use for gated RMSNorm computation (currently only 'triton').", - ) - - -@TransformRegistry.register("fuse_rmsnorm") -class FuseRMSNorm(BaseTransform): - """Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations. +@TransformRegistry.register("match_rmsnorm_pattern") +class MatchRMSNormPattern(BaseTransform): + """Matches RMSNorm patterns in the graph and replaces them with torch_rmsnorm op. - This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph - and replaces them with optimized implementations. It uses dummy tensors to register - the pattern matching rules. + This transform runs in the pattern_matcher stage and standardizes RMSNorm patterns + to use torch_rmsnorm op, which can later be fused to a specific backend in the + post_load_fusion stage. Args: gm: Input graph module to transform. - rmsnorm_backend: Backend to use for regular RMSNorm computation ("flashinfer", "triton", or "torch"). - gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton"). Returns: - Transformed graph module with optimized RMSNorm operations. + Transformed graph module with standardized torch_rmsnorm operations. """ - config: FuseRMSNormConfig + config: TransformConfig @classmethod def get_config_class(cls) -> Type[TransformConfig]: - return FuseRMSNormConfig + return TransformConfig def _apply( self, @@ -130,19 +131,6 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - # Validate rmsnorm_backend - if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS: - raise ValueError( - f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}" - ) - - # Validate gated_rmsnorm_backend (currently only triton is supported) - if self.config.gated_rmsnorm_backend.lower() != "triton": - raise ValueError( - f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported, - got {self.config.gated_rmsnorm_backend}""" - ) - graph = gm.graph patterns = ADPatternMatcherPass() @@ -164,7 +152,7 @@ def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = (torch.float32, torch.float32), ] - # Register patterns for each configuration + # Register patterns for each configuration - replace with torch_rmsnorm search_fns = [ _rms_norm_pattern, _rms_norm_pattern_float32_weights, @@ -173,7 +161,7 @@ def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = for input_dtype, weight_dtype in configs: register_ad_pattern( search_fn=search_fn, - replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend), + replace_fn=_rms_norm_to_torch_rmsnorm, patterns=patterns, dummy_args=dummy_args(input_dtype, weight_dtype), op_ignore_types={}, @@ -198,10 +186,10 @@ def make_dummy_args_gated(group_size: int, eps: float) -> list: torch.ops.aten.to.dtype: (torch.dtype,), } - # Register pattern for gated RMSNorm + # Register pattern for gated RMSNorm - replace with torch_rmsnorm_gated register_ad_pattern( search_fn=_gated_rmsnorm_pattern_ref, - replace_fn=_gated_rmsnorm_replacement, + replace_fn=_gated_rmsnorm_to_torch_rmsnorm_gated, patterns=patterns, dummy_args=make_dummy_args_gated(group_size, eps), op_ignore_types=op_ignore_types, @@ -218,6 +206,103 @@ def make_dummy_args_gated(group_size: int, eps: float) -> list: return gm, info +class FuseRMSNormConfig(TransformConfig): + """Configuration for the RMSNorm fusion transform.""" + + rmsnorm_backend: str = Field( + default="flashinfer", + description="Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch').", + ) + gated_rmsnorm_backend: str = Field( + default="triton", + description="Backend to use for gated RMSNorm computation (currently only 'triton').", + ) + + +@TransformRegistry.register("fuse_rmsnorm") +class FuseRMSNorm(BaseTransform): + """Fuses torch_rmsnorm ops with the selected backend implementation. + + This transform runs in the post_load_fusion stage and replaces torch_rmsnorm ops + with the specified backend implementation (flashinfer, triton, or torch). + + Args: + gm: Input graph module to transform. + rmsnorm_backend: Backend to use for regular RMSNorm computation ("flashinfer", "triton", or "torch"). + gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton"). + + Returns: + Transformed graph module with backend-specific RMSNorm operations. + """ + + config: FuseRMSNormConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseRMSNormConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # Validate rmsnorm_backend + if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS: + raise ValueError( + f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}" + ) + + # Validate gated_rmsnorm_backend (currently only triton is supported) + if self.config.gated_rmsnorm_backend.lower() != "triton": + raise ValueError( + f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported, + got {self.config.gated_rmsnorm_backend}""" + ) + + graph = gm.graph + backend = self.config.rmsnorm_backend.lower() + target_op = _BACKEND_OPS[backend] + cnt = 0 + + # Replace torch_rmsnorm ops with the selected backend + for node in list(graph.nodes): + if is_op(node, torch.ops.auto_deploy.torch_rmsnorm): + # Replace with the selected backend op + with graph.inserting_after(node): + new_node: Node = graph.call_function( + target_op, + args=node.args, + kwargs=node.kwargs, + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + cnt += 1 + + # Replace torch_rmsnorm_gated ops with triton_rmsnorm_gated + for node in list(graph.nodes): + if is_op(node, torch.ops.auto_deploy.torch_rmsnorm_gated): + # Replace with triton_rmsnorm_gated op + with graph.inserting_after(node): + new_node: Node = graph.call_function( + torch.ops.auto_deploy.triton_rmsnorm_gated, + args=node.args, + kwargs=node.kwargs, + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + cnt += 1 + + gm.recompile() + + info = TransformInfo( + skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0 + ) + + return gm, info + + def _gated_rmsnorm_pattern_ref( x: torch.Tensor, # [B, S, H] weight: torch.Tensor, # [H] @@ -239,13 +324,25 @@ def _gated_rmsnorm_pattern_ref( return y -def _gated_rmsnorm_replacement( +def _gated_rmsnorm_to_torch_rmsnorm_gated( x: torch.Tensor, weight: torch.Tensor, gate: torch.Tensor, eps: float, group_size: int, ) -> torch.Tensor: - return torch.ops.auto_deploy.triton_rmsnorm_gated( + """Replace gated RMSNorm pattern with torch_rmsnorm_gated op (standardized representation). + + Args: + x: Input tensor to normalize. + weight: Scaling weights for the normalized output. + gate: Gate tensor for gated normalization. + eps: Small constant for numerical stability. + group_size: Size of groups for grouped normalization. + + Returns: + Normalized and gated tensor using torch_rmsnorm_gated. + """ + return torch.ops.auto_deploy.torch_rmsnorm_gated( x, weight, gate, float(eps), int(group_size), False ) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py b/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py index ddf57a6e6f7f..e73cec39e7c9 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py @@ -1,5 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import gc from contextlib import contextmanager +from typing import Tuple import torch @@ -24,3 +41,12 @@ def cuda_memory_tracker(logger=ad_logger): leaked = mem_after - mem_before if leaked > 0: logger.warning(f"Potential memory leak detected, leaked memory: {leaked} bytes") + + +def get_mem_info_in_mb(empty_cache: bool = True) -> Tuple[int, int]: + if empty_cache: + # Clear the memory cache to get the exact free memory + torch.cuda.empty_cache() + free_mem, total_mem = torch.cuda.mem_get_info() + MB = 1024**2 + return free_mem // MB, total_mem // MB diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 09bbc234ee27..20401b5a24a9 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -16,6 +16,7 @@ except Exception: MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True +from tensorrt_llm._mnnvl_utils import init_helix_cp_comm from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm, mpi_disabled, mpi_isend, mpi_isend_object, mpi_recv, mpi_recv_object, mpi_send, @@ -888,6 +889,7 @@ def init_pp_comm(mapping): _pp_comm = PPCommTorch(mapping) else: _pp_comm = PPCommNCCL(mapping) + init_helix_cp_comm(mapping) @TorchDist.log_op diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 2ba2ee1bfee7..14bb0cb81129 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -871,6 +871,7 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell @pytest.mark.skip_less_device(8) @pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2), (2, 1, 2)], diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index b7a31e57d952..a914a00c5348 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -71,7 +71,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60) @@ -101,7 +101,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 2ede4acdf4ad..5000daf63384 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -228,7 +228,6 @@ unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP] SKIP unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL] SKIP (https://nvbugs/5664904) test_e2e.py::test_ptp_quickstart_advanced[Nemotron-Super-49B-v1-FP8-nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1-FP8] SKIP (https://nvbugs/5670469) test_e2e.py::test_ptp_quickstart_advanced[Nemotron-Super-49B-v1-NVFP4-nvfp4-quantized/Llama-3_3-Nemotron-Super-49B-v1_nvfp4_hf] SKIP (https://nvbugs/5670469) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] SKIP (https://nvbugs/5673610) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] SKIP (https://nvbugs/5756804) examples/test_llama.py::test_llm_llama_1gpu_fp4[llama-3.1-70b-instruct-enable_norm_quant_fusion-enable_fused_quant-fp4_plugin-bfloat16] SKIP (https://nvbugs/5451216) accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8 SKIP (https://nvbugs/5673527) @@ -312,14 +311,11 @@ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5768068) test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/saved_models_Qwen3-235B-A22B_fp8_hf-Qwen3/qwen3-235B-eagle3] SKIP (https://nvbugs/5685010) examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5769855) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] SKIP (https://nvbugs/5772396) -full:sm100/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] SKIP (https://nvbugs/5772396) accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm] SKIP (https://nvbugs/5772360) accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model] SKIP (https://nvbugs/5772993) accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5772995) test_e2e.py::test_eagle3_output_consistency_4gpus[Qwen3/Qwen3-30B-A3B-Qwen3/Qwen3-30B-eagle3] SKIP (https://nvbugs/5685010) full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2] SKIP (https://nvbugs/5773047) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[enable_configurable_moe-dp4-trtllm-fp8] SKIP (https://nvbugs/5773201) unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py::test_sharding[GQA_Block-torch_dist_all_reduce-True-False-2] SKIP (https://nvbugs/5766982) accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=True] SKIP (https://nvbugs/5773185) accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=False] SKIP (https://nvbugs/5773185) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index f9595cde7f05..f7b58ee349de 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -87,6 +87,9 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str): gm_transformed = InferenceOptimizer( None, { + "match_rmsnorm_pattern": { + "stage": "pattern_matcher", + }, "detect_sharding": { "stage": "post_export", "allreduce_strategy": strategy, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py index d354f9d50fe5..c62e5d5396e4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py @@ -64,6 +64,9 @@ def checker(gm): gm_transformed = InferenceOptimizer( None, { + "match_rmsnorm_pattern": { + "stage": "pattern_matcher", + }, "fuse_rmsnorm": { "stage": "post_load_fusion", "gated_rmsnorm_backend": "triton",