Skip to content

Commit 9c31402

Browse files
committed
add .cpu() when gathering tensor
1 parent 2fc7dff commit 9c31402

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3409,9 +3409,9 @@ def save_pretrained(
34093409
gathered_state_dict = {}
34103410
for key, value in state_dict.items():
34113411
if hasattr(value, "_local_tensor"):
3412-
gathered_state_dict[key] = value.to_local()
3412+
gathered_state_dict[key] = value.to_local().cpu()
34133413
else:
3414-
gathered_state_dict[key] = value
3414+
gathered_state_dict[key] = value.cpu()
34153415

34163416
del state_dict
34173417
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)