Skip to content

Commit f6910d6

Browse files
authored
bugfix: skip empty safetensors file in inplace_pin_memory (#79)
* misc: translate cuda error code to string when pin and unpin * bugfix: skip empty safetensors file when inplace pin memory
1 parent 02a68dd commit f6910d6

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

checkpoint_engine/pin_memory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def _pin(t: torch.Tensor):
209209
torch.cuda.set_device(device_index)
210210
cudart = torch.cuda.cudart()
211211
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
212-
assert r == 0, f"pin memory error, error code: {r}"
212+
if r != 0:
213+
error_msg = cudart.cudaGetErrorString(r)
214+
raise RuntimeError(f"pin memory error, error code: {r}, error message: {error_msg}")
213215

214216
# TODO: should only support /dev/shm? but we found files in disk also work?
215217
size = os.stat(file_path).st_size
@@ -254,6 +256,12 @@ def _pin(t: torch.Tensor):
254256
# Remove the file after successfully loading. This will avoid doubling the memory usage.
255257
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
256258
os.remove(file_path)
259+
if not metas:
260+
# TODO: should we still return this buffer?
261+
assert buffer.nbytes == 0, f"buffer nbytes {buffer.nbytes} should be 0"
262+
logger.warning(f"[rank{rank}] no metas found in {file_path}, skip pin memory")
263+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=[], manually_pinned=False)
264+
257265
_pin(buffer)
258266
logger.info(
259267
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"

checkpoint_engine/ps.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,11 @@ def _unpin(t: torch.Tensor):
391391
)
392392
cudart = torch.cuda.cudart()
393393
r = cudart.cudaHostUnregister(t.data_ptr())
394-
assert r == 0, f"unpin memory error, error code: {r}"
394+
if r != 0:
395+
error_msg = cudart.cudaGetErrorString(r)
396+
raise RuntimeError(
397+
f"unpin memory error, error code: {r}, error message: {error_msg}"
398+
)
395399

396400
# if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
397401
try:

0 commit comments

Comments
 (0)