We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2d2c4f3 commit 2217e31Copy full SHA for 2217e31
src/transformers/modeling_utils.py
@@ -3378,9 +3378,9 @@ def save_pretrained(
3378
gathered_state_dict = {}
3379
for key, value in state_dict.items():
3380
if hasattr(value, "_local_tensor"):
3381
- gathered_state_dict[key] = value.to_local()
+ gathered_state_dict[key] = value.to_local().cpu()
3382
else:
3383
- gathered_state_dict[key] = value
+ gathered_state_dict[key] = value.cpu()
3384
3385
del state_dict
3386
state_dict = gathered_state_dict
0 commit comments