We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e846b1c commit ee271a0Copy full SHA for ee271a0
src/transformers/modeling_utils.py
@@ -3666,9 +3666,9 @@ def save_pretrained(
3666
gathered_state_dict = {}
3667
for key, value in state_dict.items():
3668
if hasattr(value, "_local_tensor"):
3669
- gathered_state_dict[key] = value.to_local()
+ gathered_state_dict[key] = value.to_local().cpu()
3670
else:
3671
- gathered_state_dict[key] = value
+ gathered_state_dict[key] = value.cpu()
3672
3673
del state_dict
3674
state_dict = gathered_state_dict
0 commit comments