66import socket
77import subprocess # nosec B404
88import sys
9+ import warnings
910from pathlib import Path
1011from typing import Any , Dict , Literal , Mapping , Optional , Sequence
1112
@@ -154,6 +155,7 @@ def get_llm_args(
154155 fail_fast_on_attention_window_too_large : bool = False ,
155156 otlp_traces_endpoint : Optional [str ] = None ,
156157 enable_chunked_prefill : bool = False ,
158+ enable_attention_dp : bool = False ,
157159 ** llm_args_extra_dict : Any ):
158160
159161 if gpus_per_node is None :
@@ -228,6 +230,8 @@ def get_llm_args(
228230 num_postprocess_workers ,
229231 "enable_chunked_prefill" :
230232 enable_chunked_prefill ,
233+ "enable_attention_dp" :
234+ enable_attention_dp ,
231235 "revision" :
232236 revision ,
233237 "reasoning_parser" :
@@ -508,11 +512,13 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
508512
509513@click .command ("serve" )
510514@click .argument ("model" , type = str )
511- @click .option ("--tokenizer" ,
512- type = str ,
513- default = None ,
514- help = help_info_with_stability_tag ("Path | Name of the tokenizer." ,
515- "beta" ))
515+ @click .option (
516+ "--tokenizer" ,
517+ type = str ,
518+ default = None ,
519+ help = help_info_with_stability_tag (
520+ "Path or name of the tokenizer. When using the PyTorch backend, "
521+ "this replaces the default HuggingFace tokenizer." , "beta" ))
516522@click .option (
517523 "--custom_tokenizer" ,
518524 type = str ,
@@ -641,12 +647,15 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
641647 default = False ,
642648 help = help_info_with_stability_tag ("Flag for HF transformers." ,
643649 "beta" ))
644- @click .option ("--revision" ,
650+ @click .option ("--hf_revision" ,
651+ "--revision" ,
652+ "revision" ,
645653 type = str ,
646654 default = None ,
647655 help = help_info_with_stability_tag (
648656 "The revision to use for the HuggingFace model "
649- "(branch name, tag name, or commit id)." , "beta" ))
657+ "(branch name, tag name, or commit id). "
658+ "Prefer --hf_revision over --revision." , "beta" ))
650659@click .option (
651660 "--config" ,
652661 "--extra_llm_api_options" ,
@@ -681,8 +690,9 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
681690 type = str ,
682691 default = None ,
683692 help = help_info_with_stability_tag (
684- "Server role. Specify this value only if running in disaggregated mode." ,
685- "prototype" ))
693+ "Server role for disaggregated serving. "
694+ "CONTEXT=prefill (prompt processing), GENERATION=decode (token generation). "
695+ "Required when using service registry." , "prototype" ))
686696@click .option (
687697 "--fail_fast_on_attention_window_too_large" ,
688698 is_flag = True ,
@@ -706,6 +716,11 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
706716 default = False ,
707717 help = help_info_with_stability_tag ("Enable chunked prefill" ,
708718 "prototype" ))
719+ @click .option ("--enable_attention_dp" ,
720+ is_flag = True ,
721+ default = False ,
722+ help = help_info_with_stability_tag (
723+ "Enable attention data parallel." , "beta" ))
709724@click .option ("--media_io_kwargs" ,
710725 type = str ,
711726 default = None ,
@@ -752,16 +767,22 @@ def serve(
752767 metadata_server_config_file : Optional [str ], server_role : Optional [str ],
753768 fail_fast_on_attention_window_too_large : bool ,
754769 otlp_traces_endpoint : Optional [str ], enable_chunked_prefill : bool ,
755- disagg_cluster_uri : Optional [str ], media_io_kwargs : Optional [str ],
756- custom_module_dirs : list [Path ], chat_template : Optional [str ],
757- grpc : bool , served_model_name : Optional [str ],
770+ enable_attention_dp : bool , disagg_cluster_uri : Optional [str ],
771+ media_io_kwargs : Optional [str ], custom_module_dirs : list [Path ],
772+ chat_template : Optional [str ], grpc : bool ,
773+ served_model_name : Optional [str ],
758774 extra_visual_gen_options : Optional [str ]):
759775 """Running an OpenAI API compatible server
760776
761777 MODEL: model name | HF checkpoint path | TensorRT engine path
762778 """
763779 logger .set_level (log_level )
764780
781+ if "--revision" in sys .argv :
782+ warnings .warn ("--revision is deprecated, use --hf_revision instead." ,
783+ DeprecationWarning ,
784+ stacklevel = 2 )
785+
765786 for custom_module_dir in custom_module_dirs :
766787 try :
767788 import_custom_module_from_dir (custom_module_dir )
@@ -796,7 +817,8 @@ def _serve_llm():
796817 fail_fast_on_attention_window_too_large =
797818 fail_fast_on_attention_window_too_large ,
798819 otlp_traces_endpoint = otlp_traces_endpoint ,
799- enable_chunked_prefill = enable_chunked_prefill )
820+ enable_chunked_prefill = enable_chunked_prefill ,
821+ enable_attention_dp = enable_attention_dp )
800822
801823 llm_args_extra_dict = {}
802824 if extra_llm_api_options is not None :
@@ -923,33 +945,62 @@ def _serve_visual_gen():
923945 default = False ,
924946 help = "Flag for HF transformers." )
925947@click .option (
948+ "--config" ,
926949 "--extra_encoder_options" ,
950+ "extra_encoder_options" ,
927951 type = str ,
928952 default = None ,
929953 help =
930- "Path to a YAML file that overwrites the parameters specified by trtllm-serve."
931- )
954+ "Path to a YAML file that overwrites the parameters specified by trtllm-serve. "
955+ "Prefer --config over --extra_encoder_options." )
956+ @click .option ("--hf_revision" ,
957+ "--revision" ,
958+ "revision" ,
959+ type = str ,
960+ default = None ,
961+ help = "The revision to use for the HuggingFace model "
962+ "(branch name, tag name, or commit id)." )
963+ @click .option ("--free_gpu_memory_fraction" ,
964+ type = float ,
965+ default = 0.9 ,
966+ help = "Free GPU memory fraction reserved for KV Cache, "
967+ "after allocating model weights and buffers." )
968+ @click .option ("--tensor_parallel_size" ,
969+ "--tp_size" ,
970+ type = int ,
971+ default = 1 ,
972+ help = "Tensor parallelism size." )
932973@click .option ("--metadata_server_config_file" ,
933974 type = str ,
934975 default = None ,
935976 help = "Path to metadata server config file" )
936977def serve_encoder (model : str , host : str , port : int , log_level : str ,
937978 max_batch_size : int , max_num_tokens : int ,
938979 gpus_per_node : Optional [int ], trust_remote_code : bool ,
939- extra_encoder_options : Optional [str ],
980+ extra_encoder_options : Optional [str ], revision : Optional [str ],
981+ free_gpu_memory_fraction : float , tensor_parallel_size : int ,
940982 metadata_server_config_file : Optional [str ]):
941983 """Running an OpenAI API compatible server
942984
943985 MODEL: model name | HF checkpoint path | TensorRT engine path
944986 """
945987 logger .set_level (log_level )
946988
947- # TODO: expose more arguments progressively
948- llm_args , _ = get_llm_args (model = model ,
949- max_batch_size = max_batch_size ,
950- max_num_tokens = max_num_tokens ,
951- gpus_per_node = gpus_per_node ,
952- trust_remote_code = trust_remote_code )
989+ if "--extra_encoder_options" in sys .argv :
990+ warnings .warn (
991+ "--extra_encoder_options is deprecated, use --config instead." ,
992+ DeprecationWarning ,
993+ stacklevel = 2 )
994+
995+ llm_args , _ = get_llm_args (
996+ model = model ,
997+ max_batch_size = max_batch_size ,
998+ max_num_tokens = max_num_tokens ,
999+ gpus_per_node = gpus_per_node ,
1000+ trust_remote_code = trust_remote_code ,
1001+ revision = revision ,
1002+ free_gpu_memory_fraction = free_gpu_memory_fraction ,
1003+ tensor_parallel_size = tensor_parallel_size )
9531004
9541005 encoder_args_extra_dict = {}
9551006 if extra_encoder_options is not None :
@@ -966,10 +1017,12 @@ def serve_encoder(model: str, host: str, port: int, log_level: str,
9661017
9671018@click .command ("disaggregated" )
9681019@click .option ("-c" ,
1020+ "--config" ,
9691021 "--config_file" ,
1022+ "config_file" ,
9701023 type = str ,
9711024 default = None ,
972- help = "Specific option for disaggregated mode ." )
1025+ help = "Path to the disaggregated serving configuration YAML file ." )
9731026@click .option ("-m" ,
9741027 "--metadata_server_config_file" ,
9751028 type = str ,
@@ -1009,6 +1062,11 @@ def disaggregated(
10091062
10101063 logger .set_level (log_level )
10111064
1065+ if "--config_file" in sys .argv :
1066+ warnings .warn ("--config_file is deprecated, use --config instead." ,
1067+ DeprecationWarning ,
1068+ stacklevel = 2 )
1069+
10121070 disagg_cfg = parse_disagg_config_file (config_file )
10131071
10141072 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
@@ -1061,10 +1119,12 @@ def set_cuda_device():
10611119
10621120@click .command ("disaggregated_mpi_worker" )
10631121@click .option ("-c" ,
1122+ "--config" ,
10641123 "--config_file" ,
1124+ "config_file" ,
10651125 type = str ,
10661126 default = None ,
1067- help = "Specific option for disaggregated mode ." )
1127+ help = "Path to the disaggregated serving configuration YAML file ." )
10681128@click .option ('--log_level' ,
10691129 type = click .Choice (severity_map .keys ()),
10701130 default = 'info' ,
0 commit comments