1818 mpi_comm , mpi_rank , nvtx_range_debug )
1919from ..bindings import executor as tllm
2020from ..builder import ConfigEncoder , Engine , EngineConfig
21- from ..llmapi .llm_args import KvCacheConnectorConfig , PybindMirror , TorchLlmArgs
21+ from ..llmapi .llm_args import (BaseLlmArgs , KvCacheConnectorConfig ,
22+ PybindMirror , TorchLlmArgs )
2223from ..llmapi .mpi_session import set_mpi_session_cpp
2324from ..llmapi .tokenizer import TokenizerBase
2425from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
@@ -64,7 +65,7 @@ def __init__(
6465 kv_connector_config : Optional [KvCacheConnectorConfig ] = None ,
6566 hf_model_dir : Optional [Path ] = None ,
6667 tokenizer : Optional [TokenizerBase ] = None ,
67- llm_args : Optional [TorchLlmArgs ] = None ,
68+ llm_args : Optional [BaseLlmArgs ] = None ,
6869 ) -> None :
6970 postproc_config = postproc_worker_config or PostprocWorkerConfig ()
7071 super ().__init__ (
@@ -107,40 +108,55 @@ def _get_comm_ranks_device_id():
107108 device_ids = mpi_comm ().allgather (device_id )
108109 return comm_ranks , device_ids
109110
110- def _create_py_executor (executor_config ):
111- assert executor_config is None , "expect an empty executor_config is _create_py_executor"
112- executor_config = llm_args .get_executor_config (
113- hf_model_dir , tokenizer )
114- # Persist so downstream code (e.g., default max_tokens deduction) has access
115- self ._executor_config = executor_config
116- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
117- processor_batched = batched_logits_processor , replicate = False )
118- comm_ranks , device_ids = _get_comm_ranks_device_id ()
119- executor_config .parallel_config = tllm .ParallelConfig (
120- participant_ids = comm_ranks , device_ids = device_ids )
121- args = {
122- "executor_config" : executor_config ,
123- "checkpoint_dir" : executor_config .hf_model_dir ,
124- }
111+ def _create_py_executor ():
112+ args = {}
125113 assert hasattr (
126- executor_config , "backend"
127- ), "executor_config should be with backend in _create_py_executor"
128- if executor_config .backend == "pytorch" :
114+ self . llm_args , "backend"
115+ ), "llm_args should be with backend in _create_py_executor"
116+ if self . llm_args .backend == "pytorch" :
129117 from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
130118 create_py_executor
131119 create_executor = create_py_executor
120+ args ["llm_args" ] = self .llm_args
121+ args ["checkpoint_dir" ] = hf_model_dir
122+ args ["tokenizer" ] = tokenizer
132123 args ["lora_config" ] = lora_config
133- args [
134- "garbage_collection_gen0_threshold" ] = llm_args .garbage_collection_gen0_threshold
135124 args ["kv_connector_config" ] = kv_connector_config
136- elif executor_config .backend == "_autodeploy" :
125+ args [
126+ "logits_post_processor_config" ] = tllm .LogitsPostProcessorConfig (
127+ processor_batched = batched_logits_processor ,
128+ replicate = False )
129+ comm_ranks , device_ids = _get_comm_ranks_device_id ()
130+ args ["parallel_config" ] = tllm .ParallelConfig (
131+ participant_ids = comm_ranks , device_ids = device_ids )
132+ elif self .llm_args .backend == "_autodeploy" :
133+ from tensorrt_llm ._torch .auto_deploy .llm_args import \
134+ LlmArgs as ADLlmArgs
137135 from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
138136 create_autodeploy_executor
139137 create_executor = create_autodeploy_executor
138+ assert isinstance (self .llm_args , ADLlmArgs )
139+ args ["ad_config" ] = self .llm_args .get_pytorch_backend_config ()
140140 else :
141141 raise ValueError (
142- f"Unsupported backend config: { executor_config .backend } " )
143- return create_executor (** args )
142+ f"Unsupported backend config: { self .llm_args .backend } " )
143+
144+ # Define additional attributes that can be used later, such as in _deduce_max_tokens
145+ self .mapping = self .llm_args .parallel_config .to_mapping ()
146+ self .checkpoint_loader = None
147+ if self .llm_args .backend == "pytorch" :
148+ from tensorrt_llm ._torch .pyexecutor .config import \
149+ _construct_checkpoint_loader
150+ self .checkpoint_loader = _construct_checkpoint_loader (
151+ self .llm_args .backend , self .llm_args .checkpoint_loader ,
152+ self .llm_args .checkpoint_format )
153+
154+ _executor = create_executor (** args )
155+ self .max_seq_len = self .llm_args .max_seq_len
156+ if _executor .max_seq_len is not None :
157+ # max_seq_len might be updated by model engine as in create_py_executor
158+ self .max_seq_len = _executor .max_seq_len
159+ return _executor
144160
145161 def _create_engine (executor_config ):
146162 if executor_config is None :
@@ -164,8 +180,7 @@ def _create_engine(executor_config):
164180 executor_config )
165181
166182 self .engine = _create_py_executor (
167- executor_config ) if llm_args is not None else _create_engine (
168- executor_config )
183+ ) if self .llm_args is not None else _create_engine (executor_config )
169184
170185 self ._lora_manager : Optional [LoraManager ] = None
171186 self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -188,8 +203,9 @@ def _create_engine(executor_config):
188203 if engine_config .build_config .max_prompt_embedding_table_size > 0 :
189204 self ._prompt_adapter_manager = PromptAdapterManager ()
190205
191- if getattr (self ._executor_config , "backend" ,
192- "" ) == "pytorch" and lora_config is not None :
206+ if self .llm_args and getattr (
207+ self .llm_args , "backend" ,
208+ "" ) == "pytorch" and lora_config is not None :
193209 from tensorrt_llm ._torch .pyexecutor .resource_manager import \
194210 ResourceManagerType
195211 peft_cache_manager = self .engine .resource_manager .resource_managers .get (
@@ -471,26 +487,43 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
471487 assert request .id is not None
472488
473489 def _deduce_max_tokens (request : GenerationRequest ,
474- executor_config : tllm .ExecutorConfig ) -> int :
490+ executor_config : tllm .ExecutorConfig ,
491+ llm_args : Optional [BaseLlmArgs ] = None ) -> int :
475492 # deduce max_tokens when it's not set by user
476493 max_tokens = request .sampling_params .max_tokens
477494 query_token_len = len (
478495 request .query_token_ids ) if request .query_token_ids else 0
479- cp_size = 1 if (not hasattr (executor_config , "mapping" )
480- or executor_config .mapping .cp_size
481- is None ) else executor_config .mapping .cp_size
482- if not hasattr (executor_config , "max_seq_len" ):
496+
497+ cp_size = 1
498+ max_seq_len = None
499+ if llm_args is not None :
500+ # deduce max_tokens by llm args
501+ assert executor_config is None , "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
502+ if hasattr (self ,
503+ "mapping" ) and self .mapping .cp_size is not None :
504+ cp_size = self .mapping .cp_size
505+ max_seq_len = getattr (self , "max_seq_len" , None )
506+ else :
507+ # deduce max_tokens by executor config
508+ if hasattr (executor_config , "mapping"
509+ ) and executor_config .mapping .cp_size is not None :
510+ cp_size = executor_config .mapping .cp_size
511+ max_seq_len = getattr (executor_config , "max_seq_len" , None )
512+ if max_seq_len is None :
483513 logger .warning ("`default_max_tokens` cannot be deduced" )
484514 if max_tokens is None :
485515 raise ValueError (
486516 "`max_tokens` must be set when `default_max_tokens` cannot be deduced"
487517 )
518+ else :
519+ # use max_tokens if can't deduce default_max_tokens
520+ return max_tokens
488521 splited_prompt_len = int (len (prompt_token_ids ) / cp_size )
489- default_max_tokens = executor_config . max_seq_len - splited_prompt_len - query_token_len
522+ default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
490523 if default_max_tokens <= 0 :
491524 logger .warning (
492525 f"`default_max_tokens` ({ default_max_tokens } ) should be greater than 0, "
493- f"`default_max_tokens` ({ default_max_tokens } ) = max_seq_len ({ executor_config . max_seq_len } )"
526+ f"`default_max_tokens` ({ default_max_tokens } ) = max_seq_len ({ max_seq_len } )"
494527 f" - `splited_prompt_len` ({ splited_prompt_len } ) - `query_token_len` ({ query_token_len } )"
495528 )
496529 if max_tokens is None :
@@ -512,7 +545,8 @@ def _deduce_max_tokens(request: GenerationRequest,
512545 executor_request = tllm .Request (
513546 client_id = request .id ,
514547 input_token_ids = prompt_token_ids ,
515- max_tokens = _deduce_max_tokens (request , self ._executor_config ),
548+ max_tokens = _deduce_max_tokens (request , self ._executor_config ,
549+ self .llm_args ),
516550 streaming = request .streaming ,
517551 sampling_config = request .sampling_params ._get_sampling_config (),
518552 end_id = - 1 if request .sampling_params .ignore_eos else
@@ -638,11 +672,19 @@ def shutdown(self):
638672 self .engine .shutdown ()
639673 self .engine = None
640674
641- if hasattr (
642- self ._executor_config , "checkpoint_loader"
643- ) and self ._executor_config .checkpoint_loader is not None :
644- self ._executor_config .checkpoint_loader .cleanup ()
645- self ._executor_config .checkpoint_loader = None
675+ if self .llm_args is not None :
676+ assert self ._executor_config is None , "An empty executor_config is expected in shutdown when LLM arguments are defined."
677+ if (self .llm_args .backend == "pytorch"
678+ and hasattr (self , "checkpoint_loader" )
679+ and self .checkpoint_loader is not None ):
680+ self .checkpoint_loader .cleanup ()
681+ self .checkpoint_loader = None
682+ else :
683+ if hasattr (
684+ self ._executor_config , "checkpoint_loader"
685+ ) and self ._executor_config .checkpoint_loader is not None :
686+ self ._executor_config .checkpoint_loader .cleanup ()
687+ self ._executor_config .checkpoint_loader = None
646688
647689 # Check if there are any errors from the threads before shutdown.
648690 self ._handle_background_error ()
@@ -689,7 +731,7 @@ def worker_main(
689731 kv_connector_config : Optional [KvCacheConnectorConfig ] = None ,
690732 hf_model_dir : Optional [Path ] = None ,
691733 tokenizer : Optional [TokenizerBase ] = None ,
692- llm_args : Optional [TorchLlmArgs ] = None ,
734+ llm_args : Optional [BaseLlmArgs ] = None ,
693735) -> None :
694736 mpi_comm ().barrier ()
695737 print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
0 commit comments