diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index fa9a84713d4..6a3ed6d4837 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -34,11 +34,14 @@ from tensorrt_llm.llmapi.llm_args import TorchLlmArgs, TrtLlmArgs from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict from tensorrt_llm.llmapi.mpi_session import find_free_ipc_addr -from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory +from tensorrt_llm.llmapi.reasoning_parser import (ReasoningParserFactory, + resolve_auto_reasoning_parser) from tensorrt_llm.logger import logger, severity_map from tensorrt_llm.mapping import CpType from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer from tensorrt_llm.serve.tool_parser import ToolParserFactory +from tensorrt_llm.serve.tool_parser.tool_parser_factory import \ + resolve_auto_tool_parser from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir # Global variable to store the Popen object of the child process @@ -659,17 +662,19 @@ def convert(self, value: Any, param: Optional["click.Parameter"], "prototype")) @click.option( "--reasoning_parser", - type=click.Choice(ReasoningParserFactory.keys()), + type=click.Choice(["auto"] + list(ReasoningParserFactory.keys())), default=None, help=help_info_with_stability_tag( - "Specify the parser for reasoning models.", "prototype"), + "Specify the parser for reasoning models. " + "Use 'auto' to automatically select based on the model.", "prototype"), ) @click.option( "--tool_parser", - type=click.Choice(ToolParserFactory.parsers.keys()), + type=click.Choice(["auto"] + list(ToolParserFactory.parsers.keys())), default=None, - help=help_info_with_stability_tag("Specify the parser for tool models.", - "prototype"), + help=help_info_with_stability_tag( + "Specify the parser for tool models. " + "Use 'auto' to automatically select based on the model.", "prototype"), ) @click.option("--metadata_server_config_file", type=str, @@ -762,6 +767,34 @@ def serve( """ logger.set_level(log_level) + if tool_parser == "auto": + resolved = resolve_auto_tool_parser(model) + if resolved is None: + raise click.BadParameter( + f"Cannot auto-detect tool parser for model '{model}'. " + f"Supported model types for auto-detection: qwen2, qwen3, " + f"qwen3_moe, qwen3_5, qwen3_5_moe, qwen3_next, deepseek_v3, " + f"deepseek_v32, kimi_k2, kimi_k25, glm4. " + f"Please specify a parser explicitly: " + f"{list(ToolParserFactory.parsers.keys())}", + param_hint="--tool_parser") + logger.info(f"Auto-detected tool parser: {resolved}") + tool_parser = resolved + + if reasoning_parser == "auto": + resolved = resolve_auto_reasoning_parser(model) + if resolved is None: + raise click.BadParameter( + f"Cannot auto-detect reasoning parser for model '{model}'. " + f"Supported model types for auto-detection: qwen3, qwen3_moe, " + f"qwen3_5, qwen3_5_moe, qwen3_next, deepseek_v3 (R1 only), " + f"deepseek_v32 (R1 only), nemotron_h. " + f"Please specify a parser explicitly: " + f"{list(ReasoningParserFactory.keys())}", + param_hint="--reasoning_parser") + logger.info(f"Auto-detected reasoning parser: {resolved}") + reasoning_parser = resolved + for custom_module_dir in custom_module_dirs: try: import_custom_module_from_dir(custom_module_dir) diff --git a/tensorrt_llm/llmapi/reasoning_parser.py b/tensorrt_llm/llmapi/reasoning_parser.py index eb10011d16e..048f0218963 100644 --- a/tensorrt_llm/llmapi/reasoning_parser.py +++ b/tensorrt_llm/llmapi/reasoning_parser.py @@ -1,5 +1,7 @@ +import json from abc import ABC, abstractmethod from dataclasses import dataclass +from pathlib import Path from typing import Any, Optional, Type @@ -161,6 +163,41 @@ def parse_delta(self, delta_text: str) -> ReasoningParserResult: "Unreachable code reached in `DeepSeekR1Parser.parse_delta`") +MODEL_TYPE_TO_REASONING_PARSER: dict[str, str] = { + "qwen3": "qwen3", + "qwen3_moe": "qwen3", + "qwen3_5": "qwen3", + "qwen3_5_moe": "qwen3", + "qwen3_next": "qwen3", + "deepseek_v3": "deepseek-r1", + "deepseek_v32": "deepseek-r1", + "nemotron_h": "nano-v3", +} + + +def resolve_auto_reasoning_parser(model: str) -> Optional[str]: + """Resolve 'auto' reasoning parser by reading the model's HF config. + + For DeepSeek models, only maps to deepseek-r1 if the model path + suggests it is a reasoning model (contains 'R1' in the name). + """ + config_path = Path(model) / "config.json" + if not config_path.exists(): + return None + + with open(config_path) as f: + config = json.load(f) + + model_type = config.get("model_type", "") + + if model_type in ("deepseek_v3", "deepseek_v32"): + model_name = Path(model).name.lower() + if "r1" not in model_name: + return None + + return MODEL_TYPE_TO_REASONING_PARSER.get(model_type) + + @register_reasoning_parser("nano-v3") class NemotronV3ReasoningParser(DeepSeekR1Parser): """Reasoning parser for Nemotron Nano v3. diff --git a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py index f3bf95cd941..9b22031ca79 100644 --- a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py +++ b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py @@ -1,4 +1,6 @@ -from typing import Type +import json +from pathlib import Path +from typing import Optional, Type from .base_tool_parser import BaseToolParser from .deepseekv3_parser import DeepSeekV3Parser @@ -9,6 +11,33 @@ from .qwen3_coder_parser import Qwen3CoderToolParser from .qwen3_tool_parser import Qwen3ToolParser +MODEL_TYPE_TO_TOOL_PARSER: dict[str, str] = { + "qwen2": "qwen3", + "qwen3": "qwen3", + "qwen3_moe": "qwen3", + "qwen3_5": "qwen3", + "qwen3_5_moe": "qwen3", + "qwen3_next": "qwen3", + "deepseek_v3": "deepseek_v3", + "deepseek_v32": "deepseek_v32", + "kimi_k2": "kimi_k2", + "kimi_k25": "kimi_k2", + "glm4": "glm4", +} + + +def resolve_auto_tool_parser(model: str) -> Optional[str]: + """Resolve 'auto' tool parser by reading the model's HF config.""" + config_path = Path(model) / "config.json" + if not config_path.exists(): + return None + + with open(config_path) as f: + config = json.load(f) + + model_type = config.get("model_type", "") + return MODEL_TYPE_TO_TOOL_PARSER.get(model_type) + class ToolParserFactory: parsers: dict[str, Type[BaseToolParser]] = {