Skip to content

Commit 4460137

Browse files
committed
fix bugs in gathering state dict
1 parent 648d821 commit 4460137

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2945,7 +2945,8 @@ def save_pretrained(
29452945
for key, value in state_dict.items():
29462946
if hasattr(value, "_local_tensor"):
29472947
gathered_state_dict[key] = value.to_local()
2948-
gathered_state_dict[key] = value
2948+
else:
2949+
gathered_state_dict[key] = value
29492950

29502951
del state_dict
29512952
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)