File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed
Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -177,7 +177,7 @@ def _load_weights(weights: _WEIGHTS_TYPE):
177177 ):
178178 self .model_runner .drafter .model .load_weights (weights = weights )
179179
180- def _post_hook ():
180+ def _process_weight_after_loading ():
181181 process_weights_after_loading (self .model_runner .model , self .model_config , self .device )
182182 # Also trigger drafter model's post processing if MTP is enabled
183183 if (
@@ -188,10 +188,15 @@ def _post_hook():
188188 self .model_runner .drafter .model , self .model_config , self .device
189189 )
190190
191+ torch .cuda .empty_cache ()
192+
191193 update_weights_from_ipc (
192194 self ._zmq_ctx ,
193195 zmq_handles [self ._device_uuid ],
194196 device_id = self .device .index ,
195197 run = _load_weights ,
196- post_hook = _post_hook ,
198+ post_hook = _process_weight_after_loading ,
197199 )
200+
201+ if getattr (self , "_sampler_warmup" , None ) is not None :
202+ self ._sampler_warmup ()
You can’t perform that action at this time.
0 commit comments