Skip to content

Commit 2aea608

Browse files
youzhedianhongchao
authored andcommitted
fix ps alloc err & avoid mem fragmentation
1 parent 009082d commit 2aea608

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

checkpoint_engine/worker.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,13 @@ def update_weights_from_ipc(
5252
zmq_handle: str,
5353
device_id: int,
5454
*,
55-
run: Callable[[list[tuple[str, torch.Tensor]]], None],
55+
weight_loader: Callable[[list[tuple[str, torch.Tensor]]], None],
56+
pre_hook: Callable[[], None] | None = None,
57+
process_weight_after_loading: Callable[[], None] | None = None,
5658
post_hook: Callable[[], None] | None = None,
5759
):
60+
if pre_hook is not None:
61+
pre_hook()
5862
socket = zmq_ctx.socket(zmq.REP)
5963
socket.connect(zmq_handle)
6064
buffer: torch.Tensor | None = None
@@ -74,14 +78,14 @@ def update_weights_from_ipc(
7478
while True:
7579
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
7680
if payload is None: # done signal
77-
if post_hook is not None:
78-
post_hook()
81+
if process_weight_after_loading is not None:
82+
process_weight_after_loading()
7983
device_manager.device_module.synchronize()
8084
socket.send(b"")
8185
break
8286
if isinstance(payload, list): # still updating weights
8387
try:
84-
run(_extract_weights(payload, buffer))
88+
weight_loader(_extract_weights(payload, buffer))
8589
device_manager.device_module.synchronize()
8690
socket.send(b"")
8791
except Exception as e: # noqa: BLE001
@@ -102,6 +106,9 @@ def update_weights_from_ipc(
102106
gc.collect()
103107
device_manager.device_module.empty_cache()
104108

109+
if post_hook is not None:
110+
post_hook()
111+
105112

106113
class VllmColocateWorkerExtension:
107114
"""
@@ -177,7 +184,7 @@ def _load_weights(weights: _WEIGHTS_TYPE):
177184
):
178185
self.model_runner.drafter.model.load_weights(weights=weights)
179186

180-
def _post_hook():
187+
def _process_weight_after_loading():
181188
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
182189
# Also trigger drafter model's post processing if MTP is enabled
183190
if (
@@ -188,10 +195,15 @@ def _post_hook():
188195
self.model_runner.drafter.model, self.model_config, self.device
189196
)
190197

198+
def _pre_hook():
199+
torch.cuda.empty_cache()
200+
191201
update_weights_from_ipc(
192202
self._zmq_ctx,
193203
zmq_handles[self._device_uuid],
194204
device_id=self.device.index,
195-
run=_load_weights,
196-
post_hook=_post_hook,
205+
pre_hook=_pre_hook,
206+
weight_loader=_load_weights,
207+
process_weight_after_loading=_process_weight_after_loading,
208+
post_hook=getattr(self, "_sampler_warmup", None),
197209
)

0 commit comments

Comments
 (0)