Skip to content

Commit 02a68dd

Browse files
authored
fix: npu free host cache (#78)
1 parent 4a73109 commit 02a68dd

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

checkpoint_engine/ps.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,13 @@ def _unpin(t: torch.Tensor):
407407
del self._memory_pool[checkpoint_name]
408408
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
409409
# this works by using torch>=2.5.0
410-
torch._C._host_emptyCache()
410+
if self.device_manager.device_type == "cuda":
411+
torch._C._host_emptyCache()
412+
else:
413+
# torch._C._host_emptyCache() is not supported on NPU, so we call gc.collect() to empty host cache.
414+
import gc
415+
416+
gc.collect()
411417

412418
def gather_metas(self, checkpoint_name: str):
413419
"""

0 commit comments

Comments
 (0)