Skip to content

Commit 5ad2d95

Browse files
committed
feat: quit checkpoint worker process when error occurs
1 parent bbc83db commit 5ad2d95

File tree

3 files changed

+117
-8
lines changed

3 files changed

+117
-8
lines changed

checkpoint_engine/ps.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -886,10 +886,6 @@ def update(
886886
return
887887
self.init_process_group_for_ranks(ranks)
888888
self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
889-
if self._auto_pg:
890-
dist.destroy_process_group()
891-
892-
torch.cuda.empty_cache()
893889

894890
logger.info(
895891
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
@@ -901,6 +897,11 @@ def update(
901897
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
902898
)
903899
raise
900+
finally:
901+
if self._auto_pg:
902+
dist.destroy_process_group()
903+
904+
torch.cuda.empty_cache()
904905

905906
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
906907
def zmq_handle(device_uuid: str) -> str:
@@ -1191,7 +1192,13 @@ def _update_per_bucket(
11911192
else:
11921193
buffer_b.data.copy_(h2d_buffer[: bucket.size])
11931194
dist.broadcast(buffer_b, src=owner_rank)
1194-
socket.recv()
1195+
resp_list: list[bytes] = [b""] * dist.get_world_size()
1196+
resp = socket.recv()
1197+
dist.all_gather_object(resp_list, resp)
1198+
torch.cuda.synchronize()
1199+
if any(r != b"" for r in resp_list):
1200+
# quit early if any rank failed
1201+
raise RuntimeError("failed to update weights due to remote error")
11951202
dist.barrier()
11961203
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
11971204
gidx += 1

checkpoint_engine/worker.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,13 @@ def update_weights_from_ipc(
7070
socket.send(b"")
7171
continue
7272
assert isinstance(payload, list)
73-
run(_extract_weights(payload, buffer))
74-
torch.cuda.synchronize()
75-
socket.send(b"")
73+
try:
74+
run(_extract_weights(payload, buffer))
75+
torch.cuda.synchronize()
76+
socket.send(b"")
77+
except Exception as e:
78+
socket.send_pyobj(e)
79+
raise
7680

7781
socket.close()
7882
del buffer

tests/test_error_quit.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import random
3+
import time
4+
5+
import torch
6+
import zmq
7+
from torch.multiprocessing import Queue, get_context
8+
9+
from checkpoint_engine.ps import ParameterServer, _get_physical_gpu_id
10+
from checkpoint_engine.worker import update_weights_from_ipc
11+
12+
13+
def gen_test_tensors(rank: int) -> list[tuple[str, torch.Tensor]]:
14+
tensors = []
15+
for layer in range(random.randint(10, 50)):
16+
for num in range(random.randint(50, 100)):
17+
r = random.randint(0, 16)
18+
if r < 4:
19+
dtype = torch.bfloat16
20+
elif r < 10:
21+
dtype = torch.float16
22+
elif r < 14:
23+
dtype = torch.float8_e4m3fn
24+
else:
25+
dtype = torch.float
26+
tensors.append(
27+
(
28+
f"rank{rank}.layer{layer}.num{num}",
29+
torch.randn([random.randint(100, 500), random.randint(500, 1000)]).to(dtype),
30+
)
31+
)
32+
return tensors
33+
34+
35+
def receiver_proc_with_error(
36+
rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue
37+
):
38+
torch.cuda.set_device(rank)
39+
named_tensors = {name: tensor.cuda() for name, tensor in named_tensors.items()}
40+
_zmq_ctx = zmq.Context()
41+
42+
def trigger_error(socket_paths: list[tuple[str, str]]):
43+
socket_paths = dict(socket_paths)
44+
update_weights_from_ipc(
45+
_zmq_ctx,
46+
socket_paths[device_uuid],
47+
device_id=rank,
48+
run=error_run,
49+
post_hook=lambda: torch.cuda.synchronize(),
50+
)
51+
52+
def error_run(weights: list[tuple[str, torch.Tensor]]):
53+
weights = weights # unused
54+
time.sleep(random.uniform(0.1, 0.5))
55+
if random.random() < 0.6:
56+
raise RuntimeError("Intentional Error for testing.")
57+
58+
while True:
59+
socket_paths: list[tuple[str, str]] = queue.get()
60+
if socket_paths is None:
61+
break
62+
try:
63+
trigger_error(socket_paths)
64+
except:
65+
print(f"[rank{rank}] successfully triggered error.")
66+
raise
67+
68+
69+
def run():
70+
rank = int(os.getenv("RANK"))
71+
ctx = get_context("spawn")
72+
queue = ctx.Queue()
73+
_device_uuid = _get_physical_gpu_id(rank)
74+
ps = ParameterServer(auto_pg=True)
75+
named_tensors = dict(gen_test_tensors(rank))
76+
checkpoint_name = "test"
77+
proc = ctx.Process(
78+
target=receiver_proc_with_error, args=(rank, _device_uuid, named_tensors, queue)
79+
)
80+
proc.daemon = True
81+
proc.start()
82+
try:
83+
ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
84+
ps.gather_metas(checkpoint_name)
85+
ranks = []
86+
ps.update(checkpoint_name, queue.put, ranks=ranks)
87+
# sleep 3s to wait process group is destroyed
88+
time.sleep(3)
89+
except RuntimeError as e:
90+
print(f"[rank{rank}] Caught exception from worker process: {e}")
91+
assert isinstance(e, RuntimeError)
92+
finally:
93+
ps.unregister_checkpoint(checkpoint_name)
94+
queue.put(None)
95+
96+
97+
if __name__ == "__main__":
98+
run()

0 commit comments

Comments
 (0)