Skip to content

Commit e846b1c

Browse files
committed
fix bugs in gathering state dict
1 parent ae5cb30 commit e846b1c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3667,7 +3667,8 @@ def save_pretrained(
36673667
for key, value in state_dict.items():
36683668
if hasattr(value, "_local_tensor"):
36693669
gathered_state_dict[key] = value.to_local()
3670-
gathered_state_dict[key] = value
3670+
else:
3671+
gathered_state_dict[key] = value
36713672

36723673
del state_dict
36733674
state_dict = gathered_state_dict

0 commit comments

Comments
 (0)