Skip to content

Commit abc1650

Browse files
author
hongchao
committed
support mtp
1 parent f6910d6 commit abc1650

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

checkpoint_engine/worker.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
1111

12+
_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
1213

1314
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
1415
func, args = handle
@@ -31,9 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
3132

3233
def _extract_weights(
3334
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
34-
) -> list[tuple[str, torch.Tensor]]:
35+
) -> _WEIGHTS_TYPE:
3536
assert buffer is not None
36-
weights: list[tuple[str, torch.Tensor]] = []
37+
weights: _WEIGHTS_TYPE = []
3738
for item in payload:
3839
shape = item["shape"]
3940
if isinstance(shape, list | tuple):
@@ -166,11 +167,18 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
166167
self.device = torch.device(f"npu:{self.local_rank}")
167168
assert self.device is not None
168169

170+
def _load_weights(weights: _WEIGHTS_TYPE):
171+
# Load main model weights
172+
self.model_runner.model.load_weights(weights)
173+
# Load drafter model weights if MTP/speculative decoding is enabled
174+
if self.model_runner.get("use_spec_decode", False):
175+
self.model_runner.drafter.model.load_weights(weights=weights)
176+
169177
update_weights_from_ipc(
170178
self._zmq_ctx,
171179
zmq_handles[self._device_uuid],
172180
device_id=self.device.index,
173-
run=self.model_runner.model.load_weights,
181+
run=_load_weights,
174182
post_hook=lambda: process_weights_after_loading(
175183
self.model_runner.model, self.model_config, self.device
176184
),

0 commit comments

Comments
 (0)