diff --git a/3rdparty/CMakeLists.txt b/3rdparty/CMakeLists.txt index 59076e14c9a2..93565ae099b0 100644 --- a/3rdparty/CMakeLists.txt +++ b/3rdparty/CMakeLists.txt @@ -38,8 +38,8 @@ FetchContent_Declare( FetchContent_Declare( deepgemm - GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM - GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch + GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM + GIT_TAG 4ff3f54d9b7ed3129e4f36f9871232ea7ecab86b # nv_dev branch GIT_SUBMODULES_RECURSE ON SOURCE_SUBDIR diff --git a/cpp/tensorrt_llm/common/customAllReduceUtils.h b/cpp/tensorrt_llm/common/customAllReduceUtils.h index 4115ac150fdf..d718cbd188e9 100644 --- a/cpp/tensorrt_llm/common/customAllReduceUtils.h +++ b/cpp/tensorrt_llm/common/customAllReduceUtils.h @@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept { return common::getEnvAllReduceWorkspaceSize(); } - if (worldSize <= 2) + char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE"); + if (envWorkspaceSize != nullptr) { - return 16 * 1000 * 1000; - } - return 8 * 1000 * 1000; -} - -// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold) -inline std::unordered_map>> HeuristicThresholdLP{ - {90, - { - {2, {4096, 4096 * 4096}}, - {4, {4096, 1024 * 1024}}, - {8, {2048, 512 * 512}}, - }}, - {100, - { - {2, {4096, 4096 * 4096}}, - {4, {4096, 1024 * 2048}}, - {8, {4096, 1024 * 1024}}, - }}, -}; - -inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op) -{ - // The heuristic is based on the following assumptions: - // __________________________________ - // | \ TWO-SHOT zone | - // | ONE-SHOT zone \ | NCCL zone - // |_______________________\______|___ - // sm_major is 90 or 100 - - auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion())); - - auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size]; - auto const message_size = seq_len * hidden_size; - if (message_size >= two_shot_numel_threshold) - { - return AllReduceStrategyType::TWOSHOT; - } - else - { - return AllReduceStrategyType::ONESHOT; + return static_cast(std::atoi(envWorkspaceSize)); } + return 67108864; // 64 MiB } // use 1D vector to store the best strategy instead of a map for each sm version diff --git a/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt b/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt index d2a900bf05ad..1a884befcaf3 100644 --- a/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt @@ -38,7 +38,13 @@ foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES}) if(FILE_EXT STREQUAL ".py") # Read file content and replace module imports for Python files file(READ ${SOURCE_FILE} _content) - string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content + string(REPLACE "from . import _C" "import tensorrt_llm.deep_gemm_cpp_tllm" + _content "${_content}") + string(REPLACE ".._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content + "${_content}") + string(REPLACE "._C" "tensorrt_llm.deep_gemm_cpp_tllm" _content + "${_content}") + string(REPLACE "_C." "tensorrt_llm.deep_gemm_cpp_tllm." _content "${_content}") # Add adaptation header diff --git a/docs/source/commands/trtllm-serve/trtllm-serve.rst b/docs/source/commands/trtllm-serve/trtllm-serve.rst index 7e09872a9b6f..b26e45de9242 100644 --- a/docs/source/commands/trtllm-serve/trtllm-serve.rst +++ b/docs/source/commands/trtllm-serve/trtllm-serve.rst @@ -299,6 +299,8 @@ To configure the nested level arguments like ``moe_config.backend``, the yaml fi Syntax ------ +This syntax section lists all command line arguments for ``trtllm-serve``'s subcommands. Some of the arguments are accompanied with a stability tag indicating their development status. Refer to our `API Reference `__ for details + .. click:: tensorrt_llm.commands.serve:main :prog: trtllm-serve :nested: full diff --git a/examples/models/core/deepseek_v3/README.md b/examples/models/core/deepseek_v3/README.md index b25c58fbea7e..81a0e95f50b2 100644 --- a/examples/models/core/deepseek_v3/README.md +++ b/examples/models/core/deepseek_v3/README.md @@ -90,7 +90,6 @@ To quickly run DeepSeek-V3, [examples/llm-api/quickstart_advanced.py](../llm-api cd examples/llm-api python quickstart_advanced.py --model_dir --tp_size 8 ``` -Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp, as this model uses the deep_gemm.fp8_paged_mqa_logits kernel, which requires a KV cache block size of 64. The model will be run by PyTorch backend and generate outputs like: ``` @@ -108,7 +107,7 @@ cd examples/llm-api python quickstart_advanced.py --model_dir --spec_decode_algo MTP --spec_decode_max_draft_len N ``` -`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared. Please include `--tokens_per_block 64` when running DeepSeek-V3.2-Exp. +`N` is the number of MTP modules. When `N` is equal to `0`, which means that MTP is not used (default). When `N` is greater than `0`, which means that `N` MTP modules are enabled. In the current implementation, the weight of each MTP module is shared. #### Relaxed acceptance **NOTE: This feature can only be used for DeepSeek R1.** diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index b30edc1aa634..80e37bba7d62 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -785,7 +785,6 @@ def on_update_kv_lens(self): # After changing the kv_lens/kv_lens_cuda, we may need to update other metadatas. # Especially for the changes in the _preprocess_inputs() of model_engine.py. if self.num_generations > 0: - tokens_per_block = self.kv_cache_manager.indexer_k_cache_tokens_per_block torch.cumsum( self.kv_lens_cuda[self.num_contexts:self. num_seqs], # num_contexts should be 0 @@ -800,7 +799,7 @@ def on_update_kv_lens(self): out=self.gen_cached_token_indptr[1:self.num_generations + 1]) scheduler_metadata_buffer = get_paged_mqa_logits_metadata( self.kv_lens_cuda[self.num_contexts:self.num_seqs], - tokens_per_block, self.num_sms) + self.kv_cache_manager.tokens_per_block, self.num_sms) self.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer, non_blocking=True) if self.use_expanded_buffers_for_mtp: @@ -827,7 +826,6 @@ def on_update_kv_lens(self): def update_for_spec_dec(self): super().update_for_spec_dec() - self.kv_cache_manager.indexer_k_cache_tokens_per_block # host self.max_ctx_kv_len = 0 self.num_ctx_cached_tokens = 0 @@ -1030,7 +1028,7 @@ def prepare(metadata: DSAtrtllmAttentionMetadata): request_ids = metadata.request_ids seq_lens = metadata.seq_lens head_dim = metadata.kv_cache_manager.index_head_dim - tokens_per_block = metadata.kv_cache_manager.indexer_k_cache_tokens_per_block + tokens_per_block = metadata.kv_cache_manager.tokens_per_block quant_block_size = metadata.kv_cache_manager.quant_block_size cached_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq total_tokens = seq_lens.sum().item() @@ -1750,9 +1748,6 @@ def __init__( ) -> None: self.quant_block_size = 128 self.index_head_dim = sparse_attn_config.index_head_dim - # Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints - self.indexer_k_cache_tokens_per_block = 64 - assert self.indexer_k_cache_tokens_per_block == tokens_per_block, "tokens_per_block must be set to 64 for DeepSeek v3.2" super().__init__( kv_cache_config, @@ -1778,7 +1773,7 @@ def __init__( self.num_blocks = self.blocks_in_primary_pool # Indexer K cache pool for DSA attention - # Shape: [num_blocks, self.indexer_k_cache_tokens_per_block * (index_head_dim + scale_size)] + # Shape: [num_blocks, self.tokens_per_block * (index_head_dim + scale_size)] # Non-interleaved layout: [fp8_tok0 | fp8_tok1 | ... | scale_tok0 | scale_tok1 | ...] # Store FP8-quantized k values from the indexer self.indexer_k_cache_pool_per_layer = [ @@ -1805,9 +1800,7 @@ def get_cache_size_per_token(model_config: ModelConfig, mapping: Mapping, config = model_config.pretrained_config sparse_attn_config = model_config.sparse_attention_config index_head_dim = sparse_attn_config.index_head_dim - tokens_per_block = kwargs['tokens_per_block'] quant_block_size = 128 - indexer_k_cache_tokens_per_block = 64 # get kv cache dtype bytes mem_per_token = 2 @@ -1827,8 +1820,7 @@ def get_cache_size_per_token(model_config: ModelConfig, mapping: Mapping, # 1 for K, others for indexer K cache head_dim_factor = (index_head_dim + index_head_dim // quant_block_size * 4) / head_dim - tokens_per_block_factor = indexer_k_cache_tokens_per_block / tokens_per_block - kv_factor = 1 + head_dim_factor * tokens_per_block_factor + kv_factor = 1 + head_dim_factor mem_per_token *= kv_factor return mem_per_token @@ -1836,8 +1828,7 @@ def get_cache_bytes_per_token(self): # self.kv_factor for K, others for indexer K cache head_dim_factor = (self.index_head_dim + self.index_head_dim // self.quant_block_size * 4) / self.head_dim - tokens_per_block_factor = self.indexer_k_cache_tokens_per_block / self.tokens_per_block - kv_factor = self.kv_factor + head_dim_factor * tokens_per_block_factor + kv_factor = self.kv_factor + head_dim_factor cache_size_per_token = math.ceil( kv_factor * sum(self.num_kv_heads_per_layer) * self.head_dim) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index ffd39896ac79..9f6b885d3b41 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -23,6 +23,7 @@ from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures from tensorrt_llm._torch.auto_deploy.utils._graph import get_input_embeddings, get_lm_head_weights +from tensorrt_llm._torch.autotuner import AutoTuner from tensorrt_llm._torch.models.modeling_speculative import Eagle3ForCausalLM from tensorrt_llm._torch.pyexecutor._util import ( _create_kv_cache_manager, @@ -1008,6 +1009,10 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer torch.cuda.set_device(rank) port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port dist.initialize_or_skip(rank, world_size, port) + + # Setup AutoTuner with distributed state for allreduce autotuning + AutoTuner.get().setup_distributed_state(dist_mapping, mpi_dist) + # some config assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported" diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index cbcc2926805f..01193e68b790 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -2,6 +2,7 @@ import contextlib import copy import enum +import fcntl import inspect import itertools import json @@ -18,6 +19,7 @@ import tensorrt_llm from tensorrt_llm._torch.distributed import Distributed +from tensorrt_llm._utils import nvtx_range from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -266,15 +268,11 @@ def autotune(tune_mode: bool = True, cache_path: str = None): tune_required = tune_mode if cache_path is not None: # check if the rank-specific file exists - cache_path_no_ext = os.path.splitext(cache_path)[0] - cache_path_no_ext_rank = cache_path_no_ext + f".rank{rank}.json" # if the rank-specific file exists, load it - file_exists = os.path.exists(cache_path_no_ext_rank) - # if the rank-specific file exists, do not enable tuning mode + file_exists = os.path.exists(cache_path) if file_exists: - logger.info( - f"[Autotuner] Loading cache from {cache_path_no_ext_rank}") - autotuner.profiling_cache.load_cache(cache_path_no_ext_rank) + logger.info(f"[Autotuner] Loading cache from {cache_path}") + autotuner.profiling_cache.load_cache(cache_path, rank) # record the old tuning mode old_mode = autotuner.is_tuning_mode @@ -293,8 +291,8 @@ def autotune(tune_mode: bool = True, cache_path: str = None): # save cache if cache_path is not None: - logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}") - autotuner.profiling_cache.save_cache(cache_path_no_ext_rank) + logger.info(f"[Autotuner] Saving cache to {cache_path}") + autotuner.profiling_cache.save_cache(cache_path, rank) @dataclass @@ -439,7 +437,7 @@ def merge_cache_data(self, cache_data: Dict[Tuple, Tuple]): def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]: return {k: v for k, v in self.cache.items() if k[0] == custom_op} - def save_cache(self, file_path: Union[str, Path]) -> None: + def save_cache(self, file_path: Union[str, Path], rank: int) -> None: """Save the profiling cache to disk in JSON format. Args: @@ -456,9 +454,21 @@ def save_cache(self, file_path: Union[str, Path]) -> None: file_path.parent.mkdir(parents=True, exist_ok=True) try: - serializable_cache = self._serialize_cache_to_json() - with open(file_path, 'w') as f: - json.dump(serializable_cache, f, indent=2, default=str) + serialized_rank_cache_data = self._serialize_cache_data() + with open(file_path, 'a+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.seek(0) + content = f.read() + if content.strip(): + current_cache = json.loads(content) + else: + current_cache = { + "metadata": self._serialize_metadata(), + } + f.seek(0) + f.truncate() + current_cache[f"rank_{rank}"] = serialized_rank_cache_data + json.dump(current_cache, f, indent=2, default=str) logger.info( f"[AutoTuner] Successfully saved cache to {file_path} using JSON format" ) @@ -466,7 +476,7 @@ def save_cache(self, file_path: Union[str, Path]) -> None: logger.error(f"[AutoTuner] Failed to save cache with JSON: {e}") raise - def load_cache(self, file_path: Union[str, Path]) -> None: + def load_cache(self, file_path: Union[str, Path], rank: int) -> None: """Load the profiling cache from disk in JSON format. Args: @@ -486,8 +496,12 @@ def load_cache(self, file_path: Union[str, Path]) -> None: try: with open(file_path, 'r') as f: - serializable_cache = json.load(f) - self.cache = self._deserialize_cache_from_json(serializable_cache) + fcntl.flock(f, fcntl.LOCK_SH) + current_cache_contents = json.load(f) + self._deserialize_metadata(current_cache_contents["metadata"]) + assert f"rank_{rank}" in current_cache_contents, f"Rank {rank} cache not found in {file_path}" + self.cache = self._deserialize_cache_data( + current_cache_contents[f'rank_{rank}']) logger.info( f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format" ) @@ -495,7 +509,21 @@ def load_cache(self, file_path: Union[str, Path]) -> None: logger.error(f"[AutoTuner] Failed to load cache with JSON: {e}") raise - def _serialize_cache_to_json(self) -> Dict[str, Any]: + def _serialize_metadata(self) -> Dict[str, Any]: + return { + "lib_version": self.lib_version, + "creation_timestamp": self.creation_timestamp, + "device_name": self.device_name, + "device_capability": self.device_capability, + } + + def _deserialize_metadata(self, metadata: Dict[str, Any]) -> None: + self.lib_version = metadata["lib_version"] + self.creation_timestamp = metadata["creation_timestamp"] + self.device_name = metadata["device_name"] + self.device_capability = metadata["device_capability"] + + def _serialize_cache_data(self) -> Dict[str, Any]: """Convert the profiling cache to a JSON-serializable format. Returns: @@ -505,15 +533,7 @@ def _serialize_cache_to_json(self) -> Dict[str, Any]: This method handles the conversion of complex objects to JSON-compatible representations. Some type information may be lost in the conversion. """ - serializable_cache = { - "metadata": { - "lib_version": self.lib_version, - "creation_timestamp": self.creation_timestamp, - "device_name": self.device_name, - "device_capability": self.device_capability, - }, - "cache_data": {}, - } + serializable_cache = {} for key, value in self.cache.items(): # Convert any simple object to string for JSON compatibility @@ -529,7 +549,7 @@ def _serialize_cache_to_json(self) -> Dict[str, Any]: f"[AutoTuner] Could not serialize tactic: {tactic_str} for cache key {key_str} due to {e}. Deserialization may fail.", key=tactic_str) - serializable_cache["cache_data"][key_str] = { + serializable_cache[key_str] = { "runner_id": runner_id, "tactic": tactic_str, "min_time": min_time, @@ -537,8 +557,8 @@ def _serialize_cache_to_json(self) -> Dict[str, Any]: return serializable_cache - def _deserialize_cache_from_json( - self, serializable_cache: Dict[str, Any]) -> Dict[Tuple, Tuple]: + def _deserialize_cache_data( + self, cache_data: Dict[str, Any]) -> Dict[Tuple, Tuple]: """Convert JSON-serialized cache back to the original format. Args: @@ -551,14 +571,7 @@ def _deserialize_cache_from_json( This attempts to reconstruct the original data structures but may not perfectly preserve all type information, especially for complex tactic objects. """ - metadata = serializable_cache["metadata"] - self.lib_version = metadata["lib_version"] - self.creation_timestamp = metadata["creation_timestamp"] - self.device_name = metadata["device_name"] - self.device_capability = metadata["device_capability"] - cache = {} - cache_data = serializable_cache["cache_data"] for key_str, value in cache_data.items(): # Reconstruct the tuple key safely @@ -844,8 +857,10 @@ def choose_one( custom_op, runners, p.get_opt_shapes(), tuning_config) if not is_cache_hit: # Initialize runner and tactic as None in case of no valid tactic or runners are found - best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners( - custom_op, runners, tensors, p, tuning_config, **kwargs) + with nvtx_range(f"{custom_op}, shape {p.get_opt_shapes()}"): + best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners( + custom_op, runners, tensors, p, tuning_config, + **kwargs) new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred self._maybe_sync_cache_data(tuning_config.distributed_tuning_strategy, @@ -882,6 +897,13 @@ def _profile_runners( tuning_config: TuningConfig, **kwargs, ) -> float: + """Profile runners and select the best tactic. + + For multi-rank profiling, only rank 0 performs the actual profiling + to avoid sync issues when different ranks select different tactics. + The results are then broadcasted to all other ranks. + """ + min_time = float('inf') has_tuning_failure_occurred = False best_runner_id, best_tactic = None, None @@ -909,14 +931,15 @@ def _profile_runners( for tac in valid_tactics: try: - time_measured = self._profile_single_kernel( - runner=runner, - inputs=input_tensors, - tactic=tac, - tuning_config=tuning_config, - use_cuda_graph=tuning_config.use_cuda_graph, - **kwargs, - ) + with nvtx_range(f"r{runner_id}, tactic {tac}"): + time_measured = self._profile_single_kernel( + runner=runner, + inputs=input_tensors, + tactic=tac, + tuning_config=tuning_config, + use_cuda_graph=tuning_config.use_cuda_graph, + **kwargs, + ) except Exception as e: # Handle None tensors for optional inputs shapes = self._get_input_sizes(input_tensors) @@ -1026,10 +1049,13 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int): ) stream.synchronize() + if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE: + # Currently only AllReduce will use this strategy, and only MPI parallel will enable tuning. + # TODO: Unified tp barrier for both MPIDist and TorchDist. + if hasattr(self._dist, "tp_comm"): + self._dist.tp_comm.barrier() # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. - # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops) - # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity. if use_cuda_graph: delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream) else: @@ -1052,6 +1078,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int): return start.elapsed_time(end) / repeat + # warm up, no timing for _ in range(self.warmup): runner(input_tensor_batches[-1], tactic=tactic, **kwargs) diff --git a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py index e5dd01638460..55e79f72a119 100644 --- a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py +++ b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py @@ -1,5 +1,5 @@ from operator import getitem -from typing import List, Optional +from typing import Callable, List, Optional import torch from torch._inductor.pattern_matcher import (MULTIPLE, CallFunction, Ignored, @@ -14,13 +14,13 @@ from tensorrt_llm.mapping import Mapping -def register_ar_residual_norm(custom_pass: PatternMatcherPass, - mapping: Mapping): +def register_ar_residual_norm(custom_pass: PatternMatcherPass, mapping: Mapping, + allreduce_func: Callable): residual_key = KeywordArg("residual") trtllm_allreduce_default = CallFunction( - torch.ops.trtllm.allreduce.default, KeywordArg("input"), None, None, - None, None, KeywordArg("workspace"), mapping.tp_group, - KeywordArg("strategy"), int(AllReduceFusionOp.NONE), Ignored(), + allreduce_func.default, KeywordArg("input"), None, None, None, None, + KeywordArg("workspace"), mapping.tp_group, KeywordArg("strategy"), + int(AllReduceFusionOp.NONE), Ignored(), KeywordArg("trigger_completion_at_end")) getitem_x = CallFunction(getitem, trtllm_allreduce_default, 0) add_Tensor = CallFunction(aten.add.Tensor, @@ -56,7 +56,7 @@ def target_pattern( eps: float, trigger_completion_at_end: bool, ): - all_reduce_output = torch.ops.trtllm.allreduce( + all_reduce_output = allreduce_func( input, residual, norm_weight, None, None, workspace, mapping.tp_group, int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM), float(eps), @@ -111,10 +111,11 @@ def check_non_ub_strategy(match, strategy_node) -> bool: def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass, - mapping: Mapping): + mapping: Mapping, + allreduce_func: Callable): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") - allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, + allreduce_default = CallFunction(allreduce_func.default, input_node, KeywordArg("residual"), KeywordArg("gamma"), @@ -165,7 +166,7 @@ def target_pattern( scale: torch.Tensor, trigger_completion_at_end: bool, ): - allreduce = torch.ops.trtllm.allreduce( + allreduce = allreduce_func( input, residual, gamma, scale, None, workspace, mapping.tp_group, int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8), float(eps), @@ -188,10 +189,11 @@ def extra_check(match: Match) -> bool: def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass, - mapping: Mapping): + mapping: Mapping, + allreduce_func: Callable): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") - allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, + allreduce_default = CallFunction(allreduce_func.default, input_node, KeywordArg("residual"), KeywordArg("gamma"), @@ -242,7 +244,7 @@ def target_pattern( scale: torch.Tensor, trigger_completion_at_end: bool, ): - allreduce = torch.ops.trtllm.allreduce( + allreduce = allreduce_func( input, residual, gamma, scale, None, workspace, mapping.tp_group, int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8), float(eps), trigger_completion_at_end) @@ -264,10 +266,11 @@ def extra_check(match: Match) -> bool: def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass, - mapping: Mapping): + mapping: Mapping, + allreduce_func: Callable): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") - allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, + allreduce_default = CallFunction(allreduce_func.default, input_node, KeywordArg("residual"), KeywordArg("gamma"), @@ -313,7 +316,7 @@ def target_pattern( scale: torch.Tensor, trigger_completion_at_end: bool, ): - allreduce = torch.ops.trtllm.allreduce( + allreduce = allreduce_func( input, residual, gamma, scale, None, workspace, mapping.tp_group, int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4), @@ -336,10 +339,11 @@ def extra_check(match: Match) -> bool: def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass, - mapping: Mapping): + mapping: Mapping, + allreduce_func: Callable): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") - allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, + allreduce_default = CallFunction(allreduce_func.default, input_node, KeywordArg("residual"), KeywordArg("gamma"), @@ -385,7 +389,7 @@ def target_pattern( scale: torch.Tensor, trigger_completion_at_end: bool, ): - allreduce = torch.ops.trtllm.allreduce( + allreduce = allreduce_func( input, residual, gamma, scale, None, workspace, mapping.tp_group, int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4), float(eps), trigger_completion_at_end) @@ -407,17 +411,20 @@ def extra_check(match: Match) -> bool: def register_ub_patterns(custom_passes: List[PatternMatcherPass], - mapping: Mapping): + mapping: Mapping, allreduce_func: Callable): def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass): strategy = int(AllReduceStrategy.AUTO) input_node = KeywordArg('input') fusion = KeywordArg('fusion_op') - trtllm_allreduce_default = CallFunction( - torch.ops.trtllm.allreduce.default, input_node, - KeywordArg('residual_in'), KeywordArg('gamma'), KeywordArg('scale'), - None, Ignored(), mapping.tp_group, strategy, fusion, - KeywordArg('eps'), Ignored()) + trtllm_allreduce_default = CallFunction(allreduce_func.default, + input_node, + KeywordArg('residual_in'), + KeywordArg('gamma'), + KeywordArg('scale'), None, + Ignored(), mapping.tp_group, + strategy, fusion, + KeywordArg('eps'), Ignored()) def empty_convert_supported_ar_to_ub( input: torch.Tensor, @@ -667,7 +674,7 @@ def register_ub_finalize_patterns(custom_pass: PatternMatcherPass): torch.ops.trtllm.userbuffers_allreduce_finalize.default, KeywordArg("sharded_residual"), False) trtllm_allreduce_default = CallFunction( - torch.ops.trtllm.allreduce.default, KeywordArg("input"), + torch.ops.trtllm.allreduce, KeywordArg("input"), trtllm_userbuffers_allreduce_finalize_default, KeywordArg("gamma"), KeywordArg("scale"), Ignored(), Ignored(), mapping.tp_group, int(AllReduceStrategy.UB), KeywordArg("fusion_op"), @@ -718,15 +725,28 @@ def target_finalize_pattern( def register_ar_fusions(custom_passes: List[PatternMatcherPass], mapping: Mapping, enable_ub: bool): - register_ar_residual_norm(custom_passes[-1], mapping) + register_ar_residual_norm(custom_passes[-1], mapping, + torch.ops.trtllm.allreduce) + register_ar_residual_norm(custom_passes[-1], mapping, + torch.ops.trtllm.tunable_allreduce) custom_passes.append(PatternMatcherPass()) - register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping) - register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping) - # AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel. - if not enable_ub: - register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping) - register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping) + for allreduce_func in [ + torch.ops.trtllm.allreduce, torch.ops.trtllm.tunable_allreduce + ]: + register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping, + allreduce_func) + register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping, + allreduce_func) + + # AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel. + if not enable_ub: + register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping, + allreduce_func) + register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping, + allreduce_func) if enable_ub: - register_ub_patterns(custom_passes, mapping) + register_ub_patterns(custom_passes, mapping, torch.ops.trtllm.allreduce) + register_ub_patterns(custom_passes, mapping, + torch.ops.trtllm.tunable_allreduce) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 89fb0d05132f..76368412c42e 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -15,18 +15,18 @@ def _register_fake(): @torch.library.register_fake("trtllm::allreduce") def allreduce( - input, - residual, - norm_weight, - scale, - bias, - workspace, - group, - strategy, - op, - eps, - trigger_completion_at_end, - ): + input: torch.Tensor, + residual: Optional[torch.Tensor], + norm_weight: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + workspace: Optional[torch.Tensor], + group: List[int], + strategy: int, + op: int, + eps: float, + trigger_completion_at_end: bool, + ) -> List[torch.Tensor]: from tensorrt_llm.functional import AllReduceFusionOp if op == int(AllReduceFusionOp.NONE): return [torch.empty_like(input)] @@ -61,19 +61,19 @@ def allreduce( @torch.library.register_fake("trtllm::allreduce_pg") def _( - input, - residual, - norm_weight, - scale, - bias, - workspace, - group, - rank, + input: torch.Tensor, + residual: Optional[torch.Tensor], + norm_weight: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + workspace: Optional[torch.Tensor], + group: List[int], + rank: int, pg, - strategy, - op, - eps, - trigger_completion_at_end, + strategy: int, + op: int, + eps: float, + trigger_completion_at_end: bool, ): return allreduce(input, residual, norm_weight, scale, bias, workspace, group, strategy, op, eps, trigger_completion_at_end) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index b26b687ced02..2ee8d29ccca9 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -8,7 +8,9 @@ import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm import deep_gemm from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy from tensorrt_llm.logger import logger +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy, DynamicTensorSpec, OptimizationProfile, TunableRunner, @@ -1652,6 +1654,173 @@ def _( return x.new_empty((b, d), dtype=o_dtype) +class AllReduceRunner(TunableRunner): + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets(8192), + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ), + distributed_tuning_strategy=DistributedTuningStrategy.MERGE, + ) + + def __init__( + self, + tp_size: int, + group: List[int], + op: int, + eps: float, + trigger_completion_at_end: bool, + ): + self.tp_size = tp_size + self.op = op + self.group = group + self.eps = eps + self.trigger_completion_at_end = trigger_completion_at_end + + def unique_id(self): + return ( + self.tp_size, + self.op, + ) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + valid_strategies = [ + # TODO: NCCL_SYMMETRIC will cause hang during tuning process + # AllReduceStrategy.NCCL_SYMMETRIC.value, + AllReduceStrategy.NCCL.value, + ] + # Fallback in allreduceOp is set to NCCL_SYMMETRIC as default + # So we need to check if the workspace size is too large to avoid hanging. + workspace_size = inputs[0].numel() * inputs[0].element_size() + max_workspace_size = CustomAllReduceHelper.max_workspace_size_auto( + self.tp_size, + support_deterministic=False, + ) + if workspace_size > max_workspace_size: + return valid_strategies + + valid_strategies.append(AllReduceStrategy.ONESHOT.value) + + # Additional restrictions for TWOSHOT strategy + if inputs[0].shape[0] >= self.tp_size: + valid_strategies.append(AllReduceStrategy.TWOSHOT.value) + + return valid_strategies + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + ) -> torch.Tensor: + input, residual, norm_weight, scale, bias, workspace = inputs + if tactic == -1: + # TODO: Use NCCL instead of NCCL_SYMMETRIC to avoid hanging during tuning process + tactic = AllReduceStrategy.NCCL.value + + return torch.ops.trtllm.allreduce( + input, + residual, + norm_weight, + scale, + bias, + workspace, + self.group, + tactic, + self.op, + self.eps, + self.trigger_completion_at_end, + ) + + +@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=()) +def tunable_allreduce( + input: torch.Tensor, + residual: Optional[torch.Tensor], + norm_weight: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + workspace: Optional[torch.Tensor], + group: List[int], + strategy: int, + op: int, + eps: float, + trigger_completion_at_end: bool, +) -> List[torch.Tensor]: + + tuner = AutoTuner.get() + + allreduce_runner = AllReduceRunner( + len(group), + group, + op, + eps, + trigger_completion_at_end, + ) + + _, best_tactic = tuner.choose_one( + "trtllm::tunable_allreduce::allreduce", + [allreduce_runner], + AllReduceRunner.tuning_config, + [input, residual, norm_weight, scale, bias, workspace], + ) + + return allreduce_runner( + [input, residual, norm_weight, scale, bias, workspace], + tactic=best_tactic, + ) + + +@tunable_allreduce.register_fake +def _( + input: torch.Tensor, + residual: Optional[torch.Tensor], + norm_weight: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + workspace: Optional[torch.Tensor], + group: List[int], + strategy: int, + op: int, + eps: float, + trigger_completion_at_end: bool, +) -> List[torch.Tensor]: + if op == int(AllReduceFusionOp.NONE): + return [torch.empty_like(input)] + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM): + norm_out = torch.empty_like(input) + residual_out = torch.empty_like(input) + return [norm_out, residual_out] + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8): + quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) + residual_out = torch.empty_like(input) + return [quant_out, residual_out] + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8): + norm_out = torch.empty_like(input) + quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) + residual_out = torch.empty_like(input) + return [norm_out, quant_out, residual_out] + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4): + fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) + quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) + scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) + residual_out = torch.empty_like(input) + return [quant_fp4, scale_fp4, residual_out] + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4): + fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) + quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) + scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) + norm_out = torch.empty_like(input) + residual_out = torch.empty_like(input) + return [norm_out, quant_fp4, scale_fp4, residual_out] + else: + return [torch.empty_like(input)] + + def get_event(event_idx: int): from ..utils import get_model_extra_attrs extra_attrs = get_model_extra_attrs() diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 51de18c7b165..713d728566b8 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -692,7 +692,6 @@ def __init__(self, self._disable_mpi = mpi_disabled() self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce - if self.mapping.tp_size > 1: # Initialize Symmetric Memory AllReduce if needed (before workspace allocation) if self.strategy == AllReduceStrategy.SYMM_MEM: @@ -788,6 +787,7 @@ def forward( input = input.contiguous() # Underlying op requires contiguous input allreduce_strategy = self.strategy + if all_reduce_params is None: all_reduce_params = AllReduceParams() @@ -831,21 +831,42 @@ def forward( "pg": pg.boxed(), } - output = self.all_reduce_op( - input=input, - residual=all_reduce_params.residual, - norm_weight=all_reduce_params.norm_weight, - scale=all_reduce_params.scale, - bias=all_reduce_params.bias, - workspace=self.workspace, - group=self.mapping.tp_group, - strategy=allreduce_strategy, - op=all_reduce_params.fusion_op, - eps=all_reduce_params.eps, - trigger_completion_at_end=all_reduce_params. - trigger_completion_at_end, - **additional_args, - ) + # In case that AutoTuner brings potential perf regression + # TODO: Remove this if no perf regression is observed. + disable_allreduce_autotune = os.environ.get( + "TLLM_DISABLE_ALLREDUCE_AUTOTUNE", "0") == "1" + + if allreduce_strategy == AllReduceStrategy.AUTO and not disable_allreduce_autotune and not self._disable_mpi: + output = torch.ops.trtllm.tunable_allreduce( + input=input, + residual=all_reduce_params.residual, + norm_weight=all_reduce_params.norm_weight, + scale=all_reduce_params.scale, + bias=all_reduce_params.bias, + workspace=self.workspace, + group=self.mapping.tp_group, + strategy=allreduce_strategy, + op=all_reduce_params.fusion_op, + eps=all_reduce_params.eps, + trigger_completion_at_end=all_reduce_params. + trigger_completion_at_end, + ) + else: + output = self.all_reduce_op( + input=input, + residual=all_reduce_params.residual, + norm_weight=all_reduce_params.norm_weight, + scale=all_reduce_params.scale, + bias=all_reduce_params.bias, + workspace=self.workspace, + group=self.mapping.tp_group, + strategy=allreduce_strategy, + op=all_reduce_params.fusion_op, + eps=all_reduce_params.eps, + trigger_completion_at_end=all_reduce_params. + trigger_completion_at_end, + **additional_args, + ) return output if len(output) > 1 else output[0] diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 5397ba9d575d..459034ec0ec5 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -161,7 +161,7 @@ def get_all_reduce_strategy(strategy: str = "AUTO"): "TWOSHOT": AllReduceStrategy.TWOSHOT, "LOWPRECISION": AllReduceStrategy.LOWPRECISION, "MNNVL": AllReduceStrategy.MNNVL, - "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC + "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC, } key = strategy.upper() return maps[key] if key in maps else AllReduceStrategy.AUTO diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index c09abcb1da43..464a446cb375 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -657,7 +657,8 @@ def __init__( eps=config.rms_norm_eps, dtype=config.torch_dtype) - self.all_reduce = AllReduce(mapping=model_config.mapping) + self.all_reduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) self.next_layer_layernorm: RMSNorm = None self.next_attn: LlamaAttention = None diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 8fc4a473d8c4..5bf435c2ff8c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1143,6 +1143,11 @@ def _set_up_spec_metadata( return self.spec_metadata def __del__(self) -> None: + self.model = None + self.model_loader = None + self._release_cuda_graphs() + self.input_processor = None + self.input_processor_with_hash = None if getattr(self, 'ub_buffers', None): for u in self.ub_buffers: ub.ub_deallocate(u.addr) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index ebde6164b735..bd1857dda27e 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -221,7 +221,7 @@ def create_py_executor( tokenizer: Optional[TokenizerBase] = None, profiling_stage_data: Optional[dict] = None, ) -> PyExecutor: - + torch.cuda.set_per_process_memory_fraction(1.0) garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold lora_config = llm_args.lora_config kv_connector_config = llm_args.kv_connector_config diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 0dd86b66ac3d..ff189e3be91b 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -7,7 +7,7 @@ import subprocess # nosec B404 import sys from pathlib import Path -from typing import Any, Dict, Mapping, Optional, Sequence +from typing import Any, Dict, Literal, Mapping, Optional, Sequence import click import torch @@ -42,6 +42,13 @@ _child_p_global: Optional[subprocess.Popen] = None +def help_info_with_stability_tag( + help_str: str, tag: Literal["stable", "beta", "prototype", + "deprecated"]) -> str: + """Append stability info to help string.""" + return f":tag:`{tag}` {help_str}" + + def _signal_handler_cleanup_child(signum, frame): """Signal handler to clean up the child process.""" global _child_p_global @@ -279,27 +286,33 @@ def convert(self, value: Any, param: Optional["click.Parameter"], @click.option("--tokenizer", type=str, default=None, - help="Path | Name of the tokenizer." - "Specify this value only if using TensorRT engine as model.") + help=help_info_with_stability_tag("Path | Name of the tokenizer.", + "beta")) @click.option( "--custom_tokenizer", type=str, default=None, - help= - "Custom tokenizer type: alias (e.g., 'deepseek_v32') or Python import path " - "(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer'). [Experimental]" -) + help=help_info_with_stability_tag( + "Custom tokenizer type: alias (e.g., 'deepseek_v32') or Python import path " + "(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer').", + "prototype")) @click.option("--host", type=str, default="localhost", - help="Hostname of the server.") -@click.option("--port", type=int, default=8000, help="Port of the server.") + help=help_info_with_stability_tag("Hostname of the server.", + "beta")) +@click.option("--port", + type=int, + default=8000, + help=help_info_with_stability_tag("Port of the server.", "beta")) @click.option( "--backend", type=ChoiceWithAlias(["pytorch", "tensorrt", "_autodeploy"], {"trt": "tensorrt"}), default="pytorch", - help="The backend to use to serve the model. Default is pytorch backend.") + help=help_info_with_stability_tag( + "The backend to use to serve the model. Default is pytorch backend.", + "beta")) @click.option( "--custom_module_dirs", type=click.Path(exists=True, @@ -308,143 +321,170 @@ def convert(self, value: Any, param: Optional["click.Parameter"], resolve_path=True), default=None, multiple=True, - help="Paths to custom module directories to import.", + help=help_info_with_stability_tag( + "Paths to custom module directories to import.", "prototype"), ) @click.option('--log_level', type=click.Choice(severity_map.keys()), default='info', - help="The logging level.") + help=help_info_with_stability_tag("The logging level.", "beta")) @click.option("--max_beam_width", type=int, default=BuildConfig.model_fields["max_beam_width"].default, - help="Maximum number of beams for beam search decoding.") + help=help_info_with_stability_tag( + "Maximum number of beams for beam search decoding.", "beta")) @click.option("--max_batch_size", type=int, default=BuildConfig.model_fields["max_batch_size"].default, - help="Maximum number of requests that the engine can schedule.") + help=help_info_with_stability_tag( + "Maximum number of requests that the engine can schedule.", + "beta")) @click.option( "--max_num_tokens", type=int, default=BuildConfig.model_fields["max_num_tokens"].default, - help= - "Maximum number of batched input tokens after padding is removed in each batch." -) + help=help_info_with_stability_tag( + "Maximum number of batched input tokens after padding is removed in each batch.", + "beta")) @click.option( "--max_seq_len", type=int, default=BuildConfig.model_fields["max_seq_len"].default, - help="Maximum total length of one request, including prompt and outputs. " - "If unspecified, the value is deduced from the model config.") + help=help_info_with_stability_tag( + "Maximum total length of one request, including prompt and outputs. " + "If unspecified, the value is deduced from the model config.", "beta")) @click.option("--tensor_parallel_size", "--tp_size", type=int, default=1, - help='Tensor parallelism size.') + help=help_info_with_stability_tag('Tensor parallelism size.', + 'beta')) @click.option("--pipeline_parallel_size", "--pp_size", type=int, default=1, - help='Pipeline parallelism size.') + help=help_info_with_stability_tag('Pipeline parallelism size.', + 'beta')) @click.option("--context_parallel_size", "--cp_size", type=int, default=1, - help='Context parallelism size.') + help=help_info_with_stability_tag('Context parallelism size.', + 'beta')) @click.option("--moe_expert_parallel_size", "--ep_size", type=int, default=None, - help="expert parallelism size") + help=help_info_with_stability_tag("expert parallelism size", + "beta")) @click.option("--moe_cluster_parallel_size", "--cluster_size", type=int, default=None, - help="expert cluster parallelism size") -@click.option("--gpus_per_node", - type=int, - default=None, - help="Number of GPUs per node. Default to None, and it will be " - "detected automatically.") + help=help_info_with_stability_tag( + "expert cluster parallelism size", "beta")) +@click.option( + "--gpus_per_node", + type=int, + default=None, + help=help_info_with_stability_tag( + "Number of GPUs per node. Default to None, and it will be detected automatically.", + "beta")) @click.option("--free_gpu_memory_fraction", "--kv_cache_free_gpu_memory_fraction", type=float, default=0.9, - help="Free GPU memory fraction reserved for KV Cache, " - "after allocating model weights and buffers.") -@click.option( - "--num_postprocess_workers", - type=int, - default=0, - help="[Experimental] Number of workers to postprocess raw responses " - "to comply with OpenAI protocol.") + help=help_info_with_stability_tag( + "Free GPU memory fraction reserved for KV Cache, " + "after allocating model weights and buffers.", "beta")) +@click.option("--num_postprocess_workers", + type=int, + default=0, + help=help_info_with_stability_tag( + "Number of workers to postprocess raw responses " + "to comply with OpenAI protocol.", "prototype")) @click.option("--trust_remote_code", is_flag=True, default=False, - help="Flag for HF transformers.") + help=help_info_with_stability_tag("Flag for HF transformers.", + "beta")) @click.option("--revision", type=str, default=None, - help="The revision to use for the HuggingFace model " - "(branch name, tag name, or commit id).") + help=help_info_with_stability_tag( + "The revision to use for the HuggingFace model " + "(branch name, tag name, or commit id).", "beta")) @click.option( "--config", "--extra_llm_api_options", "extra_llm_api_options", type=str, default=None, - help= - "Path to a YAML file that overwrites the parameters specified by trtllm-serve. " - "Can be specified as either --config or --extra_llm_api_options.") + help=help_info_with_stability_tag( + "Path to a YAML file that overwrites the parameters specified by trtllm-serve. " + "Can be specified as either --config or --extra_llm_api_options.", + "prototype")) @click.option( "--reasoning_parser", type=click.Choice(ReasoningParserFactory.parsers.keys()), default=None, - help="[Experimental] Specify the parser for reasoning models.", + help=help_info_with_stability_tag( + "Specify the parser for reasoning models.", "prototype"), ) @click.option( "--tool_parser", type=click.Choice(ToolParserFactory.parsers.keys()), default=None, - help="[Experimental] Specify the parser for tool models.", + help=help_info_with_stability_tag("Specify the parser for tool models.", + "prototype"), ) @click.option("--metadata_server_config_file", type=str, default=None, - help="Path to metadata server config file") + help=help_info_with_stability_tag( + "Path to metadata server config file", "prototype")) @click.option( "--server_role", type=str, default=None, - help="Server role. Specify this value only if running in disaggregated mode." -) + help=help_info_with_stability_tag( + "Server role. Specify this value only if running in disaggregated mode.", + "prototype")) @click.option( "--fail_fast_on_attention_window_too_large", is_flag=True, default=False, - help= - "Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache." -) + help=help_info_with_stability_tag( + "Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache.", + "prototype")) @click.option("--otlp_traces_endpoint", type=str, default=None, - help="Target URL to which OpenTelemetry traces will be sent.") + help=help_info_with_stability_tag( + "Target URL to which OpenTelemetry traces will be sent.", + "prototype")) @click.option("--disagg_cluster_uri", type=str, default=None, - help="URI of the disaggregated cluster.") + help=help_info_with_stability_tag( + "URI of the disaggregated cluster.", "prototype")) @click.option("--enable_chunked_prefill", is_flag=True, default=False, - help="Enable chunked prefill") + help=help_info_with_stability_tag("Enable chunked prefill", + "prototype")) @click.option("--media_io_kwargs", type=str, default=None, - help="Keyword arguments for media I/O.") + help=help_info_with_stability_tag( + "Keyword arguments for media I/O.", "prototype")) @click.option("--chat_template", type=str, default=None, - help="[Experimental] Specify a custom chat template. " - "Can be a file path or one-liner template string") + help=help_info_with_stability_tag( + "Specify a custom chat template. " + "Can be a file path or one-liner template string", + "prototype")) def serve( model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str], host: str, port: int, log_level: str, backend: str, max_beam_width: int, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 54b958993b7b..8555407c6e99 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2264,6 +2264,10 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, # task.evaluate(llm, # extra_evaluator_kwargs=dict(apply_chat_template=True)) + import gc + gc.collect() + torch.cuda.empty_cache() + @skip_pre_blackwell @pytest.mark.parametrize( "tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend", @@ -2629,14 +2633,12 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, if get_sm_version() == 100 or get_sm_version() == 103: moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384) - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6) else: if moe_backend != "_DEFAULT": pytest.skip("Not supported MoE backend!") moe_config = MoeConfig() - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, @@ -2707,8 +2709,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, pytest.skip(f"{moe_backend} backend does not support SM 120 or 121") moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384) - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) cuda_graph_config = CudaGraphConfig( enable_padding=True, max_batch_size=max_batch_size) if cuda_graph else None @@ -2771,8 +2772,7 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size, pytest.skip(f"{moe_backend} backend does not support SM 120 or 121") moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384) - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, - tokens_per_block=64) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) cuda_graph_config = CudaGraphConfig( enable_padding=True, max_batch_size=max_batch_size) if cuda_graph else None diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 8fb4928504a9..318f621025b6 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -15,6 +15,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: + - unittest/_torch/misc/test_autotuner.py::test_autotuner_distributed_strategy - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_nvfp4[enable_configurable_moe-disable_finalize_fusion-TRTLLM-dtype1] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml index 57c3b6fd8106..1174f6066c95 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml @@ -32,10 +32,10 @@ l0_gb200_multi_nodes: backend: pytorch tests: - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] TIMEOUT (180) - - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] TIMEOUT (180) - - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] TIMEOUT (180) - - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp] TIMEOUT (180) - - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp] TIMEOUT (180) + - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] TIMEOUT (180) ISOLATION + - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] TIMEOUT (180) ISOLATION + - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp] TIMEOUT (180) ISOLATION + - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp] TIMEOUT (180) ISOLATION - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] TIMEOUT (90) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 6cecb313529b..7513bcf8998d 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -321,7 +321,6 @@ accuracy/test_llm_api_pytorch.py::TestNemotronH_47B_Base::test_reasoning_fp8_pre accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4 SKIP (https://nvbugs/5640697) test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] SKIP (https://nvbugs/5648560) test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False] SKIP (https://nvbugs/5648560) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp] SKIP (https://nvbugs/5629136) examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5568052) accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5648441) accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype SKIP (https://nvbugs/5648441) @@ -381,29 +380,16 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8ep accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-cuda_graph=True] SKIP (https://nvbugs/5707145) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] SKIP (https://nvbugs/5596343) unittest/_torch/speculative/test_spec_gate.py::test_spec_gate_e2e SKIP (https://nvbugs/5710045) -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5769815) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5715568) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] SKIP (https://nvbugs/5715568) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5721661) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] SKIP (https://nvbugs/5715568) unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py::test_flashinfer_attention_op_context_input_pos[cuda-dtype0-4-8-seq6] SKIP (https://nvbugs/5721907) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5722629) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_2gpus[cutlass-two_model-overlap_scheduler] SKIP (https://nvbugs/5702826) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-overlap_scheduler] SKIP (https://nvbugs/5702826) unittest/llmapi/test_llm_pytorch.py::test_llm_reward_model SKIP (https://nvbugs/5670458) accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740377) accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5740087) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740087) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] SKIP (https://nvbugs/5740075) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] SKIP (https://nvbugs/5740075) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740075) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5740075) unittest/_torch/modeling/test_modeling_out_of_tree.py::TestOutOfTree::test_llm_api[False] SKIP (https://nvbugs/5739981) unittest/_torch/modeling/test_modeling_out_of_tree.py::TestOutOfTree::test_llm_api[True] SKIP (https://nvbugs/5739981) unittest/_torch/modeling/test_modeling_out_of_tree.py::TestOutOfTree::test_serve[True] SKIP (https://nvbugs/5739981) @@ -411,24 +397,14 @@ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5596337) accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] SKIP (https://nvbugs/5721672) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] SKIP (https://nvbugs/5741304) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740377, https://nvbugs/5740075) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5740087) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075) unittest/_torch/multi_gpu/test_allreduce.py::test_allreduce_fusion_patterns[2-residual_rms_norm_out_quant_fp8-hidden:7168-seqlen:8192] SKIP (https://nvbugs/5741392) unittest/executor/test_rpc.py::TestRpcCorrectness::test_incremental_task_async SKIP (https://nvbugs/5741476) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5740377) accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] SKIP (https://nvbugs/5740377) examples/test_phi.py::test_phi_fp8_with_bf16_lora[phi-2] SKIP (https://nvbugs/5744293) examples/test_phi.py::test_llm_phi_1node_2gpus_summary[Phi-3.5-MoE-instruct-nb:1] SKIP (https://nvbugs/5744293) examples/test_phi.py::test_llm_phi_quantization_1gpu[phi-2-fp8-bfloat16] SKIP (https://nvbugs/5744293) test_e2e.py::test_trtllm_bench_llmapi_launch[pytorch_backend-llama-v3-llama3-8b] SKIP (https://nvbugs/5744432) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5740087) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5740075) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740075) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5740075) test_e2e.py::test_trtllm_serve_multimodal_example SKIP (https://nvbugs/5747920) test_e2e.py::test_trtllm_serve_example SKIP (https://nvbugs/5747938) unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py::test_build_ad[meta-llama/Llama-4-Scout-17B-16E-Instruct-llm_extra_args8] SKIP (https://nvbugs/5747878) @@ -445,7 +421,6 @@ accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype SKIP (http accuracy/test_cli_flow.py::TestPhi3_5MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5744293) unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py::test_build_run_llama4_vlm SKIP (https://nvbugs/5747878) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=TRTLLM] SKIP (https://nvbugs/5740377) -unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py::test_flashinfer_fused_moe_matches_torch_moe SKIP (https://nvbugs/5752521) cpp/test_multi_gpu.py::TestDisagg::test_symmetric_executor[gpt-2proc-mpi_kvcache-90] SKIP (https://nvbugs/5755941) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] SKIP (https://nvbugs/5748600) examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForQuestionAnswering-bert/bert-base-cased-squad2] SKIP (https://nvbugs/5608979) @@ -485,7 +460,6 @@ unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627) examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5740075) examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] SKIP (https://nvbugs/5715568) unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741) unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741) triton_server/test_triton.py::test_gpt_gather_logits[gpt-gather-logits] SKIP (https://nvbugs/5766960) @@ -540,3 +514,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backen disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] SKIP (https://nvbugs/5779536) perf/test_perf_sanity.py::test_e2e[disagg_upload-deepseek-r1-fp4_1k1k_ctx1_gen1_dep8_bs768_eplb0_mtp0_ccb-UCX] SKIP (https://nvbugs/5778381) +unittest/_torch/attention/test_flashinfer_star_attn.py::TestStarAttention::test_flashinfer_star_attention[num_layers:2-num_heads:32-num_kv_heads:8-head_dim:64-anchor_size:64-block_size:64-dtype:torch.float16] SKIP (https://nvbugs/5781389) +unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_reducescatter_pg_op[var_len:True-seqlen:16-hidden:128] SKIP (https://nvbugs/5781383) +cpp/test_e2e.py::test_model[-mamba-86] SKIP (https://nvbugs/5781665) +unittest/llmapi/test_llm_multi_gpu_pytorch.py::test_tinyllama_logits_processor_tp2pp2 SKIP (https://nvbugs/5781731) diff --git a/tests/microbenchmarks/all_reduce.py b/tests/microbenchmarks/all_reduce.py index bd5ceb8826b2..d2a9adf453cb 100644 --- a/tests/microbenchmarks/all_reduce.py +++ b/tests/microbenchmarks/all_reduce.py @@ -27,17 +27,21 @@ import tensorrt_llm as tllm from tensorrt_llm import Mapping -from tensorrt_llm._torch.distributed import AllReduce, AllReduceFusionOp +from tensorrt_llm._torch.autotuner import AutoTuner, autotune +from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, + MPIDist, TorchDist) from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm._utils import (get_sm_version, local_mpi_rank, local_mpi_size, - nvtx_range) + mpi_disabled, nvtx_range) from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy +from tensorrt_llm.logger import logger from tensorrt_llm.plugin.plugin import CustomAllReduceHelper def profile_allreduce( mapping: Mapping, + dist: TorchDist | MPIDist, enable_cudagraph: bool = False, inner_loop=200, outer_loop=10, @@ -49,7 +53,6 @@ def profile_allreduce( scale=None, bias=None, ): - tllm.logger.set_level('error') allreduce_params = AllReduceParams( fusion_op=fusion, @@ -62,8 +65,8 @@ def profile_allreduce( allreduce = AllReduce(mapping=mapping, strategy=strategy) - def func(x): - for _ in range(inner_loop): + def func(x, loop_num=inner_loop): + for _ in range(loop_num): output = allreduce(x, all_reduce_params=allreduce_params) return output if fusion == AllReduceFusionOp.NONE else output[0] @@ -75,16 +78,17 @@ def func(x): with torch.cuda.stream(stream), nvtx_range( f"allreudce: shape={input.size(0)}x{input.size(1)} fusion={fusion} strategy={strategy}" ): + with autotune(): + func(input, loop_num=1) + if enable_cudagraph: # CUDA graph warmup then capture for _ in range(2): - func(input) + func(input, loop_num=1) with torch.cuda.graph(graph, stream=stream): output = func(input) - # warmup for no cuda graph - func(input) - tllm.mpi_barrier() + dist.barrier() # add delay to avoid the effect of host time overhead delay_kernel(20000, stream) @@ -124,7 +128,6 @@ def allreduce_benchmark( save_csv: str = None, enable_auto: bool = False, ): - tllm.logger.set_level('error') world_size = tllm.mpi_world_size() rank = tllm.mpi_rank() local_rank = local_mpi_rank() @@ -134,6 +137,15 @@ def allreduce_benchmark( cudart.cudaSetDevice(local_rank) mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size) + if mpi_disabled(): + dist = TorchDist(mapping=mapping) + else: + dist = MPIDist(mapping=mapping) + + logger.set_rank(mapping.rank) + + AutoTuner.get().setup_distributed_state(mapping, dist) + sm_version = get_sm_version() if world_size == 1: @@ -148,11 +160,12 @@ def allreduce_benchmark( shape_list = [] if explore_2d: - num_seqs_list = [ + num_tokens_list = [ 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 ] hidden_size_list = [128, 256, 512, 1024, 2048, 4096, 8192] - for num_tokens, hidden_size in product(num_seqs_list, hidden_size_list): + for num_tokens, hidden_size in product(num_tokens_list, + hidden_size_list): shape_list.append((num_tokens, hidden_size)) else: min_size, max_size, ratio = [int(i) for i in test_range.split(",")] @@ -214,6 +227,7 @@ def allreduce_benchmark( median_ms = profile_allreduce( mapping=mapping, + dist=dist, enable_cudagraph=enable_cudagraph, inner_loop=inner_loop, outer_loop=outer_loop, @@ -240,6 +254,13 @@ def allreduce_benchmark( }) ]) + # print the new record in a single line instead of a dataframe + if mapping.rank == 0: + print( + f"num_tokens: {num_tokens}, hidden_size: {hidden_size}, strategy: {strategy.name}, fusion: {fusion.name}, time (us): {median_ms * 1000}" + ) + + AutoTuner.get().print_profiling_cache() # print the dataframe if mapping.rank == 0: pd.set_option('display.max_rows', None) diff --git a/tests/scripts/allreduce_perf/allreduce_perf_viz.py b/tests/scripts/allreduce_perf/allreduce_perf_viz.py index 4f89a97fb890..57e08f934689 100644 --- a/tests/scripts/allreduce_perf/allreduce_perf_viz.py +++ b/tests/scripts/allreduce_perf/allreduce_perf_viz.py @@ -570,7 +570,6 @@ def main(): if not os.path.exists(os.path.join(args.data_dir, "viz")): os.makedirs(os.path.join(args.data_dir, "viz")) - os.makedirs(os.path.join(args.data_dir, "viz", fusion_op)) for tp_size in tp_size_list: case_name = f"benchmark.tp{tp_size}.sm{get_sm_version()}" @@ -581,7 +580,10 @@ def main(): df = pd.read_csv(Path(fname)) for fusion_op in fusion_op_list: - os.makedirs(os.path.join(args.data_dir, "viz", fusion_op)) + # if not exists, create the directory + if not os.path.exists(os.path.join(args.data_dir, "viz", + fusion_op)): + os.makedirs(os.path.join(args.data_dir, "viz", fusion_op)) if df is not None: print(f"\n=== TP Size: {tp_size} ===") @@ -599,7 +601,7 @@ def main(): save_path=viz_path_best_strategy) # Create strategy difference heatmaps and save to data/viz directory - viz_path_diff = f"{os.path.dirname(__file__)}/{args.data_dir}/viz/{fusion_op}/{case_name}_strategy_difference_heatmap.png" + viz_path_diff = f"{args.data_dir}/viz/{fusion_op}/{case_name}_strategy_difference_heatmap.png" visualize_strategy_difference_heatmaps(df, fusion_op, save_path=viz_path_diff) diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index a6116d544f27..32d0581e6e3b 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -17,8 +17,10 @@ FakeTensor, OptimizationProfile, StaticDim, TunableRunner, TuningConfig, autotune) +from tensorrt_llm._torch.distributed.communicator import MPIDist, TorchDist from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets, next_positive_power_of_2) +from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -323,8 +325,9 @@ def test_multiple_dynamic_shapes_cache(): # Do tuning with a sample input x = torch.randn(3, 64) temp_dir = tempfile.TemporaryDirectory() - with autotune(cache_path=os.path.join(temp_dir.name, - "test_multiple_dynamic_shapes.json")): + cache_path = os.path.join(temp_dir.name, + "test_multiple_dynamic_shapes.json") + with autotune(cache_path=cache_path): tuner = AutoTuner.get() runner, tactic = tuner.choose_one("test_multiple_dynamic_shapes", runners, tuning_config, [x, w]) @@ -336,8 +339,7 @@ def test_multiple_dynamic_shapes_cache(): # Verify cache size - should have 12 entries (3x4 combinations) # We also test the cache serialization and deserialization here. AutoTuner.get().profiling_cache.clear() - AutoTuner.get().profiling_cache.load_cache( - os.path.join(temp_dir.name, "test_multiple_dynamic_shapes.rank0.json")) + AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0) cache_entries = tuner.profiling_cache.get_specific_custom_op( "test_multiple_dynamic_shapes") @@ -427,8 +429,9 @@ def test_autotuner_tuning_configs(): use_cuda_graph=False, ) temp_dir = tempfile.TemporaryDirectory() - with autotune(cache_path=os.path.join( - temp_dir.name, "test_autotuner_tactic_configs.json")): + cache_path = os.path.join(temp_dir.name, + "test_autotuner_tactic_configs.json") + with autotune(cache_path=cache_path): tuner = AutoTuner.get() runner, best_tactic = tuner.choose_one("test_autotuner_tactic_configs", runners, tuning_config, [x, w]) @@ -437,8 +440,7 @@ def test_autotuner_tuning_configs(): # Test if the tactic can be loaded from cache correctly AutoTuner.get().profiling_cache.clear() - AutoTuner.get().profiling_cache.load_cache( - os.path.join(temp_dir.name, "test_autotuner_tactic_configs.rank0.json")) + AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0) # No further tuning should be performed. runner, deserialized_tactic = tuner.choose_one( @@ -646,9 +648,14 @@ def _distributed_worker_function(world_size, strategy): rank=rank, tp_size=world_size, pp_size=1) + if mpi_disabled(): + dist = TorchDist(mapping=mapping) + else: + dist = MPIDist(mapping=mapping) + tuner = AutoTuner.get() tuner.clear_cache() - tuner.setup_distributed_state(mapping) + tuner.setup_distributed_state(mapping, dist) x = torch.randn(16, 32, device='cuda') w = torch.randn(32, 64, device='cuda') @@ -663,12 +670,28 @@ def _distributed_worker_function(world_size, strategy): runner = DistributedGemmRunner(prefer_tactics=prefer_tactics) config = TuningConfig(distributed_tuning_strategy=strategy) - cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None) - with autotune(tune_mode=True, cache_path=cache_path): + if rank == 0: + temp_dir = tempfile.TemporaryDirectory() + # rank 0 should broadcast the cache path to all ranks + cache_path = os.path.join(temp_dir.name, "test_distributed_tuning.json") + dist.broadcast(cache_path, root=0) + else: + cache_path = dist.broadcast(None, root=0) + + with autotune(cache_path=cache_path): tuner.choose_one(custom_op=f"test_distributed_{strategy}", runners=[runner], tuning_config=config, inputs=inputs) + + # Check only one file is created in the cache path + assert len(os.listdir(os.path.dirname( + cache_path))) == 1, "Only one rank file should be created" + + # Check cache for distributed tuning + AutoTuner.get().profiling_cache.clear() + AutoTuner.get().profiling_cache.load_cache(cache_path, rank) + selected_runner, best_tactic = tuner.choose_one( custom_op=f"test_distributed_{strategy}", runners=[runner], @@ -706,8 +729,7 @@ def _distributed_worker_function(world_size, strategy): ], ) @pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) -def test_distributed_broadcast_strategy(strategy, mpi_pool_executor): - """Test broadcast strategy with real MPI processes.""" +def test_autotuner_distributed_strategy(strategy, mpi_pool_executor): world_size = 2 # Use MPIPoolExecutor to run distributed test results = mpi_pool_executor.map( diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index 5051998c5a6c..b531d826435f 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -21,14 +21,17 @@ import pytest import torch from mpi4py import MPI -from utils.util import skip_pre_blackwell +from utils.util import check_accuracy, skip_pre_blackwell import tensorrt_llm +from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, AllReduceStrategy, - MoEAllReduce, MoEAllReduceParams) + MoEAllReduce, MoEAllReduceParams, + MPIDist, TorchDist) from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode from tensorrt_llm._torch.modules.rms_norm import RMSNorm +from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.mapping import Mapping sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -62,8 +65,16 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): return y -def run_single_rank(tensor_parallel_size, single_rank_forward_func, input, - residual, weights, hidden_size, dtype, fusion_op): +def run_single_rank( + tensor_parallel_size, + single_rank_forward_func, + input, + residual, + weights, + hidden_size, + dtype, + fusion_op, +): rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(rank) try: @@ -91,10 +102,16 @@ def run_moe_single_rank(tensor_parallel_size, single_rank_forward_func, @torch.inference_mode() -def run_allreduce_op(x: torch.Tensor, residual: torch.Tensor, hidden_size: int, - dtype: torch.dtype, tensor_parallel_size: int, - tensor_parallel_rank: int, weights: torch.Tensor, - fusion_op: AllReduceFusionOp): +def run_allreduce_op( + x: torch.Tensor, + residual: torch.Tensor, + hidden_size: int, + dtype: torch.dtype, + tensor_parallel_size: int, + tensor_parallel_rank: int, + weights: torch.Tensor, + fusion_op: AllReduceFusionOp, +): def e2m1_and_ufp8sf_scale_to_float_v2(e2m1_tensor, ufp8_scale_tensor, @@ -116,6 +133,12 @@ def e2m1_and_ufp8sf_scale_to_float_v2(e2m1_tensor, tp_size=tensor_parallel_size, rank=tensor_parallel_rank, ) + if mpi_disabled(): + dist = TorchDist(mapping=mapping) + else: + dist = MPIDist(mapping=mapping) + + AutoTuner.get().setup_distributed_state(mapping, dist) linear = Linear( in_features=hidden_size, out_features=hidden_size, @@ -128,14 +151,12 @@ def e2m1_and_ufp8sf_scale_to_float_v2(e2m1_tensor, allreduce = AllReduce(mapping=mapping) norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda() + allreduce = AllReduce(mapping=mapping).cuda() + scale = torch.tensor(1.0, dtype=torch.float32).cuda() linear.load_weights([dict(weight=weights[0])]) norm.weight.data.copy_(norm_weight) - def calc_allreduce(x, res): - linear_out = linear(x) - return [linear_out] - def calc_fused_allreduce(x, res): linear_out = linear( x, all_reduce_params=AllReduceParams(enable_allreduce=False)) @@ -150,7 +171,7 @@ def calc_fused_allreduce(x, res): eps=eps, ), ) - return output + return [output] if fusion_op == AllReduceFusionOp.NONE else output def calc_residual_rms_norm_quant_fp8(x, res): quant_fp8, residual_out = calc_fused_allreduce(x, res) @@ -215,7 +236,7 @@ def ref_residual_rms_norm_out_quant_nvfp4(x, res): return norm_out, dequant_fp4, residual_out fusion_op_to_func = { - AllReduceFusionOp.NONE: (calc_allreduce, ref_allreduce), + AllReduceFusionOp.NONE: (calc_fused_allreduce, ref_allreduce), AllReduceFusionOp.RESIDUAL_RMS_NORM: (calc_fused_allreduce, ref_residual_rms_norm), AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8: @@ -234,26 +255,19 @@ def ref_residual_rms_norm_out_quant_nvfp4(x, res): # common allreduce path xs = torch.chunk(x.clone(), tensor_parallel_size, dim=-1) - calc_output = calc_func(xs[tensor_parallel_rank], residual) + + # trigger autotune + with autotune(): + calc_output = calc_func(xs[tensor_parallel_rank], residual) + ref_output = ref_func(xs[tensor_parallel_rank], residual) for calc_output_tensor, ref_output_tensor in zip(calc_output, ref_output): - rtol, atol = 0.05, 0.15 - try: - torch.testing.assert_close( - calc_output_tensor, - ref_output_tensor, - rtol=rtol, - atol=atol, - ) - except AssertionError: - # Calculate percentage of mismatched elements - mismatched = torch.abs(calc_output_tensor - ref_output_tensor) > ( - rtol * torch.abs(ref_output_tensor) + atol) - mismatch_percentage = (mismatched.sum() / mismatched.numel()) - - # If more than 1% elements mismatch, raise the error - assert mismatch_percentage < 0.01, f"Large mismatched elements encountered" + check_accuracy(calc_output_tensor, + ref_output_tensor, + atol=0.05, + rtol=0.15, + percent=0.99) @pytest.mark.skipif(torch.cuda.device_count() < 2, diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index 6de03d190862..00abc2229b96 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -460,7 +460,7 @@ def run_single_rank_ub_pass( # 3 AR_NORM replacement # 3 Scaled MM Prologue # 2 UB Finalize Removal - assert backend.match_count == [3, 0, 2, 0, 3, 0, 3, 0, 2, 0] + assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0] torch.cuda.synchronize() if rank == 0: @@ -759,7 +759,7 @@ def run_single_rank_ub_mm_add_pass(tensor_parallel_size, num_tokens, # 3 AR_NORM replacement # 3 Prologue # 1 UB Finalize Removal - assert backend.match_count == [3, 0, 0, 3, 0, 3, 0, 1, 0] + assert backend.match_count == [3, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0] torch.cuda.synchronize() if rank == 0: @@ -993,7 +993,7 @@ def block_scale_unswizzled(scale): # 3 AR_NORM replacement # 3 Scaled MM Prologue # 2 UB Finalize Removal - assert backend.match_count == [3, 0, 2, 0, 3, 0, 3, 0, 2, 0] + assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0] torch.cuda.synchronize() torch.testing.assert_close(output_fused, output_ref,