Skip to content

Commit 1f14a07

Browse files
author
hongchao
committed
support mtp
1 parent f6910d6 commit 1f14a07

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

checkpoint_engine/worker.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
1111

1212

13+
_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
14+
15+
1316
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
1417
func, args = handle
1518
list_args = list(args)
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
2932
offset: int
3033

3134

32-
def _extract_weights(
33-
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
34-
) -> list[tuple[str, torch.Tensor]]:
35+
def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
3536
assert buffer is not None
36-
weights: list[tuple[str, torch.Tensor]] = []
37+
weights: _WEIGHTS_TYPE = []
3738
for item in payload:
3839
shape = item["shape"]
3940
if isinstance(shape, list | tuple):
@@ -166,12 +167,25 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
166167
self.device = torch.device(f"npu:{self.local_rank}")
167168
assert self.device is not None
168169

170+
def _load_weights(weights: _WEIGHTS_TYPE):
171+
# Load main model weights
172+
self.model_runner.model.load_weights(weights)
173+
# Load drafter model weights if MTP/speculative decoding is enabled
174+
if getattr(self.model_runner, "speculative_config", None) is not None:
175+
self.model_runner.drafter.model.load_weights(weights=weights)
176+
177+
def _post_hook():
178+
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
179+
# Also trigger drafter model's post processing if MTP is enabled
180+
if getattr(self.model_runner, "speculative_config", None) is not None:
181+
process_weights_after_loading(
182+
self.model_runner.drafter.model, self.model_config, self.device
183+
)
184+
169185
update_weights_from_ipc(
170186
self._zmq_ctx,
171187
zmq_handles[self._device_uuid],
172188
device_id=self.device.index,
173-
run=self.model_runner.model.load_weights,
174-
post_hook=lambda: process_weights_after_loading(
175-
self.model_runner.model, self.model_config, self.device
176-
),
189+
run=_load_weights,
190+
post_hook=_post_hook,
177191
)

0 commit comments

Comments
 (0)