Skip to content

Commit bbe5596

Browse files
committed
feat: release ipc buffers before calling update_weights_from_ipc's post_hook
1 parent 009082d commit bbe5596

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

checkpoint_engine/ps.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
import gc
23
import os
34
import threading
45
from collections import defaultdict
@@ -848,6 +849,29 @@ def _update_per_bucket(
848849
gidx += 1
849850

850851
socket.recv()
852+
device_mem = self.device_manager.device_module.mem_get_info()
853+
logger.info(
854+
f"[rank{self._rank}] weights broadcast done, device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, allocated memory: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, reserved memory: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
855+
)
856+
# Notify worker to release handle
857+
socket.send_pyobj(None)
858+
socket.recv()
859+
# Set to None in correct order (views first, then base tensors)
860+
del buffer_b, h2d_buffer, buffer, handle
861+
self.device_manager.device_module.synchronize()
862+
gc.collect()
863+
self.device_manager.device_module.ipc_collect()
864+
self.device_manager.device_module.empty_cache()
865+
self.device_manager.device_module.synchronize()
866+
867+
# Log actual memory usage
868+
device_mem = self.device_manager.device_module.mem_get_info()
869+
logger.info(
870+
f"[rank{self._rank}] post-release: device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, "
871+
f"allocated: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, "
872+
f"reserved: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
873+
)
874+
# Notify worker to call post_hook
851875
socket.send_pyobj(None)
852876
socket.recv()
853877
finally:

checkpoint_engine/worker.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gc
2+
import os
23
import traceback
34
from collections.abc import Callable
45
from functools import cached_property
@@ -71,9 +72,22 @@ def update_weights_from_ipc(
7172
socket.recv() # wait for ack
7273
raise
7374
try:
75+
released = False
7476
while True:
7577
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
7678
if payload is None: # done signal
79+
# TODO: refine stm logic, wrap all messages instead of None and Exception
80+
device_manager.device_module.synchronize()
81+
if not released:
82+
released = True
83+
del buffer, ipc_handle
84+
85+
gc.collect()
86+
device_manager.device_module.ipc_collect()
87+
device_manager.device_module.empty_cache()
88+
device_manager.device_module.synchronize()
89+
socket.send(b"")
90+
continue
7791
if post_hook is not None:
7892
post_hook()
7993
device_manager.device_module.synchronize()
@@ -98,7 +112,8 @@ def update_weights_from_ipc(
98112

99113
finally:
100114
socket.close()
101-
del buffer
115+
if "buffer" in locals():
116+
del buffer
102117
gc.collect()
103118
device_manager.device_module.empty_cache()
104119

tests/test_update.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Ten
9191
name: tensor.to(device_manager.device_type) for name, tensor in named_tensors.items()
9292
}
9393
_zmq_ctx = zmq.Context()
94+
mem_info = device_manager.device_module.mem_get_info()
95+
memory_usage = mem_info[1] - mem_info[0]
96+
memory_history: list[int] = [memory_usage]
9497

9598
def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor]]):
9699
for name, weight in weights:
@@ -108,6 +111,11 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
108111
run=lambda weights: check(names_to_check, weights),
109112
post_hook=lambda: device_manager.device_module.synchronize(),
110113
)
114+
device_manager.device_module.synchronize()
115+
device_manager.device_module.empty_cache()
116+
mem_info = device_manager.device_module.mem_get_info()
117+
memory_usage = mem_info[1] - mem_info[0]
118+
memory_history.append(memory_usage)
111119
assert all(names_to_check.values())
112120

113121
while True:
@@ -118,6 +126,13 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
118126
check_weights(names_to_check, socket_paths)
119127

120128

129+
mem_info = device_manager.device_module.mem_get_info()
130+
memory_usage = mem_info[1] - mem_info[0]
131+
memory_history.append(memory_usage)
132+
for memory in memory_history[1:]:
133+
print(f"[rank{rank}] Memory change: {memory - memory_history[0]}")
134+
135+
121136
def run(
122137
checker_func: callable,
123138
rank_list: list[list[int]],
@@ -318,6 +333,8 @@ def test_update_with_files(test_name: str = "test_with_files"):
318333
rank_list = json.loads(sys.argv[2])
319334
if test_type == "test_no_error":
320335
run(checker_proc, rank_list, need_error=False)
336+
mem_info = device_manager.device_module.mem_get_info()
337+
print(f"Memory usage: {mem_info[1] - mem_info[0]}")
321338
elif test_type == "test_with_remote_error":
322339
run(
323340
checker_proc_with_error,

0 commit comments

Comments
 (0)