We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bc6c907 commit dedaa12Copy full SHA for dedaa12
src/transformers/modeling_utils.py
@@ -2985,9 +2985,9 @@ def save_pretrained(
2985
gathered_state_dict = {}
2986
for key, value in state_dict.items():
2987
if hasattr(value, "_local_tensor"):
2988
- gathered_state_dict[key] = value.to_local()
+ gathered_state_dict[key] = value.to_local().cpu()
2989
else:
2990
- gathered_state_dict[key] = value
+ gathered_state_dict[key] = value.cpu()
2991
2992
del state_dict
2993
state_dict = gathered_state_dict
0 commit comments