Skip to content

Commit ee271a0

Browse files
committed
add .cpu() when gathering tensor
1 parent e846b1c commit ee271a0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,9 +3666,9 @@ def save_pretrained(
36663666
gathered_state_dict = {}
36673667
for key, value in state_dict.items():
36683668
if hasattr(value, "_local_tensor"):
3669-
gathered_state_dict[key] = value.to_local()
3669+
gathered_state_dict[key] = value.to_local().cpu()
36703670
else:
3671-
gathered_state_dict[key] = value
3671+
gathered_state_dict[key] = value.cpu()
36723672

36733673
del state_dict
36743674
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)