Skip to content

Commit 2d2c4f3

Browse files
committed
fix bugs in gathering state dict
1 parent 5ac8f76 commit 2d2c4f3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3379,7 +3379,8 @@ def save_pretrained(
33793379
for key, value in state_dict.items():
33803380
if hasattr(value, "_local_tensor"):
33813381
gathered_state_dict[key] = value.to_local()
3382-
gathered_state_dict[key] = value
3382+
else:
3383+
gathered_state_dict[key] = value
33833384

33843385
del state_dict
33853386
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)