From 1f14a07fa749aa04462be40c8dc5ebc7a819deb4 Mon Sep 17 00:00:00 2001 From: hongchao Date: Tue, 27 Jan 2026 08:36:59 +0000 Subject: [PATCH 1/3] support mtp --- checkpoint_engine/worker.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index c69815c..53d9991 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -10,6 +10,9 @@ from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid +_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]] + + def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: func, args = handle list_args = list(args) @@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict): offset: int -def _extract_weights( - payload: list[FlattenedTensorMetadata], buffer: torch.Tensor -) -> list[tuple[str, torch.Tensor]]: +def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE: assert buffer is not None - weights: list[tuple[str, torch.Tensor]] = [] + weights: _WEIGHTS_TYPE = [] for item in payload: shape = item["shape"] if isinstance(shape, list | tuple): @@ -166,12 +167,25 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): self.device = torch.device(f"npu:{self.local_rank}") assert self.device is not None + def _load_weights(weights: _WEIGHTS_TYPE): + # Load main model weights + self.model_runner.model.load_weights(weights) + # Load drafter model weights if MTP/speculative decoding is enabled + if getattr(self.model_runner, "speculative_config", None) is not None: + self.model_runner.drafter.model.load_weights(weights=weights) + + def _post_hook(): + process_weights_after_loading(self.model_runner.model, self.model_config, self.device) + # Also trigger drafter model's post processing if MTP is enabled + if getattr(self.model_runner, "speculative_config", None) is not None: + process_weights_after_loading( + self.model_runner.drafter.model, self.model_config, self.device + ) + update_weights_from_ipc( self._zmq_ctx, zmq_handles[self._device_uuid], device_id=self.device.index, - run=self.model_runner.model.load_weights, - post_hook=lambda: process_weights_after_loading( - self.model_runner.model, self.model_config, self.device - ), + run=_load_weights, + post_hook=_post_hook, ) From 8c0a54869bee30b1dfb89ef905822c37ba953336 Mon Sep 17 00:00:00 2001 From: hongchao Date: Wed, 28 Jan 2026 09:33:13 +0000 Subject: [PATCH 2/3] fix review --- checkpoint_engine/worker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 53d9991..fdf1c0b 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -171,13 +171,17 @@ def _load_weights(weights: _WEIGHTS_TYPE): # Load main model weights self.model_runner.model.load_weights(weights) # Load drafter model weights if MTP/speculative decoding is enabled - if getattr(self.model_runner, "speculative_config", None) is not None: + if hasattr(self.model_runner, "drafter") and hasattr( + self.model_runner.drafter, "model" + ): self.model_runner.drafter.model.load_weights(weights=weights) def _post_hook(): process_weights_after_loading(self.model_runner.model, self.model_config, self.device) # Also trigger drafter model's post processing if MTP is enabled - if getattr(self.model_runner, "speculative_config", None) is not None: + if hasattr(self.model_runner, "drafter") and hasattr( + self.model_runner.drafter, "model" + ): process_weights_after_loading( self.model_runner.drafter.model, self.model_config, self.device ) From 1fa937bb57da263c8f196386508ec5f1c62f634c Mon Sep 17 00:00:00 2001 From: hongchao Date: Wed, 28 Jan 2026 09:35:53 +0000 Subject: [PATCH 3/3] fix review --- checkpoint_engine/worker.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index fdf1c0b..3979cda 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -171,16 +171,18 @@ def _load_weights(weights: _WEIGHTS_TYPE): # Load main model weights self.model_runner.model.load_weights(weights) # Load drafter model weights if MTP/speculative decoding is enabled - if hasattr(self.model_runner, "drafter") and hasattr( - self.model_runner.drafter, "model" + if ( + getattr(self.model_runner, "drafter", None) is not None + and getattr(self.model_runner.drafter, "model", None) is not None ): self.model_runner.drafter.model.load_weights(weights=weights) def _post_hook(): process_weights_after_loading(self.model_runner.model, self.model_config, self.device) # Also trigger drafter model's post processing if MTP is enabled - if hasattr(self.model_runner, "drafter") and hasattr( - self.model_runner.drafter, "model" + if ( + getattr(self.model_runner, "drafter", None) is not None + and getattr(self.model_runner.drafter, "model", None) is not None ): process_weights_after_loading( self.model_runner.drafter.model, self.model_config, self.device