Skip to content

Commit 45866d4

Browse files
authored
add .cpu() when gathering tensor
1 parent 4460137 commit 45866d4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,9 +2944,9 @@ def save_pretrained(
29442944
gathered_state_dict = {}
29452945
for key, value in state_dict.items():
29462946
if hasattr(value, "_local_tensor"):
2947-
gathered_state_dict[key] = value.to_local()
2947+
gathered_state_dict[key] = value.to_local().cpu()
29482948
else:
2949-
gathered_state_dict[key] = value
2949+
gathered_state_dict[key] = value.cpu()
29502950

29512951
del state_dict
29522952
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)