Skip to content

Commit 2217e31

Browse files
committed
add .cpu() when gathering tensor
1 parent 2d2c4f3 commit 2217e31

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3378,9 +3378,9 @@ def save_pretrained(
33783378
gathered_state_dict = {}
33793379
for key, value in state_dict.items():
33803380
if hasattr(value, "_local_tensor"):
3381-
gathered_state_dict[key] = value.to_local()
3381+
gathered_state_dict[key] = value.to_local().cpu()
33823382
else:
3383-
gathered_state_dict[key] = value
3383+
gathered_state_dict[key] = value.cpu()
33843384

33853385
del state_dict
33863386
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)