@@ -52,9 +52,13 @@ def update_weights_from_ipc(
5252 zmq_handle : str ,
5353 device_id : int ,
5454 * ,
55- run : Callable [[list [tuple [str , torch .Tensor ]]], None ],
55+ weight_loader : Callable [[list [tuple [str , torch .Tensor ]]], None ],
56+ pre_hook : Callable [[], None ] | None = None ,
57+ process_weight_after_loading : Callable [[], None ] | None = None ,
5658 post_hook : Callable [[], None ] | None = None ,
5759):
60+ if pre_hook is not None :
61+ pre_hook ()
5862 socket = zmq_ctx .socket (zmq .REP )
5963 socket .connect (zmq_handle )
6064 buffer : torch .Tensor | None = None
@@ -74,14 +78,14 @@ def update_weights_from_ipc(
7478 while True :
7579 payload : list [FlattenedTensorMetadata ] | Exception | None = socket .recv_pyobj ()
7680 if payload is None : # done signal
77- if post_hook is not None :
78- post_hook ()
81+ if process_weight_after_loading is not None :
82+ process_weight_after_loading ()
7983 device_manager .device_module .synchronize ()
8084 socket .send (b"" )
8185 break
8286 if isinstance (payload , list ): # still updating weights
8387 try :
84- run (_extract_weights (payload , buffer ))
88+ weight_loader (_extract_weights (payload , buffer ))
8589 device_manager .device_module .synchronize ()
8690 socket .send (b"" )
8791 except Exception as e : # noqa: BLE001
@@ -102,6 +106,9 @@ def update_weights_from_ipc(
102106 gc .collect ()
103107 device_manager .device_module .empty_cache ()
104108
109+ if post_hook is not None :
110+ post_hook ()
111+
105112
106113class VllmColocateWorkerExtension :
107114 """
@@ -177,7 +184,7 @@ def _load_weights(weights: _WEIGHTS_TYPE):
177184 ):
178185 self .model_runner .drafter .model .load_weights (weights = weights )
179186
180- def _post_hook ():
187+ def _process_weight_after_loading ():
181188 process_weights_after_loading (self .model_runner .model , self .model_config , self .device )
182189 # Also trigger drafter model's post processing if MTP is enabled
183190 if (
@@ -188,10 +195,15 @@ def _post_hook():
188195 self .model_runner .drafter .model , self .model_config , self .device
189196 )
190197
198+ def _pre_hook ():
199+ torch .cuda .empty_cache ()
200+
191201 update_weights_from_ipc (
192202 self ._zmq_ctx ,
193203 zmq_handles [self ._device_uuid ],
194204 device_id = self .device .index ,
195- run = _load_weights ,
196- post_hook = _post_hook ,
205+ pre_hook = _pre_hook ,
206+ weight_loader = _load_weights ,
207+ process_weight_after_loading = _process_weight_after_loading ,
208+ post_hook = getattr (self , "_sampler_warmup" , None ),
197209 )
0 commit comments