Skip to content

Commit 34eb34b

Browse files
committed
Fix wrong device_type, refine documents in worker.py
1 parent 22874d4 commit 34eb34b

File tree

1 file changed

+39
-7
lines changed

1 file changed

+39
-7
lines changed

checkpoint_engine/worker.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,45 @@ def update_weights_from_ipc(
8585

8686
class VllmColocateWorkerExtension:
8787
"""
88-
The class for vLLM's worker to inherit from, in the colocate setting.
89-
By defining an extension class, the code can work no matter what is
90-
the underlying worker class. This way, the code can be compatible
91-
with both vLLM V0 and V1.
92-
NOTE: we define this class in a separate module, and the main module
93-
should pass the full qualified name as `worker_extension_cls` argument.
88+
Worker extension for vLLM to update weights from checkpoint-engine.
89+
90+
This class provides a worker extension mechanism that allows vLLM workers to receive
91+
and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
92+
The methods in this worker extension will be injected into the vLLM worker class and
93+
are callable from the `collective_rpc` API, enabling seamless weight updates for both
94+
vLLM V0 and V1 versions.
95+
96+
Note:
97+
This class is defined in a separate module. The fully qualified name
98+
`checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
99+
`worker_extension_cls` argument when initializing the vLLM worker.
94100
"""
95101

96102
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
103+
"""
104+
Update model weights from checkpoint-engine via IPC communication.
105+
106+
This method establishes a ZMQ connection to the checkpoint-engine and receives
107+
weight updates through a shared memory buffer. The update process includes:
108+
1. Receiving IPC handles to reconstruct shared memory tensors
109+
2. Extracting flattened metadata describing tensor weights in the shared memory tensor
110+
3. Loading weights into the model
111+
4. Post-processing weights after loading
112+
113+
Args:
114+
zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
115+
The device UUID is platform-specific:
116+
- For CUDA: UUID from `current_platform.get_device_uuid()`
117+
- For NPU: Format "NPU-{generated_uuid}"
118+
119+
Raises:
120+
ValueError: If the device type is not supported (not CUDA or NPU).
121+
AssertionError: If the device is not properly initialized.
122+
123+
Note:
124+
This method is called by vLLM's collective RPC mechanism. The ZMQ context
125+
is lazily initialized on first call and reused for subsequent updates.
126+
"""
97127
from vllm.model_executor.model_loader.utils import process_weights_after_loading
98128
from vllm.platforms import current_platform
99129

@@ -103,10 +133,12 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
103133
assert self.device is not None
104134
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
105135
self._zmq_ctx = zmq.Context()
106-
if current_platform.device_type == "gpu":
136+
if current_platform.device_type == "cuda":
107137
device_uuid = current_platform.get_device_uuid(self.device.index)
108138
elif current_platform.device_type == "npu":
109139
device_uuid = f"NPU-{npu_generate_uuid()}"
140+
else:
141+
raise ValueError(f"Unsupported device type: {current_platform.device_type}")
110142
update_weights_from_ipc(
111143
self._zmq_ctx,
112144
zmq_handles[device_uuid],

0 commit comments

Comments
 (0)