We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4a73109 commit 02a68ddCopy full SHA for 02a68dd
checkpoint_engine/ps.py
@@ -407,7 +407,13 @@ def _unpin(t: torch.Tensor):
407
del self._memory_pool[checkpoint_name]
408
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
409
# this works by using torch>=2.5.0
410
- torch._C._host_emptyCache()
+ if self.device_manager.device_type == "cuda":
411
+ torch._C._host_emptyCache()
412
+ else:
413
+ # torch._C._host_emptyCache() is not supported on NPU, so we call gc.collect() to empty host cache.
414
+ import gc
415
+
416
+ gc.collect()
417
418
def gather_metas(self, checkpoint_name: str):
419
"""
0 commit comments