@@ -734,42 +734,63 @@ def Q_update(self, recurrent=True, monte_carlo=False, policy=True, verbose=False
734
734
self .update_target_network (source = self .Q2_network , target = self .Q2_target , tau = self .polyak )
735
735
self .update_target_network (source = self .policy_network , target = self .policy_target , tau = self .polyak )
736
736
737
- def save_network (self , save_path ):
738
- '''
739
- Saves networks to directory specified by save_path
740
- :param save_path: directory to save networks to
741
- '''
742
-
743
- torch .save (self .policy_network , os .path .join (save_path , "policy_network.pth" ))
744
- torch .save (self .Q1_network , os .path .join (save_path , "Q1_network.pth" ))
745
- torch .save (self .Q2_network , os .path .join (save_path , "Q2_network.pth" ))
746
-
747
- torch .save (self .policy_target , os .path .join (save_path , "policy_target.pth" ))
748
- torch .save (self .Q1_target , os .path .join (save_path , "Q1_target.pth" ))
749
- torch .save (self .Q2_target , os .path .join (save_path , "Q2_target.pth" ))
750
-
751
- def load_network (self , load_path , load_target_networks = False ):
752
- '''
753
- Loads netoworks from directory specified by load_path.
754
- :param load_path: directory to load networks from
755
- :param load_target_networks: whether to load target networks
756
- '''
757
-
758
- self .policy_network = torch .load (os .path .join (load_path , "policy_network.pth" ))
759
- self .policy_network_opt = Adam (self .policy_network .parameters (), lr = self .pol_learning_rate )
760
-
761
- self .Q1_network = torch .load (os .path .join (load_path , "Q1_network.pth" ))
762
- self .Q1_network_opt = Adam (self .Q1_network .parameters (), lr = self .val_learning_rate )
763
-
764
- self .Q2_network = torch .load (os .path .join (load_path , "Q2_network.pth" ))
765
- self .Q2_etwork_opt = Adam (self .Q2_network .parameters (), lr = self .val_learning_rate )
766
-
737
+ def save_ckpt (self , save_path , additional_info = None ):
738
+ '''
739
+ Creates a full checkpoint (networks, optimizers, memory buffers) and saves it to the specified path.
740
+ :param save_path: path to save the checkpoint to
741
+ :param additional_info: additional information to save (Python dictionary)
742
+ '''
743
+ ckpt = {
744
+ "policy_network" : self .policy_network .state_dict (),
745
+ "Q1_network" : self .Q1_network .state_dict (),
746
+ "Q2_network" : self .Q2_network .state_dict (),
747
+ "policy_target" : self .policy_target .state_dict (),
748
+ "Q1_target" : self .Q1_target .state_dict (),
749
+ "Q2_target" : self .Q2_target .state_dict (),
750
+ "policy_network_opt" : self .policy_network_opt .state_dict (),
751
+ "Q1_network_opt" : self .Q1_network_opt .state_dict (),
752
+ "Q2_network_opt" : self .Q2_network_opt .state_dict (),
753
+ "additional_info" : additional_info if additional_info is not None else {},
754
+ }
755
+
756
+ ### save buffers
757
+ for buffer in ("memory" , "values" , "states" , "next_states" , "actions" , "rewards" , "dones" ,
758
+ "sequences" , "next_sequences" , "all_returns" ):
759
+ ckpt [buffer ] = getattr (self , buffer )
760
+
761
+ ### save the checkpoint
762
+ torch .save (ckpt , save_path )
763
+
764
+ def load_ckpt (self , load_path , load_target_networks = True ):
765
+ '''
766
+ Loads a full checkpoint (networks, optimizers, memory buffers) from the specified path.
767
+ :param load_path: path to load the checkpoint from
768
+ :param load_target_networks: whether to load the target networks as well
769
+ '''
770
+ ckpt = torch .load (load_path )
771
+
772
+ ### load networks
773
+ self .policy_network .load_state_dict (ckpt ["policy_network" ])
774
+ self .Q1_network .load_state_dict (ckpt ["Q1_network" ])
775
+ self .Q2_network .load_state_dict (ckpt ["Q2_network" ])
776
+
777
+ ### load target networks
767
778
if load_target_networks :
768
- self .policy_target = torch .load (os .path .join (load_path , "policy_target.pth" ))
769
- self .Q1_target = torch .load (os .path .join (load_path , "Q1_target.pth" ))
770
- self .Q2_target = torch .load (os .path .join (load_path , "Q2_target.pth" ))
771
- else :
772
- print ("[WARNING] Not loading target networks" )
779
+ self .policy_target .load_state_dict (ckpt ["policy_target" ])
780
+ self .Q1_target .load_state_dict (ckpt ["Q1_target" ])
781
+ self .Q2_target .load_state_dict (ckpt ["Q2_target" ])
782
+
783
+ ### load optimizers
784
+ self .policy_network_opt .load_state_dict (ckpt ["policy_network_opt" ])
785
+ self .Q1_network_opt .load_state_dict (ckpt ["Q1_network_opt" ])
786
+ self .Q2_network_opt .load_state_dict (ckpt ["Q2_network_opt" ])
787
+
788
+ ### load buffers
789
+ for buffer in ("memory" , "values" , "states" , "next_states" , "actions" , "rewards" , "dones" ,
790
+ "sequences" , "next_sequences" , "all_returns" ):
791
+ setattr (self , buffer , ckpt [buffer ])
792
+
793
+ return ckpt
773
794
774
795
def reset_weights (self , policy = True ):
775
796
'''
0 commit comments