2626import rofunc as rf
2727from rofunc .learning .RofuncRL .agents .base_agent import BaseAgent
2828from rofunc .learning .RofuncRL .agents .mixline .amp_agent import AMPAgent
29+ from rofunc .learning .RofuncRL .models .misc_models import ASEDiscEnc
2930from rofunc .learning .RofuncRL .models .base_models import BaseMLP
3031from rofunc .learning .RofuncRL .utils .memory import Memory
3132
@@ -72,6 +73,13 @@ def __init__(self,
7273 self ._enc_reward_weight = cfg .Agent .enc_reward_weight
7374
7475 '''Define ASE specific models except for AMP'''
76+ # self.discriminator = ASEDiscEnc(cfg.Model,
77+ # input_dim=amp_observation_space.shape[0],
78+ # enc_output_dim=self._ase_latent_dim,
79+ # disc_output_dim=1,
80+ # cfg_name='encoder').to(device)
81+ # self.encoder = self.discriminator
82+
7583 self .encoder = BaseMLP (cfg .Model ,
7684 input_dim = amp_observation_space .shape [0 ],
7785 output_dim = self ._ase_latent_dim ,
@@ -95,10 +103,11 @@ def __init__(self,
95103
96104 def _set_up (self ):
97105 super ()._set_up ()
98- self .optimizer_enc = torch .optim .Adam (self .encoder .parameters (), lr = self ._lr_e , eps = self ._adam_eps )
99- if self ._lr_scheduler is not None :
100- self .scheduler_enc = self ._lr_scheduler (self .optimizer_enc , ** self ._lr_scheduler_kwargs )
101- self .checkpoint_modules ["optimizer_enc" ] = self .optimizer_enc
106+ if self .encoder is not self .discriminator :
107+ self .optimizer_enc = torch .optim .Adam (self .encoder .parameters (), lr = self ._lr_e , eps = self ._adam_eps )
108+ if self ._lr_scheduler is not None :
109+ self .scheduler_enc = self ._lr_scheduler (self .optimizer_enc , ** self ._lr_scheduler_kwargs )
110+ self .checkpoint_modules ["optimizer_enc" ] = self .optimizer_enc
102111
103112 def act (self , states : torch .Tensor , deterministic : bool = False , ase_latents : torch .Tensor = None ):
104113 if self ._current_states is not None :
@@ -173,7 +182,10 @@ def update_net(self):
173182 style_rewards *= self ._discriminator_reward_scale
174183
175184 # Compute encoder reward
176- enc_output = self .encoder (self ._amp_state_preprocessor (amp_states ))
185+ if self .encoder is self .discriminator :
186+ enc_output = self .encoder .get_enc (self ._amp_state_preprocessor (amp_states ))
187+ else :
188+ enc_output = self .encoder (self ._amp_state_preprocessor (amp_states ))
177189 enc_output = torch .nn .functional .normalize (enc_output , dim = - 1 )
178190 enc_reward = torch .clamp_min (torch .sum (enc_output * ase_latents , dim = - 1 , keepdim = True ), 0.0 )
179191 enc_reward *= self ._enc_reward_scale
@@ -311,7 +323,10 @@ def update_net(self):
311323 discriminator_loss *= self ._discriminator_loss_scale
312324
313325 # encoder loss
314- enc_output = self .encoder (self ._amp_state_preprocessor (sampled_amp_states ))
326+ if self .encoder is self .discriminator :
327+ enc_output = self .encoder .get_enc (self ._amp_state_preprocessor (sampled_amp_states ))
328+ else :
329+ enc_output = self .encoder (self ._amp_state_preprocessor (sampled_amp_states_batch ))
315330 enc_output = torch .nn .functional .normalize (enc_output , dim = - 1 )
316331 enc_err = - torch .sum (enc_output * sampled_ase_latents , dim = - 1 , keepdim = True )
317332 enc_loss = torch .mean (enc_err )
@@ -357,17 +372,21 @@ def update_net(self):
357372
358373 # Update discriminator network
359374 self .optimizer_disc .zero_grad ()
360- discriminator_loss .backward ()
375+ if self .encoder is self .discriminator :
376+ (discriminator_loss + enc_loss ).backward ()
377+ else :
378+ discriminator_loss .backward ()
361379 if self ._grad_norm_clip > 0 :
362380 nn .utils .clip_grad_norm_ (self .discriminator .parameters (), self ._grad_norm_clip )
363381 self .optimizer_disc .step ()
364382
365383 # Update encoder network
366- self .optimizer_enc .zero_grad ()
367- enc_loss .backward ()
368- if self ._grad_norm_clip > 0 :
369- nn .utils .clip_grad_norm_ (self .encoder .parameters (), self ._grad_norm_clip )
370- self .optimizer_enc .step ()
384+ if self .encoder is not self .discriminator :
385+ self .optimizer_enc .zero_grad ()
386+ enc_loss .backward ()
387+ if self ._grad_norm_clip > 0 :
388+ nn .utils .clip_grad_norm_ (self .encoder .parameters (), self ._grad_norm_clip )
389+ self .optimizer_enc .step ()
371390
372391 # update cumulative losses
373392 cumulative_policy_loss += policy_loss .item ()
@@ -382,7 +401,8 @@ def update_net(self):
382401 self .scheduler_policy .step ()
383402 self .scheduler_value .step ()
384403 self .scheduler_disc .step ()
385- self .scheduler_enc .step ()
404+ if self .encoder is not self .discriminator :
405+ self .scheduler_enc .step ()
386406
387407 # update AMP replay buffer
388408 self .replay_buffer .add_samples (states = amp_states .view (- 1 , amp_states .shape [- 1 ]))
@@ -407,4 +427,5 @@ def update_net(self):
407427 self .track_data ("Learning / Learning rate (policy)" , self .scheduler_policy .get_last_lr ()[0 ])
408428 self .track_data ("Learning / Learning rate (value)" , self .scheduler_value .get_last_lr ()[0 ])
409429 self .track_data ("Learning / Learning rate (discriminator)" , self .scheduler_disc .get_last_lr ()[0 ])
410- self .track_data ("Learning / Learning rate (encoder)" , self .scheduler_enc .get_last_lr ()[0 ])
430+ if self .encoder is not self .discriminator :
431+ self .track_data ("Learning / Learning rate (encoder)" , self .scheduler_enc .get_last_lr ()[0 ])
0 commit comments