Skip to content

Commit 1ac11c6

Browse files
committed
Refine statemachine
1 parent fa2dbe3 commit 1ac11c6

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

checkpoint_engine/worker.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,29 +70,35 @@ def update_weights_from_ipc(
7070
socket.send_string(msg)
7171
socket.recv() # wait for ack
7272
raise
73+
# State machine:
74+
# + receive tensor_metadata -> update_weights
75+
# + receive Exception -> raise and stop
76+
# + receive None first time -> release resources
77+
# + receive None second time -> call post_hook and stop
7378
try:
7479
released = False
7580
while True:
7681
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
77-
if payload is None: # done signal
78-
# TODO: refine stm logic, wrap all messages instead of None and Exception
79-
device_manager.device_module.synchronize()
80-
if not released:
81-
released = True
82-
buffer = None
83-
del ipc_handle
84-
85-
gc.collect()
86-
device_manager.device_module.ipc_collect()
87-
device_manager.device_module.empty_cache()
88-
device_manager.device_module.synchronize()
89-
socket.send(b"")
90-
continue
82+
if released:
83+
assert payload is None, "Should not receive any payload after released"
9184
if post_hook is not None:
9285
post_hook()
9386
device_manager.device_module.synchronize()
9487
socket.send(b"")
9588
break
89+
if payload is None: # done signal
90+
# TODO: wrap all messages to an object instead of None and Exception
91+
device_manager.device_module.synchronize()
92+
released = True
93+
buffer = None
94+
del ipc_handle
95+
96+
gc.collect()
97+
device_manager.device_module.ipc_collect()
98+
device_manager.device_module.empty_cache()
99+
device_manager.device_module.synchronize()
100+
socket.send(b"")
101+
continue
96102
if isinstance(payload, list): # still updating weights
97103
try:
98104
run(_extract_weights(payload, buffer))

0 commit comments

Comments
 (0)