Skip to content

Commit 3b345fa

Browse files
committed
add .cpu() when gathering tensor
1 parent 2766d1a commit 3b345fa

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2985,9 +2985,9 @@ def save_pretrained(
29852985
gathered_state_dict = {}
29862986
for key, value in state_dict.items():
29872987
if hasattr(value, "_local_tensor"):
2988-
gathered_state_dict[key] = value.to_local()
2988+
gathered_state_dict[key] = value.to_local().cpu()
29892989
else:
2990-
gathered_state_dict[key] = value
2990+
gathered_state_dict[key] = value.cpu()
29912991

29922992
del state_dict
29932993
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)