99
1010from checkpoint_engine .device_utils import DeviceManager , npu_generate_uuid
1111
12+ _WEIGHTS_TYPE = list [tuple [str , torch .Tensor ]]
1213
1314def _rebuild_ipc (handle : tuple [Callable , tuple ], device_id : int | None = None ) -> torch .Tensor :
1415 func , args = handle
@@ -31,9 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
3132
3233def _extract_weights (
3334 payload : list [FlattenedTensorMetadata ], buffer : torch .Tensor
34- ) -> list [ tuple [ str , torch . Tensor ]] :
35+ ) -> _WEIGHTS_TYPE :
3536 assert buffer is not None
36- weights : list [ tuple [ str , torch . Tensor ]] = []
37+ weights : _WEIGHTS_TYPE = []
3738 for item in payload :
3839 shape = item ["shape" ]
3940 if isinstance (shape , list | tuple ):
@@ -166,11 +167,18 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
166167 self .device = torch .device (f"npu:{ self .local_rank } " )
167168 assert self .device is not None
168169
170+ def _load_weights (weights : _WEIGHTS_TYPE ):
171+ # Load main model weights
172+ self .model_runner .model .load_weights (weights )
173+ # Load drafter model weights if MTP/speculative decoding is enabled
174+ if self .model_runner .get ("use_spec_decode" , False ):
175+ self .model_runner .drafter .model .load_weights (weights = weights )
176+
169177 update_weights_from_ipc (
170178 self ._zmq_ctx ,
171179 zmq_handles [self ._device_uuid ],
172180 device_id = self .device .index ,
173- run = self . model_runner . model . load_weights ,
181+ run = _load_weights ,
174182 post_hook = lambda : process_weights_after_loading (
175183 self .model_runner .model , self .model_config , self .device
176184 ),
0 commit comments