Skip to content

Commit 608df37

Browse files
committed
misc: fix pr issues
1 parent 39807ce commit 608df37

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

checkpoint_engine/ps.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -981,11 +981,6 @@ def update(
981981
return
982982
self._update_per_bucket(checkpoint_name, req_func, ranks)
983983

984-
logger.info(
985-
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
986-
f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
987-
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
988-
)
989984
except Exception as e:
990985
logger.exception(
991986
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
@@ -996,6 +991,11 @@ def update(
996991
dist.destroy_process_group()
997992

998993
self.device_manager.device_module.empty_cache()
994+
logger.info(
995+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
996+
f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
997+
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
998+
)
999999

10001000
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
10011001
def zmq_handle(device_uuid: str) -> str:
@@ -1225,7 +1225,7 @@ def _update_per_bucket(
12251225
socket.send_pyobj(handle)
12261226

12271227
gidx = 0
1228-
ret_code = torch.tensor(0, device=self.device_manager.device_type)
1228+
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
12291229
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
12301230
for i in range(max_len):
12311231
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
@@ -1268,8 +1268,9 @@ def _update_per_bucket(
12681268
self.device_manager.device_module.synchronize()
12691269
if ret_code.item() != 0:
12701270
# quit early if any rank failed
1271-
socket.send_pyobj(RuntimeError("Failed to update weights due to remote errors"))
1272-
raise RuntimeError("Failed to update weights due to remote errors")
1271+
exception = RuntimeError("Failed to update weights due to remote errors")
1272+
socket.send_pyobj(exception)
1273+
raise exception
12731274
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
12741275
gidx += 1
12751276

checkpoint_engine/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def update_weights_from_ipc(
6767
post_hook()
6868
device_mananger.device_module.synchronize()
6969
socket.send(b"")
70+
if isinstance(payload, Exception):
71+
raise payload
7072
break
7173
if isinstance(payload, tuple):
7274
# an ipc handle that vLLM can use `func, args = handle`

tests/test_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]):
8282
try:
8383
trigger_error(socket_paths)
8484
except RuntimeError as e:
85-
assert str(e) == "Intentional Error for testing."
85+
assert str(e) == "Failed to update weights due to remote errors"
8686

8787

8888
def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):

0 commit comments

Comments
 (0)