We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2fc7dff commit 9c31402Copy full SHA for 9c31402
src/transformers/modeling_utils.py
@@ -3409,9 +3409,9 @@ def save_pretrained(
3409
gathered_state_dict = {}
3410
for key, value in state_dict.items():
3411
if hasattr(value, "_local_tensor"):
3412
- gathered_state_dict[key] = value.to_local()
+ gathered_state_dict[key] = value.to_local().cpu()
3413
else:
3414
- gathered_state_dict[key] = value
+ gathered_state_dict[key] = value.cpu()
3415
3416
del state_dict
3417
state_dict = gathered_state_dict
0 commit comments