Skip to content

Commit b398e02

Browse files
committed
[TRTLLM-10076,TRTLLM-10079,TRTLLM-10229,TRTLLM-10078][feat] Serve CLI improvements: renames, new flags, and mm_embedding_serve enhancements
- TRTLLM-10076: Update --tokenizer description for PyTorch backend, add --hf_revision alias for --revision with deprecation warning, support hf_revision key in YAML config, add --enable_attention_dp flag - TRTLLM-10079: mm_embedding_serve: add --config alias for --extra_encoder_options, expose --hf_revision, --free_gpu_memory_fraction, --tensor_parallel_size - TRTLLM-10229: Add --config alias for --config_file in disaggregated and disaggregated_mpi_worker commands - TRTLLM-10078: Improve --server_role help message with role descriptions Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> Made-with: Cursor
1 parent 2afe11d commit b398e02

File tree

2 files changed

+87
-24
lines changed

2 files changed

+87
-24
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import socket
77
import subprocess # nosec B404
88
import sys
9+
import warnings
910
from pathlib import Path
1011
from typing import Any, Dict, Literal, Mapping, Optional, Sequence
1112

@@ -154,6 +155,7 @@ def get_llm_args(
154155
fail_fast_on_attention_window_too_large: bool = False,
155156
otlp_traces_endpoint: Optional[str] = None,
156157
enable_chunked_prefill: bool = False,
158+
enable_attention_dp: bool = False,
157159
**llm_args_extra_dict: Any):
158160

159161
if gpus_per_node is None:
@@ -228,6 +230,8 @@ def get_llm_args(
228230
num_postprocess_workers,
229231
"enable_chunked_prefill":
230232
enable_chunked_prefill,
233+
"enable_attention_dp":
234+
enable_attention_dp,
231235
"revision":
232236
revision,
233237
"reasoning_parser":
@@ -508,11 +512,13 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
508512

509513
@click.command("serve")
510514
@click.argument("model", type=str)
511-
@click.option("--tokenizer",
512-
type=str,
513-
default=None,
514-
help=help_info_with_stability_tag("Path | Name of the tokenizer.",
515-
"beta"))
515+
@click.option(
516+
"--tokenizer",
517+
type=str,
518+
default=None,
519+
help=help_info_with_stability_tag(
520+
"Path or name of the tokenizer. When using the PyTorch backend, "
521+
"this replaces the default HuggingFace tokenizer.", "beta"))
516522
@click.option(
517523
"--custom_tokenizer",
518524
type=str,
@@ -641,12 +647,15 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
641647
default=False,
642648
help=help_info_with_stability_tag("Flag for HF transformers.",
643649
"beta"))
644-
@click.option("--revision",
650+
@click.option("--hf_revision",
651+
"--revision",
652+
"revision",
645653
type=str,
646654
default=None,
647655
help=help_info_with_stability_tag(
648656
"The revision to use for the HuggingFace model "
649-
"(branch name, tag name, or commit id).", "beta"))
657+
"(branch name, tag name, or commit id). "
658+
"Prefer --hf_revision over --revision.", "beta"))
650659
@click.option(
651660
"--config",
652661
"--extra_llm_api_options",
@@ -681,8 +690,9 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
681690
type=str,
682691
default=None,
683692
help=help_info_with_stability_tag(
684-
"Server role. Specify this value only if running in disaggregated mode.",
685-
"prototype"))
693+
"Server role for disaggregated serving. "
694+
"CONTEXT=prefill (prompt processing), GENERATION=decode (token generation). "
695+
"Required when using service registry.", "prototype"))
686696
@click.option(
687697
"--fail_fast_on_attention_window_too_large",
688698
is_flag=True,
@@ -706,6 +716,11 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
706716
default=False,
707717
help=help_info_with_stability_tag("Enable chunked prefill",
708718
"prototype"))
719+
@click.option("--enable_attention_dp",
720+
is_flag=True,
721+
default=False,
722+
help=help_info_with_stability_tag(
723+
"Enable attention data parallel.", "beta"))
709724
@click.option("--media_io_kwargs",
710725
type=str,
711726
default=None,
@@ -752,16 +767,22 @@ def serve(
752767
metadata_server_config_file: Optional[str], server_role: Optional[str],
753768
fail_fast_on_attention_window_too_large: bool,
754769
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
755-
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
756-
custom_module_dirs: list[Path], chat_template: Optional[str],
757-
grpc: bool, served_model_name: Optional[str],
770+
enable_attention_dp: bool, disagg_cluster_uri: Optional[str],
771+
media_io_kwargs: Optional[str], custom_module_dirs: list[Path],
772+
chat_template: Optional[str], grpc: bool,
773+
served_model_name: Optional[str],
758774
extra_visual_gen_options: Optional[str]):
759775
"""Running an OpenAI API compatible server
760776
761777
MODEL: model name | HF checkpoint path | TensorRT engine path
762778
"""
763779
logger.set_level(log_level)
764780

781+
if "--revision" in sys.argv:
782+
warnings.warn("--revision is deprecated, use --hf_revision instead.",
783+
DeprecationWarning,
784+
stacklevel=2)
785+
765786
for custom_module_dir in custom_module_dirs:
766787
try:
767788
import_custom_module_from_dir(custom_module_dir)
@@ -796,7 +817,8 @@ def _serve_llm():
796817
fail_fast_on_attention_window_too_large=
797818
fail_fast_on_attention_window_too_large,
798819
otlp_traces_endpoint=otlp_traces_endpoint,
799-
enable_chunked_prefill=enable_chunked_prefill)
820+
enable_chunked_prefill=enable_chunked_prefill,
821+
enable_attention_dp=enable_attention_dp)
800822

801823
llm_args_extra_dict = {}
802824
if extra_llm_api_options is not None:
@@ -923,33 +945,62 @@ def _serve_visual_gen():
923945
default=False,
924946
help="Flag for HF transformers.")
925947
@click.option(
948+
"--config",
926949
"--extra_encoder_options",
950+
"extra_encoder_options",
927951
type=str,
928952
default=None,
929953
help=
930-
"Path to a YAML file that overwrites the parameters specified by trtllm-serve."
931-
)
954+
"Path to a YAML file that overwrites the parameters specified by trtllm-serve. "
955+
"Prefer --config over --extra_encoder_options.")
956+
@click.option("--hf_revision",
957+
"--revision",
958+
"revision",
959+
type=str,
960+
default=None,
961+
help="The revision to use for the HuggingFace model "
962+
"(branch name, tag name, or commit id).")
963+
@click.option("--free_gpu_memory_fraction",
964+
type=float,
965+
default=0.9,
966+
help="Free GPU memory fraction reserved for KV Cache, "
967+
"after allocating model weights and buffers.")
968+
@click.option("--tensor_parallel_size",
969+
"--tp_size",
970+
type=int,
971+
default=1,
972+
help="Tensor parallelism size.")
932973
@click.option("--metadata_server_config_file",
933974
type=str,
934975
default=None,
935976
help="Path to metadata server config file")
936977
def serve_encoder(model: str, host: str, port: int, log_level: str,
937978
max_batch_size: int, max_num_tokens: int,
938979
gpus_per_node: Optional[int], trust_remote_code: bool,
939-
extra_encoder_options: Optional[str],
980+
extra_encoder_options: Optional[str], revision: Optional[str],
981+
free_gpu_memory_fraction: float, tensor_parallel_size: int,
940982
metadata_server_config_file: Optional[str]):
941983
"""Running an OpenAI API compatible server
942984
943985
MODEL: model name | HF checkpoint path | TensorRT engine path
944986
"""
945987
logger.set_level(log_level)
946988

947-
# TODO: expose more arguments progressively
948-
llm_args, _ = get_llm_args(model=model,
949-
max_batch_size=max_batch_size,
950-
max_num_tokens=max_num_tokens,
951-
gpus_per_node=gpus_per_node,
952-
trust_remote_code=trust_remote_code)
989+
if "--extra_encoder_options" in sys.argv:
990+
warnings.warn(
991+
"--extra_encoder_options is deprecated, use --config instead.",
992+
DeprecationWarning,
993+
stacklevel=2)
994+
995+
llm_args, _ = get_llm_args(
996+
model=model,
997+
max_batch_size=max_batch_size,
998+
max_num_tokens=max_num_tokens,
999+
gpus_per_node=gpus_per_node,
1000+
trust_remote_code=trust_remote_code,
1001+
revision=revision,
1002+
free_gpu_memory_fraction=free_gpu_memory_fraction,
1003+
tensor_parallel_size=tensor_parallel_size)
9531004

9541005
encoder_args_extra_dict = {}
9551006
if extra_encoder_options is not None:
@@ -966,10 +1017,12 @@ def serve_encoder(model: str, host: str, port: int, log_level: str,
9661017

9671018
@click.command("disaggregated")
9681019
@click.option("-c",
1020+
"--config",
9691021
"--config_file",
1022+
"config_file",
9701023
type=str,
9711024
default=None,
972-
help="Specific option for disaggregated mode.")
1025+
help="Path to the disaggregated serving configuration YAML file.")
9731026
@click.option("-m",
9741027
"--metadata_server_config_file",
9751028
type=str,
@@ -1009,6 +1062,11 @@ def disaggregated(
10091062

10101063
logger.set_level(log_level)
10111064

1065+
if "--config_file" in sys.argv:
1066+
warnings.warn("--config_file is deprecated, use --config instead.",
1067+
DeprecationWarning,
1068+
stacklevel=2)
1069+
10121070
disagg_cfg = parse_disagg_config_file(config_file)
10131071

10141072
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -1061,10 +1119,12 @@ def set_cuda_device():
10611119

10621120
@click.command("disaggregated_mpi_worker")
10631121
@click.option("-c",
1122+
"--config",
10641123
"--config_file",
1124+
"config_file",
10651125
type=str,
10661126
default=None,
1067-
help="Specific option for disaggregated mode.")
1127+
help="Path to the disaggregated serving configuration YAML file.")
10681128
@click.option('--log_level',
10691129
type=click.Choice(severity_map.keys()),
10701130
default='info',

tensorrt_llm/llmapi/llm_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3461,6 +3461,9 @@ def update_llm_args_with_extra_dict(
34613461
llm_args_dict: Dict,
34623462
extra_llm_api_options: Optional[str] = None) -> Dict:
34633463

3464+
if 'hf_revision' in llm_args_dict:
3465+
llm_args_dict.setdefault('revision', llm_args_dict.pop('hf_revision'))
3466+
34643467
# Deep merge kv_cache_config to prevent partial YAML kv_cache_config from replacing the complete kv_cache_config
34653468
if 'kv_cache_config' in llm_args and 'kv_cache_config' in llm_args_dict:
34663469
# Convert KvCacheConfig object to dict if necessary

0 commit comments

Comments
 (0)