We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4460137 commit 45866d4Copy full SHA for 45866d4
src/transformers/modeling_utils.py
@@ -2944,9 +2944,9 @@ def save_pretrained(
2944
gathered_state_dict = {}
2945
for key, value in state_dict.items():
2946
if hasattr(value, "_local_tensor"):
2947
- gathered_state_dict[key] = value.to_local()
+ gathered_state_dict[key] = value.to_local().cpu()
2948
else:
2949
- gathered_state_dict[key] = value
+ gathered_state_dict[key] = value.cpu()
2950
2951
del state_dict
2952
state_dict = gathered_state_dict
0 commit comments