Skip to content

Commit a5b5550

Browse files
committed
Update
1 parent b63fe7f commit a5b5550

File tree

1 file changed

+72
-51
lines changed

1 file changed

+72
-51
lines changed

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2026 Bytedance Ltd. and/or its affiliates
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -25,11 +25,13 @@
2525

2626
from verl.single_controller.ray import RayClassWithInitArgs, SubRayResourcePool
2727
from verl.utils.config import omega_conf_to_dataclass
28+
from verl.utils.device import is_cuda_available
2829
from verl.utils.net_utils import is_valid_ipv6_address
2930
from verl.workers.config import HFModelConfig, RolloutConfig
3031
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
3132
from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter
32-
from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn
33+
from verl.workers.rollout.utils import run_unvicorn
34+
from verl.utils.tokenizer import hf_processor
3335

3436
logger = logging.getLogger(__file__)
3537
logger.setLevel(logging.INFO)
@@ -42,7 +44,6 @@ class TRTLLMHttpServer:
4244
Args:
4345
config (DictConfig): full config.
4446
model_config (HFModelConfig): model config.
45-
is_reward_model (bool): whether this is a reward model.
4647
rollout_mode (RolloutMode): rollout mode.
4748
workers (list[ActorHandle]): list of rollout workers.
4849
replica_rank (int): replica rank, a replica may contain multiple nodes.
@@ -55,7 +56,6 @@ def __init__(
5556
self,
5657
config: RolloutConfig,
5758
model_config: HFModelConfig,
58-
is_reward_model: bool,
5959
rollout_mode: RolloutMode,
6060
workers: list[ActorHandle],
6161
replica_rank: int,
@@ -64,20 +64,10 @@ def __init__(
6464
bundle_indices: list[list[int]] = None,
6565
):
6666
os.environ["TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL"] = "1"
67-
assert torch.cuda.is_available(), "TRTLLM http server should run on GPU node"
68-
67+
os.system("nvcc --version")
6968
self.config: RolloutConfig = omega_conf_to_dataclass(config)
7069
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
71-
self.is_reward_model = is_reward_model
72-
max_position_embeddings = get_max_position_embeddings(self.model_config.hf_config)
73-
if self.config.max_model_len is None:
74-
self.config.max_model_len = max_position_embeddings
75-
else:
76-
if self.config.max_model_len > max_position_embeddings:
77-
raise ValueError(
78-
f"max_model_len ({self.config.max_model_len}) should be less than or equal to "
79-
f"max_position_embeddings ({max_position_embeddings})"
80-
)
70+
self.config.max_model_len = self.config.prompt_length + self.config.response_length
8171
self.rollout_mode = rollout_mode
8272
self.workers = workers
8373
self.replica_rank = replica_rank
@@ -88,12 +78,15 @@ def __init__(
8878
if self.rollout_mode != RolloutMode.HYBRID and self.config.load_format == "dummy":
8979
logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto")
9080
self.config.load_format = "auto"
91-
81+
self.is_vlm_model = (
82+
self.model_config.hf_config is not None
83+
and hasattr(self.model_config.hf_config, "vision_config")
84+
)
9285
# used for http server
9386
self._server_address = ray.util.get_node_ip_address().strip("[]")
9487
self._server_port = None
9588

96-
logger.info(f"TRTLLMHttpServer, replica_rank: {self.replica_rank}")
89+
logger.info(f"TRTLLMHttpServer, replica_rank: {self.replica_rank}, ")
9790

9891
self.sampling_args = {
9992
"detokenize": False,
@@ -118,15 +111,32 @@ async def launch_server(self):
118111
enable_block_reuse=True,
119112
free_gpu_memory_fraction=self.config.gpu_memory_utilization,
120113
)
114+
cuda_graph_config = CudaGraphConfig(
115+
enable_padding=True,
116+
batch_sizes=self.config.cudagraph_capture_sizes,
117+
max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_num_seqs,
118+
)
121119

122120
per_worker_gpu_share = 1.0 / self.max_colocate_count
123121

122+
# 准备 CUDA 环境变量,传递给 Ray workers(包括 RayWorkerWrapper)
123+
import os
124+
cuda_env_vars = {}
125+
cuda_env_var_names = [
126+
"CUDA_HOME", "LD_LIBRARY_PATH", "CUDA_LIB_PATH",
127+
"CUDA_VERSION", "CUDA_ROOT", "PATH"
128+
]
129+
for env_var in cuda_env_var_names:
130+
if env_var in os.environ:
131+
cuda_env_vars[env_var] = os.environ[env_var]
132+
124133
llm_kwargs = {
125134
"model": self.model_config.local_path,
126135
"backend": "pytorch",
127136
"orchestrator_type": "ray",
128137
"ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
129138
"kv_cache_config": kv_cache_config,
139+
"cuda_graph_config": cuda_graph_config,
130140
"max_seq_len": self.config.max_model_len,
131141
"max_batch_size": self.config.max_num_seqs,
132142
"max_num_tokens": self.config.max_num_batched_tokens,
@@ -141,26 +151,9 @@ async def launch_server(self):
141151
**engine_kwargs,
142152
}
143153

144-
if self.is_reward_model:
145-
llm_kwargs.update(
146-
{
147-
"cuda_graph_config": None,
148-
"disable_overlap_scheduler": True,
149-
}
150-
)
151-
else:
152-
llm_kwargs.update(
153-
{
154-
"cuda_graph_config": CudaGraphConfig(
155-
enable_padding=True,
156-
batch_sizes=self.config.cudagraph_capture_sizes,
157-
max_batch_size=0 if self.config.cudagraph_capture_sizes else self.config.max_num_seqs,
158-
)
159-
}
160-
)
161-
162154
self.llm = await AsyncLLM(**llm_kwargs)
163-
155+
if self.is_vlm_model:
156+
self.visual_processor = hf_processor(self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code)
164157
trtllm_server = OpenAIServer(
165158
llm=self.llm,
166159
model=self.model_config.local_path,
@@ -192,10 +185,39 @@ async def generate(
192185
sampling_params.update(self.sampling_args)
193186

194187
trt_llm_sampling_params = SamplingParams(**sampling_params)
195-
outputs = await self.llm.generate_async(
196-
inputs=prompt_ids,
197-
sampling_params=trt_llm_sampling_params,
198-
)
188+
if self.is_vlm_model:
189+
multi_modal_inputs = self.visual_processor(
190+
text=[""], # 占位符,实际不使用
191+
images=image_data if image_data else None,
192+
videos=video_data if video_data else None,
193+
return_tensors="pt"
194+
)
195+
196+
# 提取多模态相关的 tensor
197+
mm_processor_kwargs = {
198+
k: v for k, v in multi_modal_inputs.items()
199+
if k not in ["input_ids", "attention_mask"]
200+
}
201+
202+
tokens_prompt = {
203+
"prompt_token_ids": prompt_ids, # 使用已有的 prompt_ids
204+
"multi_modal_data": {
205+
"mm_processor_kwargs": mm_processor_kwargs, # 多模态处理后的 tensor
206+
},
207+
}
208+
if image_data:
209+
tokens_prompt["multi_modal_data"]["image"] = image_data
210+
if video_data:
211+
tokens_prompt["multi_modal_data"]["video"] = video_data
212+
outputs = await self.llm.generate_async(
213+
inputs=tokens_prompt,
214+
sampling_params=trt_llm_sampling_params,
215+
)
216+
else:
217+
outputs = await self.llm.generate_async(
218+
inputs=prompt_ids,
219+
sampling_params=trt_llm_sampling_params,
220+
)
199221

200222
token_ids = outputs.outputs[0].token_ids
201223
log_probs = None
@@ -205,19 +227,16 @@ async def generate(
205227

206228
async def wake_up(self):
207229
if self.rollout_mode == RolloutMode.HYBRID:
208-
# In hybrid mode, rollout is wake up in `update_weights`
209-
raise ValueError(f"wake_up not support rollout_mode {self.rollout_mode}")
210-
if self.rollout_mode == RolloutMode.COLOCATED:
230+
# Call all workers to switch between trainer mode and rollout mode.
231+
await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers])
232+
elif self.rollout_mode == RolloutMode.COLOCATED:
211233
await self.llm.resume(tags=ServerAdapter.get_full_tags())
212234
elif self.rollout_mode == RolloutMode.STANDALONE:
213235
logger.info("skip wake_up in standalone mode")
214236

215237
async def sleep(self):
216-
if not self.config.free_cache_engine:
217-
return
218-
219238
if self.rollout_mode == RolloutMode.HYBRID:
220-
await self.llm.release(tags=ServerAdapter.get_full_tags())
239+
await asyncio.gather(*[worker.sleep.remote() for worker in self.workers])
221240
elif self.rollout_mode == RolloutMode.COLOCATED:
222241
await self.llm.release(tags=ServerAdapter.get_full_tags())
223242
elif self.rollout_mode == RolloutMode.STANDALONE:
@@ -329,17 +348,19 @@ async def launch_servers(self):
329348
else f"trtllm_server_reward_{self.replica_rank}"
330349
)
331350

351+
runtime_env_vars = {
352+
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1",
353+
"TLLM_NUMA_AWARE_WORKER_AFFINITY": "0"
354+
}
332355
server = TRTLLMHttpServer.options(
333356
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
334357
node_id=node_id,
335358
soft=False,
336359
),
337-
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
338-
name=name,
360+
runtime_env={"env_vars": runtime_env_vars},
339361
).remote(
340362
config=self.config,
341363
model_config=self.model_config,
342-
is_reward_model=self.is_reward_model,
343364
rollout_mode=self.rollout_mode,
344365
workers=self.workers,
345366
replica_rank=self.replica_rank,

0 commit comments

Comments
 (0)