Skip to content

Commit dcaacfe

Browse files
committed
Fix partial load problem, Add vlm support for trtllm rollout
1 parent 15e28de commit dcaacfe

File tree

3 files changed

+168
-28
lines changed

3 files changed

+168
-28
lines changed

verl/workers/rollout/trtllm_rollout/trtllm_async_server.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def launch_server(self):
125125
"model": self.model_config.local_path,
126126
"backend": "pytorch",
127127
"orchestrator_type": "ray",
128-
"ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
128+
"ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension",
129129
"kv_cache_config": kv_cache_config,
130130
"max_seq_len": self.config.max_model_len,
131131
"max_batch_size": self.config.max_num_seqs,
@@ -159,18 +159,45 @@ async def launch_server(self):
159159
}
160160
)
161161

162-
self.llm = await AsyncLLM(**llm_kwargs)
163-
164-
trtllm_server = OpenAIServer(
165-
llm=self.llm,
166-
model=self.model_config.local_path,
167-
tool_parser=None,
168-
server_role=None,
169-
metadata_server_cfg=None,
170-
)
162+
if self.is_vlm_model:
163+
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
164+
multimodal_config = MultimodalServerConfig(
165+
media_io_kwargs={
166+
"image": {
167+
"format": "pil",
168+
"device": "cpu",
169+
},
170+
"video": {
171+
"num_frames": 8,
172+
"fps": 30,
173+
"format": "pil",
174+
"device": "cpu",
175+
},
176+
}
177+
)
178+
self.llm = await AsyncLLM(**llm_kwargs)
179+
trtllm_server = OpenAIServer(
180+
llm=self.llm,
181+
model=self.model_config.local_path,
182+
tool_parser=None,
183+
server_role=None,
184+
metadata_server_cfg=None,
185+
multimodal_server_config=multimodal_config,
186+
)
187+
else:
188+
self.llm = await AsyncLLM(**llm_kwargs)
189+
trtllm_server = OpenAIServer(
190+
llm=self.llm,
191+
model=self.model_config.local_path,
192+
tool_parser=None,
193+
server_role=None,
194+
metadata_server_cfg=None,
195+
)
196+
171197
app = trtllm_server.app
172198
self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address)
173199

200+
@resume_on_abort
174201
async def generate(
175202
self,
176203
prompt_ids: list[int],
@@ -179,11 +206,7 @@ async def generate(
179206
image_data: Optional[list[Any]] = None,
180207
video_data: Optional[list[Any]] = None,
181208
) -> TokenOutput:
182-
"""Generate sequence with token-in-token-out."""
183-
assert image_data is None and video_data is None, "Multimodality is not yet supported in TRTLLMHttpServer."
184-
185209
from tensorrt_llm.llmapi import SamplingParams
186-
187210
max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids))
188211
sampling_params["max_tokens"] = max_tokens
189212
sampling_params["logprobs"] = 1 if sampling_params.pop("logprobs", False) else None
@@ -192,15 +215,34 @@ async def generate(
192215
sampling_params.update(self.sampling_args)
193216

194217
trt_llm_sampling_params = SamplingParams(**sampling_params)
195-
outputs = await self.llm.generate_async(
196-
inputs=prompt_ids,
197-
sampling_params=trt_llm_sampling_params,
198-
)
199-
218+
if self.is_vlm_model:
219+
if image_data or video_data:
220+
input_dict = {
221+
"prompt_token_ids": prompt_ids,
222+
"multi_modal_data": {},
223+
}
224+
if image_data:
225+
input_dict["multi_modal_data"]["image"] = image_data
226+
if video_data:
227+
input_dict["multi_modal_data"]["video"] = video_data
228+
outputs = await self.llm.generate_async(
229+
inputs=input_dict,
230+
sampling_params=trt_llm_sampling_params,
231+
)
232+
else:
233+
outputs = await self.llm.generate_async(
234+
inputs=prompt_ids,
235+
sampling_params=trt_llm_sampling_params,
236+
)
237+
else:
238+
outputs = await self.llm.generate_async(
239+
inputs=prompt_ids,
240+
sampling_params=trt_llm_sampling_params,
241+
)
200242
token_ids = outputs.outputs[0].token_ids
201243
log_probs = None
202-
if trt_llm_sampling_params.logprobs is not None:
203-
log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs]
244+
if outputs.outputs[0].logprobs is not None:
245+
log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(outputs.outputs[0].logprobs)]
204246
return TokenOutput(token_ids=token_ids, log_probs=log_probs)
205247

206248
async def wake_up(self):

verl/workers/rollout/trtllm_rollout/trtllm_rollout.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def __init__(
281281
self.is_leader_rank = None
282282
self.replica_rank = None
283283
self.is_dp_rank = None
284+
self._supports_partial_loading = None
284285

285286
# hybrid mode
286287
if self.device_mesh is not None:
@@ -312,6 +313,21 @@ def __init__(
312313

313314
self.node_ip = ray.util.get_node_ip_address().strip("[]")
314315

316+
async def get_supports_partial_loading(self) -> bool:
317+
"""Query and cache whether the model supports partial weight loading."""
318+
if self._supports_partial_loading is not None:
319+
return self._supports_partial_loading
320+
321+
await self._init_server_adapter()
322+
try:
323+
self._supports_partial_loading = await self.server_actor.supports_partial_loading.remote()
324+
except Exception as e:
325+
logger.warning(f"Failed to query partial loading support: {e}, defaulting to False")
326+
self._supports_partial_loading = False
327+
328+
logger.info(f"Model supports partial loading: {self._supports_partial_loading}")
329+
return self._supports_partial_loading
330+
315331
async def _init_server_adapter(self):
316332
if self._adapter is not None:
317333
return
@@ -405,16 +421,21 @@ async def flush():
405421
await self.update_weights_from_ipc_handles(serialized_device_handles)
406422
cur_available_bytes = total_available_bytes
407423
cur_handles = []
424+
425+
# Query if model supports partial loading
426+
supports_partial_loading = await self.get_supports_partial_loading()
408427

409428
for name, param in weights:
410-
size_in_bytes = param.element_size() * param.numel()
411-
if size_in_bytes > cur_available_bytes:
412-
await flush()
429+
if supports_partial_loading:
430+
size_in_bytes = param.element_size() * param.numel()
431+
if size_in_bytes > cur_available_bytes:
432+
await flush()
433+
434+
assert cur_available_bytes >= size_in_bytes, (
435+
f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}"
436+
)
437+
cur_available_bytes -= size_in_bytes
413438

414-
assert cur_available_bytes >= size_in_bytes, (
415-
f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}"
416-
)
417-
cur_available_bytes -= size_in_bytes
418439
handle = reduce_tensor(param.detach())
419440
cur_handles.append((name, handle))
420441

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import base64
2+
import inspect
3+
import pickle
4+
from typing import Optional
5+
6+
from tensorrt_llm._ray_utils import control_action_decorator
7+
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
8+
from tensorrt_llm._torch.utils import get_device_uuid
9+
from tensorrt_llm.logger import logger
10+
11+
12+
class WorkerExtension:
13+
14+
def __init__(self):
15+
pass
16+
17+
@control_action_decorator
18+
def supports_partial_loading(self) -> bool:
19+
"""Check if the model supports partial weight loading."""
20+
try:
21+
model = self.engine.model_engine.model
22+
load_weights_args = inspect.getfullargspec(model.load_weights).args
23+
return "allow_partial_loading" in load_weights_args
24+
except Exception as e:
25+
logger.warning(f"Failed to check partial loading support: {e}")
26+
return False
27+
28+
@control_action_decorator
29+
def update_weights(self, ipc_handles: Optional[dict] = None):
30+
try:
31+
if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"):
32+
for module in self.engine.model_engine.model.modules():
33+
if hasattr(module, "pre_reload_weights") and not getattr(
34+
module, "_weights_removed", False
35+
):
36+
module.pre_reload_weights()
37+
setattr(self.engine.model_engine.model, "first_pre_reload_weights", True)
38+
39+
if ipc_handles is not None:
40+
device_uuid = get_device_uuid()
41+
handles = ipc_handles.get(device_uuid, None)
42+
if handles is not None:
43+
weights = pickle.loads(base64.b64decode(handles))
44+
model = self.engine.model_engine.model
45+
load_weights_args = inspect.getfullargspec(model.load_weights).args
46+
supports_partial_loading = "allow_partial_loading" in load_weights_args
47+
48+
if supports_partial_loading:
49+
self.engine.model_engine.model_loader.reload(
50+
model, weights, allow_partial_loading=True
51+
)
52+
else:
53+
self.engine.model_engine.model_loader.reload(
54+
model, weights, allow_partial_loading=False
55+
)
56+
else:
57+
for module in self.engine.model_engine.model.modules():
58+
if hasattr(module, "process_weights_after_loading") and not getattr(
59+
module, "_weights_removed", False
60+
):
61+
module.process_weights_after_loading()
62+
if hasattr(module, "post_load_weights") and not getattr(
63+
module, "_weights_removed", False
64+
):
65+
module.post_load_weights()
66+
moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None)
67+
if isinstance(moe_load_balancer, MoeLoadBalancer):
68+
moe_load_balancer.register_weight_slots_after_to_cuda()
69+
logger.info("moe_load_balancer finalizing model...")
70+
moe_load_balancer.finalize_model()
71+
logger.info("moe_load_balancer finalize model done")
72+
self.engine.reset_prefix_cache()
73+
delattr(self.engine.model_engine.model, "first_pre_reload_weights")
74+
75+
except Exception as e:
76+
logger.error("Encountered an error in update_weights")
77+
raise e

0 commit comments

Comments
 (0)