@@ -85,15 +85,45 @@ def update_weights_from_ipc(
8585
8686class VllmColocateWorkerExtension :
8787 """
88- The class for vLLM's worker to inherit from, in the colocate setting.
89- By defining an extension class, the code can work no matter what is
90- the underlying worker class. This way, the code can be compatible
91- with both vLLM V0 and V1.
92- NOTE: we define this class in a separate module, and the main module
93- should pass the full qualified name as `worker_extension_cls` argument.
88+ Worker extension for vLLM to update weights from checkpoint-engine.
89+
90+ This class provides a worker extension mechanism that allows vLLM workers to receive
91+ and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
92+ The methods in this worker extension will be injected into the vLLM worker class and
93+ are callable from the `collective_rpc` API, enabling seamless weight updates for both
94+ vLLM V0 and V1 versions.
95+
96+ Note:
97+ This class is defined in a separate module. The fully qualified name
98+ `checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
99+ `worker_extension_cls` argument when initializing the vLLM worker.
94100 """
95101
96102 def update_weights_from_ipc (self , zmq_handles : dict [str , str ]):
103+ """
104+ Update model weights from checkpoint-engine via IPC communication.
105+
106+ This method establishes a ZMQ connection to the checkpoint-engine and receives
107+ weight updates through a shared memory buffer. The update process includes:
108+ 1. Receiving IPC handles to reconstruct shared memory tensors
109+ 2. Extracting flattened metadata describing tensor weights in the shared memory tensor
110+ 3. Loading weights into the model
111+ 4. Post-processing weights after loading
112+
113+ Args:
114+ zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
115+ The device UUID is platform-specific:
116+ - For CUDA: UUID from `current_platform.get_device_uuid()`
117+ - For NPU: Format "NPU-{generated_uuid}"
118+
119+ Raises:
120+ ValueError: If the device type is not supported (not CUDA or NPU).
121+ AssertionError: If the device is not properly initialized.
122+
123+ Note:
124+ This method is called by vLLM's collective RPC mechanism. The ZMQ context
125+ is lazily initialized on first call and reused for subsequent updates.
126+ """
97127 from vllm .model_executor .model_loader .utils import process_weights_after_loading
98128 from vllm .platforms import current_platform
99129
@@ -103,10 +133,12 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
103133 assert self .device is not None
104134 if not hasattr (self , "_zmq_ctx" ) or self ._zmq_ctx is None :
105135 self ._zmq_ctx = zmq .Context ()
106- if current_platform .device_type == "gpu " :
136+ if current_platform .device_type == "cuda " :
107137 device_uuid = current_platform .get_device_uuid (self .device .index )
108138 elif current_platform .device_type == "npu" :
109139 device_uuid = f"NPU-{ npu_generate_uuid ()} "
140+ else :
141+ raise ValueError (f"Unsupported device type: { current_platform .device_type } " )
110142 update_weights_from_ipc (
111143 self ._zmq_ctx ,
112144 zmq_handles [device_uuid ],
0 commit comments