@@ -1014,21 +1014,29 @@ def save(self, save_path, protocol="backend", verbose=0):
10141014 )
10151015 return save_path
10161016
1017- def restore (self , save_path , verbose = 0 ):
1017+ def restore (self , save_path , device = None , verbose = 0 ):
10181018 """Restore all variables from a disk file.
10191019
10201020 Args:
10211021 save_path (string): Path where model was previously saved.
1022+ device (string, optional): Device to load the model on (e.g. "cpu","cuda:0"...). By default, the model is loaded on the device it was saved from.
10221023 """
10231024 # TODO: backend tensorflow
1025+ if device is not None and backend_name != "pytorch" :
1026+ print (
1027+ "Warning: device is only supported for backend pytorch. Model will be loaded on the device it was saved from."
1028+ )
10241029 if verbose > 0 :
10251030 print ("Restoring model from {} ...\n " .format (save_path ))
10261031 if backend_name == "tensorflow.compat.v1" :
10271032 self .saver .restore (self .sess , save_path )
10281033 elif backend_name == "tensorflow" :
10291034 self .net .load_weights (save_path )
10301035 elif backend_name == "pytorch" :
1031- checkpoint = torch .load (save_path )
1036+ if device is not None :
1037+ checkpoint = torch .load (save_path , map_location = torch .device (device ))
1038+ else :
1039+ checkpoint = torch .load (save_path )
10321040 self .net .load_state_dict (checkpoint ["model_state_dict" ])
10331041 self .opt .load_state_dict (checkpoint ["optimizer_state_dict" ])
10341042 elif backend_name == "paddle" :
0 commit comments