Skip to content

Commit 0a62449

Browse files
specture724Copilot
authored andcommitted
hotfix: add a switch to disable inplace pinning of tensors (#68)
* feat: add a switch to disable inplace pinning of tensors * Update checkpoint_engine/ps.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Anjie Hou <149605198+specture724@users.noreply.github.com> * doc --------- Signed-off-by: Anjie Hou <149605198+specture724@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 089d185 commit 0a62449

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

checkpoint_engine/ps.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -622,19 +622,25 @@ def _register_checkpoint(
622622
named_tensors: dict[str, torch.Tensor],
623623
rank: int | None = None,
624624
shared_pin_memory: list[MemoryBuffer] | None = None,
625+
inplace_pin: bool = False,
625626
) -> list[MemoryBuffer]:
626627
logger.info(
627628
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
628629
)
629630
if not files and not named_tensors:
630631
return []
631632
memory_buffers: list[MemoryBuffer] = []
632-
files_to_inplace_pin = [
633-
file
634-
for file in files
635-
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636-
]
637-
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
633+
if inplace_pin:
634+
logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
635+
files_to_inplace_pin = [
636+
file
637+
for file in files
638+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
639+
]
640+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
641+
else:
642+
files_to_normal_pin = files
643+
files_to_inplace_pin = []
638644
if files_to_normal_pin or named_tensors:
639645
memory_buffers.extend(
640646
_normal_pin_memory(
@@ -973,10 +979,11 @@ def register_checkpoint(
973979
files: list[str] | None = None,
974980
named_tensors: dict[str, torch.Tensor] | None = None,
975981
use_shared_memory_pool: bool = False,
982+
use_inplace_pin_memory: bool = False,
976983
) -> None:
977984
"""
978985
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
979-
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
986+
Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
980987
Please make sure to copy the files to disks if you need to keep them.
981988
982989
Args:
@@ -988,6 +995,8 @@ def register_checkpoint(
988995
cannot accommodate checkpoints with different memory requirements.
989996
To free the actual memory of the shared pool or to modify its shape,
990997
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
998+
use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
999+
Currently, this feature is experimental and may crash.
9911000
"""
9921001
try:
9931002
if use_shared_memory_pool:
@@ -1016,7 +1025,10 @@ def register_checkpoint(
10161025
f"checkpoint {checkpoint_name} already registered"
10171026
)
10181027
self._memory_pool[checkpoint_name] = _register_checkpoint(
1019-
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
1028+
files=files or [],
1029+
named_tensors=named_tensors or {},
1030+
rank=self._rank,
1031+
inplace_pin=use_inplace_pin_memory,
10201032
)
10211033
if self._p2p_store is not None:
10221034
self._register_parameters_to_p2p_store(checkpoint_name)

tests/test_update.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ def run_with_files(
218218
if rank == 0:
219219
import shutil
220220

221-
os.removedirs(dev_shm_dir)
221+
# this test should be run under use_inplace_pin_memory=False. Otherwise, the files in /dev/shm/ will be deleted.
222+
shutil.rmtree(dev_shm_dir)
222223
shutil.rmtree(disk_dir)
223224
assert proc.exitcode == 0
224225

0 commit comments

Comments
 (0)