diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 4990575..20f5be6 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -407,7 +407,13 @@ def _unpin(t: torch.Tensor): del self._memory_pool[checkpoint_name] # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018 # this works by using torch>=2.5.0 - torch._C._host_emptyCache() + if self.device_manager.device_type == "cuda": + torch._C._host_emptyCache() + else: + # torch._C._host_emptyCache() is not supported on NPU, so we call gc.collect() to empty host cache. + import gc + + gc.collect() def gather_metas(self, checkpoint_name: str): """