Skip to content

Commit 63b782d

Browse files
kip-cxjkip-cxj
andauthored
cache device uuid (#74)
Co-authored-by: kip-cxj <cuixiaojin@huawei.com>
1 parent 322f684 commit 63b782d

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

checkpoint_engine/worker.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,18 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
151151
assert self.device is not None
152152
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
153153
self._zmq_ctx = zmq.Context()
154-
if current_platform.device_type == "cuda":
155-
device_uuid = current_platform.get_device_uuid(self.device.index)
156-
elif current_platform.device_type == "npu":
157-
device_uuid = f"NPU-{npu_generate_uuid()}"
158-
else:
159-
raise ValueError(f"Unsupported device type: {current_platform.device_type}")
154+
155+
if not hasattr(self, "_device_uuid") or self._device_uuid is None:
156+
if current_platform.device_type == "cuda":
157+
self._device_uuid = current_platform.get_device_uuid(self.device.index)
158+
elif current_platform.device_type == "npu":
159+
self._device_uuid = f"NPU-{npu_generate_uuid()}"
160+
else:
161+
raise ValueError(f"Unsupported device type: {current_platform.device_type}")
162+
160163
update_weights_from_ipc(
161164
self._zmq_ctx,
162-
zmq_handles[device_uuid],
165+
zmq_handles[self._device_uuid],
163166
device_id=self.device.index,
164167
run=self.model_runner.model.load_weights,
165168
post_hook=lambda: process_weights_after_loading(

0 commit comments

Comments
 (0)