Skip to content

Commit d0cd5ac

Browse files
authored
Merge pull request #841 from hpcaitech/oom_fix
[ckpt] mitigate gpu mem peak when loading ckpt
2 parents bc4aa4f + 5730060 commit d0cd5ac

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

opensora/utils/ckpt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ def load_checkpoint(
113113

114114
log_message(f"Loading checkpoint from {path}")
115115
if path.endswith(".safetensors"):
116-
# ckpt = load_file(path, device=str(device_map))
117-
ckpt = load_file(path, device=torch.cuda.current_device())
116+
ckpt = load_file(path, device='cpu')
118117

119118
if rename_keys is not None:
120119
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.

0 commit comments

Comments
 (0)