Skip to content

Commit 0bcd31f

Browse files
authored
Backend PyTorch: Model.restore supports restoring model to a specified device (#1224)
1 parent 4733e0e commit 0bcd31f

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

deepxde/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,21 +1014,29 @@ def save(self, save_path, protocol="backend", verbose=0):
10141014
)
10151015
return save_path
10161016

1017-
def restore(self, save_path, verbose=0):
1017+
def restore(self, save_path, device=None, verbose=0):
10181018
"""Restore all variables from a disk file.
10191019
10201020
Args:
10211021
save_path (string): Path where model was previously saved.
1022+
device (string, optional): Device to load the model on (e.g. "cpu","cuda:0"...). By default, the model is loaded on the device it was saved from.
10221023
"""
10231024
# TODO: backend tensorflow
1025+
if device is not None and backend_name != "pytorch":
1026+
print(
1027+
"Warning: device is only supported for backend pytorch. Model will be loaded on the device it was saved from."
1028+
)
10241029
if verbose > 0:
10251030
print("Restoring model from {} ...\n".format(save_path))
10261031
if backend_name == "tensorflow.compat.v1":
10271032
self.saver.restore(self.sess, save_path)
10281033
elif backend_name == "tensorflow":
10291034
self.net.load_weights(save_path)
10301035
elif backend_name == "pytorch":
1031-
checkpoint = torch.load(save_path)
1036+
if device is not None:
1037+
checkpoint = torch.load(save_path, map_location=torch.device(device))
1038+
else:
1039+
checkpoint = torch.load(save_path)
10321040
self.net.load_state_dict(checkpoint["model_state_dict"])
10331041
self.opt.load_state_dict(checkpoint["optimizer_state_dict"])
10341042
elif backend_name == "paddle":

0 commit comments

Comments
 (0)