Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,12 @@ def _unpin(t: torch.Tensor):
del self._memory_pool[checkpoint_name]
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
# this works by using torch>=2.5.0
torch._C._host_emptyCache()
if self.device_manager.device_type == "cuda":
torch._C._host_emptyCache()
else:
import gc

gc.collect()

def gather_metas(self, checkpoint_name: str):
"""
Expand Down