@@ -70,29 +70,35 @@ def update_weights_from_ipc(
7070 socket .send_string (msg )
7171 socket .recv () # wait for ack
7272 raise
73+ # State machine:
74+ # + receive tensor_metadata -> update_weights
75+ # + receive Exception -> raise and stop
76+ # + receive None first time -> release resources
77+ # + receive None second time -> call post_hook and stop
7378 try :
7479 released = False
7580 while True :
7681 payload : list [FlattenedTensorMetadata ] | Exception | None = socket .recv_pyobj ()
77- if payload is None : # done signal
78- # TODO: refine stm logic, wrap all messages instead of None and Exception
79- device_manager .device_module .synchronize ()
80- if not released :
81- released = True
82- buffer = None
83- del ipc_handle
84-
85- gc .collect ()
86- device_manager .device_module .ipc_collect ()
87- device_manager .device_module .empty_cache ()
88- device_manager .device_module .synchronize ()
89- socket .send (b"" )
90- continue
82+ if released :
83+ assert payload is None , "Should not receive any payload after released"
9184 if post_hook is not None :
9285 post_hook ()
9386 device_manager .device_module .synchronize ()
9487 socket .send (b"" )
9588 break
89+ if payload is None : # done signal
90+ # TODO: wrap all messages to an object instead of None and Exception
91+ device_manager .device_module .synchronize ()
92+ released = True
93+ buffer = None
94+ del ipc_handle
95+
96+ gc .collect ()
97+ device_manager .device_module .ipc_collect ()
98+ device_manager .device_module .empty_cache ()
99+ device_manager .device_module .synchronize ()
100+ socket .send (b"" )
101+ continue
96102 if isinstance (payload , list ): # still updating weights
97103 try :
98104 run (_extract_weights (payload , buffer ))
0 commit comments