Skip to content

Commit cffc99b

Browse files
committed
feat: release ipc buffers before calling update_weights_from_ipc's post_hook
1 parent fe57396 commit cffc99b

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

checkpoint_engine/ps.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
import gc
23
import os
34
import threading
45
from collections import defaultdict
@@ -852,6 +853,29 @@ def _update_per_bucket(
852853
gidx += 1
853854

854855
socket.recv()
856+
device_mem = self.device_manager.device_module.mem_get_info()
857+
logger.info(
858+
f"[rank{self._rank}] weights broadcast done, device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, allocated memory: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, reserved memory: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
859+
)
860+
# Notify worker to release handle
861+
socket.send_pyobj(None)
862+
socket.recv()
863+
# Set to None in correct order (views first, then base tensors)
864+
del buffer_b, h2d_buffer, buffer, handle
865+
self.device_manager.device_module.synchronize()
866+
gc.collect()
867+
self.device_manager.device_module.ipc_collect()
868+
self.device_manager.device_module.empty_cache()
869+
self.device_manager.device_module.synchronize()
870+
871+
# Log actual memory usage
872+
device_mem = self.device_manager.device_module.mem_get_info()
873+
logger.info(
874+
f"[rank{self._rank}] post-release: device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, "
875+
f"allocated: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, "
876+
f"reserved: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
877+
)
878+
# Notify worker to call post_hook
855879
socket.send_pyobj(None)
856880
socket.recv()
857881
finally:

checkpoint_engine/worker.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,35 @@ def update_weights_from_ipc(
7070
socket.send_string(msg)
7171
socket.recv() # wait for ack
7272
raise
73+
# State machine:
74+
# + receive tensor_metadata -> update_weights
75+
# + receive Exception -> raise and stop
76+
# + receive None first time -> release resources
77+
# + receive None second time -> call post_hook and stop
7378
try:
79+
released = False
7480
while True:
7581
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
76-
if payload is None: # done signal
82+
if released:
83+
assert payload is None, "Should not receive any payload after released"
7784
if post_hook is not None:
7885
post_hook()
7986
device_manager.device_module.synchronize()
8087
socket.send(b"")
8188
break
89+
if payload is None: # done signal
90+
# TODO: wrap all messages to an object instead of None and Exception
91+
device_manager.device_module.synchronize()
92+
released = True
93+
buffer = None
94+
del ipc_handle
95+
96+
gc.collect()
97+
device_manager.device_module.ipc_collect()
98+
device_manager.device_module.empty_cache()
99+
device_manager.device_module.synchronize()
100+
socket.send(b"")
101+
continue
82102
if isinstance(payload, list): # still updating weights
83103
try:
84104
run(_extract_weights(payload, buffer))

tests/test_update.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Ten
9191
name: tensor.to(device_manager.device_type) for name, tensor in named_tensors.items()
9292
}
9393
_zmq_ctx = zmq.Context()
94+
mem_info = device_manager.device_module.mem_get_info()
95+
memory_usage = mem_info[1] - mem_info[0]
96+
memory_history: list[int] = [memory_usage]
9497

9598
def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor]]):
9699
for name, weight in weights:
@@ -108,6 +111,11 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
108111
run=lambda weights: check(names_to_check, weights),
109112
post_hook=lambda: device_manager.device_module.synchronize(),
110113
)
114+
device_manager.device_module.synchronize()
115+
device_manager.device_module.empty_cache()
116+
mem_info = device_manager.device_module.mem_get_info()
117+
memory_usage = mem_info[1] - mem_info[0]
118+
memory_history.append(memory_usage)
111119
assert all(names_to_check.values())
112120

113121
while True:
@@ -117,6 +125,12 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
117125
names_to_check = dict.fromkeys(named_tensors.keys(), False)
118126
check_weights(names_to_check, socket_paths)
119127

128+
mem_info = device_manager.device_module.mem_get_info()
129+
memory_usage = mem_info[1] - mem_info[0]
130+
memory_history.append(memory_usage)
131+
for memory in memory_history[1:]:
132+
print(f"[rank{rank}] Memory change: {memory - memory_history[0]}")
133+
120134

121135
def run(
122136
checker_func: callable,
@@ -318,6 +332,8 @@ def test_update_with_files(test_name: str = "test_with_files"):
318332
rank_list = json.loads(sys.argv[2])
319333
if test_type == "test_no_error":
320334
run(checker_proc, rank_list, need_error=False)
335+
mem_info = device_manager.device_module.mem_get_info()
336+
print(f"Memory usage: {mem_info[1] - mem_info[0]}")
321337
elif test_type == "test_with_remote_error":
322338
run(
323339
checker_proc_with_error,

0 commit comments

Comments
 (0)