1010from checkpoint_engine .device_utils import DeviceManager , npu_generate_uuid
1111
1212
13+ _WEIGHTS_TYPE = list [tuple [str , torch .Tensor ]]
14+
15+
1316def _rebuild_ipc (handle : tuple [Callable , tuple ], device_id : int | None = None ) -> torch .Tensor :
1417 func , args = handle
1518 list_args = list (args )
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
2932 offset : int
3033
3134
32- def _extract_weights (
33- payload : list [FlattenedTensorMetadata ], buffer : torch .Tensor
34- ) -> list [tuple [str , torch .Tensor ]]:
35+ def _extract_weights (payload : list [FlattenedTensorMetadata ], buffer : torch .Tensor ) -> _WEIGHTS_TYPE :
3536 assert buffer is not None
36- weights : list [ tuple [ str , torch . Tensor ]] = []
37+ weights : _WEIGHTS_TYPE = []
3738 for item in payload :
3839 shape = item ["shape" ]
3940 if isinstance (shape , list | tuple ):
@@ -166,12 +167,25 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
166167 self .device = torch .device (f"npu:{ self .local_rank } " )
167168 assert self .device is not None
168169
170+ def _load_weights (weights : _WEIGHTS_TYPE ):
171+ # Load main model weights
172+ self .model_runner .model .load_weights (weights )
173+ # Load drafter model weights if MTP/speculative decoding is enabled
174+ if getattr (self .model_runner , "speculative_config" , None ) is not None :
175+ self .model_runner .drafter .model .load_weights (weights = weights )
176+
177+ def _post_hook ():
178+ process_weights_after_loading (self .model_runner .model , self .model_config , self .device )
179+ # Also trigger drafter model's post processing if MTP is enabled
180+ if getattr (self .model_runner , "speculative_config" , None ) is not None :
181+ process_weights_after_loading (
182+ self .model_runner .drafter .model , self .model_config , self .device
183+ )
184+
169185 update_weights_from_ipc (
170186 self ._zmq_ctx ,
171187 zmq_handles [self ._device_uuid ],
172188 device_id = self .device .index ,
173- run = self .model_runner .model .load_weights ,
174- post_hook = lambda : process_weights_after_loading (
175- self .model_runner .model , self .model_config , self .device
176- ),
189+ run = _load_weights ,
190+ post_hook = _post_hook ,
177191 )
0 commit comments