1- # Copyright 2026 Bytedance Ltd. and/or its affiliates
1+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
2525
2626from verl .single_controller .ray import RayClassWithInitArgs , SubRayResourcePool
2727from verl .utils .config import omega_conf_to_dataclass
28+ from verl .utils .device import is_cuda_available
2829from verl .utils .net_utils import is_valid_ipv6_address
2930from verl .workers .config import HFModelConfig , RolloutConfig
3031from verl .workers .rollout .replica import RolloutMode , RolloutReplica , TokenOutput
3132from verl .workers .rollout .trtllm_rollout .trtllm_rollout import ServerAdapter
32- from verl .workers .rollout .utils import get_max_position_embeddings , run_unvicorn
33+ from verl .workers .rollout .utils import run_unvicorn
34+ from verl .utils .tokenizer import hf_processor
3335
3436logger = logging .getLogger (__file__ )
3537logger .setLevel (logging .INFO )
@@ -42,7 +44,6 @@ class TRTLLMHttpServer:
4244 Args:
4345 config (DictConfig): full config.
4446 model_config (HFModelConfig): model config.
45- is_reward_model (bool): whether this is a reward model.
4647 rollout_mode (RolloutMode): rollout mode.
4748 workers (list[ActorHandle]): list of rollout workers.
4849 replica_rank (int): replica rank, a replica may contain multiple nodes.
@@ -55,7 +56,6 @@ def __init__(
5556 self ,
5657 config : RolloutConfig ,
5758 model_config : HFModelConfig ,
58- is_reward_model : bool ,
5959 rollout_mode : RolloutMode ,
6060 workers : list [ActorHandle ],
6161 replica_rank : int ,
@@ -64,20 +64,10 @@ def __init__(
6464 bundle_indices : list [list [int ]] = None ,
6565 ):
6666 os .environ ["TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL" ] = "1"
67- assert torch .cuda .is_available (), "TRTLLM http server should run on GPU node"
68-
67+ os .system ("nvcc --version" )
6968 self .config : RolloutConfig = omega_conf_to_dataclass (config )
7069 self .model_config : HFModelConfig = omega_conf_to_dataclass (model_config , dataclass_type = HFModelConfig )
71- self .is_reward_model = is_reward_model
72- max_position_embeddings = get_max_position_embeddings (self .model_config .hf_config )
73- if self .config .max_model_len is None :
74- self .config .max_model_len = max_position_embeddings
75- else :
76- if self .config .max_model_len > max_position_embeddings :
77- raise ValueError (
78- f"max_model_len ({ self .config .max_model_len } ) should be less than or equal to "
79- f"max_position_embeddings ({ max_position_embeddings } )"
80- )
70+ self .config .max_model_len = self .config .prompt_length + self .config .response_length
8171 self .rollout_mode = rollout_mode
8272 self .workers = workers
8373 self .replica_rank = replica_rank
@@ -88,12 +78,15 @@ def __init__(
8878 if self .rollout_mode != RolloutMode .HYBRID and self .config .load_format == "dummy" :
8979 logger .warning (f"rollout mode is { self .rollout_mode } , load_format is dummy, set to auto" )
9080 self .config .load_format = "auto"
91-
81+ self .is_vlm_model = (
82+ self .model_config .hf_config is not None
83+ and hasattr (self .model_config .hf_config , "vision_config" )
84+ )
9285 # used for http server
9386 self ._server_address = ray .util .get_node_ip_address ().strip ("[]" )
9487 self ._server_port = None
9588
96- logger .info (f"TRTLLMHttpServer, replica_rank: { self .replica_rank } " )
89+ logger .info (f"TRTLLMHttpServer, replica_rank: { self .replica_rank } , " )
9790
9891 self .sampling_args = {
9992 "detokenize" : False ,
@@ -118,15 +111,32 @@ async def launch_server(self):
118111 enable_block_reuse = True ,
119112 free_gpu_memory_fraction = self .config .gpu_memory_utilization ,
120113 )
114+ cuda_graph_config = CudaGraphConfig (
115+ enable_padding = True ,
116+ batch_sizes = self .config .cudagraph_capture_sizes ,
117+ max_batch_size = 0 if self .config .cudagraph_capture_sizes else self .config .max_num_seqs ,
118+ )
121119
122120 per_worker_gpu_share = 1.0 / self .max_colocate_count
123121
122+ # 准备 CUDA 环境变量,传递给 Ray workers(包括 RayWorkerWrapper)
123+ import os
124+ cuda_env_vars = {}
125+ cuda_env_var_names = [
126+ "CUDA_HOME" , "LD_LIBRARY_PATH" , "CUDA_LIB_PATH" ,
127+ "CUDA_VERSION" , "CUDA_ROOT" , "PATH"
128+ ]
129+ for env_var in cuda_env_var_names :
130+ if env_var in os .environ :
131+ cuda_env_vars [env_var ] = os .environ [env_var ]
132+
124133 llm_kwargs = {
125134 "model" : self .model_config .local_path ,
126135 "backend" : "pytorch" ,
127136 "orchestrator_type" : "ray" ,
128137 "ray_worker_extension_cls" : "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension" ,
129138 "kv_cache_config" : kv_cache_config ,
139+ "cuda_graph_config" : cuda_graph_config ,
130140 "max_seq_len" : self .config .max_model_len ,
131141 "max_batch_size" : self .config .max_num_seqs ,
132142 "max_num_tokens" : self .config .max_num_batched_tokens ,
@@ -141,26 +151,9 @@ async def launch_server(self):
141151 ** engine_kwargs ,
142152 }
143153
144- if self .is_reward_model :
145- llm_kwargs .update (
146- {
147- "cuda_graph_config" : None ,
148- "disable_overlap_scheduler" : True ,
149- }
150- )
151- else :
152- llm_kwargs .update (
153- {
154- "cuda_graph_config" : CudaGraphConfig (
155- enable_padding = True ,
156- batch_sizes = self .config .cudagraph_capture_sizes ,
157- max_batch_size = 0 if self .config .cudagraph_capture_sizes else self .config .max_num_seqs ,
158- )
159- }
160- )
161-
162154 self .llm = await AsyncLLM (** llm_kwargs )
163-
155+ if self .is_vlm_model :
156+ self .visual_processor = hf_processor (self .model_config .local_path , trust_remote_code = self .model_config .trust_remote_code )
164157 trtllm_server = OpenAIServer (
165158 llm = self .llm ,
166159 model = self .model_config .local_path ,
@@ -192,10 +185,39 @@ async def generate(
192185 sampling_params .update (self .sampling_args )
193186
194187 trt_llm_sampling_params = SamplingParams (** sampling_params )
195- outputs = await self .llm .generate_async (
196- inputs = prompt_ids ,
197- sampling_params = trt_llm_sampling_params ,
198- )
188+ if self .is_vlm_model :
189+ multi_modal_inputs = self .visual_processor (
190+ text = ["" ], # 占位符,实际不使用
191+ images = image_data if image_data else None ,
192+ videos = video_data if video_data else None ,
193+ return_tensors = "pt"
194+ )
195+
196+ # 提取多模态相关的 tensor
197+ mm_processor_kwargs = {
198+ k : v for k , v in multi_modal_inputs .items ()
199+ if k not in ["input_ids" , "attention_mask" ]
200+ }
201+
202+ tokens_prompt = {
203+ "prompt_token_ids" : prompt_ids , # 使用已有的 prompt_ids
204+ "multi_modal_data" : {
205+ "mm_processor_kwargs" : mm_processor_kwargs , # 多模态处理后的 tensor
206+ },
207+ }
208+ if image_data :
209+ tokens_prompt ["multi_modal_data" ]["image" ] = image_data
210+ if video_data :
211+ tokens_prompt ["multi_modal_data" ]["video" ] = video_data
212+ outputs = await self .llm .generate_async (
213+ inputs = tokens_prompt ,
214+ sampling_params = trt_llm_sampling_params ,
215+ )
216+ else :
217+ outputs = await self .llm .generate_async (
218+ inputs = prompt_ids ,
219+ sampling_params = trt_llm_sampling_params ,
220+ )
199221
200222 token_ids = outputs .outputs [0 ].token_ids
201223 log_probs = None
@@ -205,19 +227,16 @@ async def generate(
205227
206228 async def wake_up (self ):
207229 if self .rollout_mode == RolloutMode .HYBRID :
208- # In hybrid mode, rollout is wake up in `update_weights`
209- raise ValueError ( f" wake_up not support rollout_mode { self .rollout_mode } " )
210- if self .rollout_mode == RolloutMode .COLOCATED :
230+ # Call all workers to switch between trainer mode and rollout mode.
231+ await asyncio . gather ( * [ worker . wake_up . remote () for worker in self .workers ] )
232+ elif self .rollout_mode == RolloutMode .COLOCATED :
211233 await self .llm .resume (tags = ServerAdapter .get_full_tags ())
212234 elif self .rollout_mode == RolloutMode .STANDALONE :
213235 logger .info ("skip wake_up in standalone mode" )
214236
215237 async def sleep (self ):
216- if not self .config .free_cache_engine :
217- return
218-
219238 if self .rollout_mode == RolloutMode .HYBRID :
220- await self . llm . release ( tags = ServerAdapter . get_full_tags () )
239+ await asyncio . gather ( * [ worker . sleep . remote () for worker in self . workers ] )
221240 elif self .rollout_mode == RolloutMode .COLOCATED :
222241 await self .llm .release (tags = ServerAdapter .get_full_tags ())
223242 elif self .rollout_mode == RolloutMode .STANDALONE :
@@ -329,17 +348,19 @@ async def launch_servers(self):
329348 else f"trtllm_server_reward_{ self .replica_rank } "
330349 )
331350
351+ runtime_env_vars = {
352+ "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" : "1" ,
353+ "TLLM_NUMA_AWARE_WORKER_AFFINITY" : "0"
354+ }
332355 server = TRTLLMHttpServer .options (
333356 scheduling_strategy = ray .util .scheduling_strategies .NodeAffinitySchedulingStrategy (
334357 node_id = node_id ,
335358 soft = False ,
336359 ),
337- runtime_env = {"env_vars" : {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES" : "1" }},
338- name = name ,
360+ runtime_env = {"env_vars" : runtime_env_vars },
339361 ).remote (
340362 config = self .config ,
341363 model_config = self .model_config ,
342- is_reward_model = self .is_reward_model ,
343364 rollout_mode = self .rollout_mode ,
344365 workers = self .workers ,
345366 replica_rank = self .replica_rank ,
0 commit comments