2020from stable_baselines3 .common .callbacks import BaseCallback
2121from stable_baselines3 .common .distributions import DiagGaussianDistribution
2222from stable_baselines3 .common .on_policy_algorithm import OnPolicyAlgorithm
23- from stable_baselines3 .common .type_aliases import (GymEnv , MaybeCallback ,
24- Schedule )
23+ from stable_baselines3 .common .type_aliases import GymEnv , MaybeCallback
2524from stable_baselines3 .common .utils import (explained_variance , get_linear_fn ,
2625 obs_as_tensor , safe_mean , update_learning_rate , get_schedule_fn )
2726from stable_baselines3 .common .vec_env import VecEnv
@@ -103,15 +102,17 @@ def __init__(
103102 self .normalize_advantage = normalize_advantage
104103 self .max_policy_grad_norm = max_policy_grad_norm
105104 self .max_value_grad_norm = max_value_grad_norm
106- num_rollouts = total_n_steps / (n_steps * env .num_envs )
105+ num_rollouts = total_n_steps / (n_steps * env .num_envs )
107106 total_num_updates = num_rollouts
108107 assert 'tree_optimizer' in policy_kwargs , "tree_optimizer must be a dictionary within policy_kwargs"
109- assert 'gbrl_params' in policy_kwargs ['tree_optimizer' ], "gbrl_params must be a dictionary within policy_kwargs['tree_optimizer]"
108+ assert 'params' in policy_kwargs ['tree_optimizer' ], \
109+ "params must be a dictionary within policy_kwargs['tree_optimizer]"
110110 policy_kwargs ['tree_optimizer' ]['policy_optimizer' ]['T' ] = int (total_num_updates )
111111 policy_kwargs ['tree_optimizer' ]['value_optimizer' ]['T' ] = int (total_num_updates )
112112 policy_kwargs ['tree_optimizer' ]['device' ] = device
113113 self .fixed_std = fixed_std
114- is_categorical = (hasattr (env , 'is_mixed' ) and env .is_mixed ) or (hasattr (env , 'is_categorical' ) and env .is_categorical )
114+ is_categorical = (hasattr (env , 'is_mixed' ) and env .is_mixed ) or (hasattr (env , 'is_categorical' ) and
115+ env .is_categorical )
115116 is_mixed = (hasattr (env , 'is_mixed' ) and env .is_mixed )
116117 if is_categorical :
117118 policy_kwargs ['is_categorical' ] = True
@@ -120,14 +121,14 @@ def __init__(
120121
121122 if isinstance (log_std_lr , str ):
122123 if 'lin_' in log_std_lr :
123- log_std_lr = get_linear_fn (float (log_std_lr .replace ('lin_' , '' )), min_log_std_lr , 1 )
124+ log_std_lr = get_linear_fn (float (log_std_lr .replace ('lin_' , '' )), min_log_std_lr , 1 )
124125 else :
125126 log_std_lr = float (log_std_lr )
126127 policy_kwargs ['log_std_schedule' ] = get_schedule_fn (log_std_lr )
127128 super ().__init__ (
128129 ActorCriticPolicy ,
129130 env ,
130- learning_rate = learning_rate , # does nothing for categorical output spaces
131+ learning_rate = learning_rate , # does nothing for categorical output spaces
131132 n_steps = n_steps ,
132133 gamma = gamma ,
133134 gae_lambda = gae_lambda ,
@@ -153,15 +154,15 @@ def __init__(
153154 # Update optimizer inside the policy if we want to use RMSProp
154155 # (original implementation) rather than Adam
155156
156- self .rollout_buffer_class = RolloutBuffer
157+ self .rollout_buffer_class = RolloutBuffer
157158 self .rollout_buffer_kwargs = {}
158159 if is_categorical or is_mixed :
159- self .rollout_buffer_class = CategoricalRolloutBuffer
160+ self .rollout_buffer_class = CategoricalRolloutBuffer
160161 self .rollout_buffer_kwargs ['is_mixed' ] = is_mixed
161-
162+
162163 if _init_setup_model :
163164 self .a2c_setup_model ()
164-
165+
165166 def a2c_setup_model (self ) -> None :
166167 self ._setup_lr_schedule ()
167168 self .set_random_seed (self .seed )
@@ -184,8 +185,6 @@ def a2c_setup_model(self) -> None:
184185 # pytype:enable=not-instantiable
185186 self .policy = self .policy .to (self .device )
186187
187-
188-
189188 def train (self ) -> None :
190189 """
191190 Update policy using the currently gathered
@@ -196,7 +195,8 @@ def train(self) -> None:
196195
197196 # Update optimizer learning rate
198197 if isinstance (self .policy .action_dist , DiagGaussianDistribution ):
199- update_learning_rate (self .policy .log_std_optimizer , self .policy .log_std_schedule (self ._current_progress_remaining ))
198+ update_learning_rate (self .policy .log_std_optimizer ,
199+ self .policy .log_std_schedule (self ._current_progress_remaining ))
200200
201201 policy_losses , value_losses , entropy_losses = [], [], []
202202 log_std_s = []
@@ -245,7 +245,8 @@ def train(self) -> None:
245245 theta = params [0 ]
246246 if isinstance (self .policy .action_dist , DiagGaussianDistribution ) and not self .fixed_std :
247247 if self .max_policy_grad_norm is not None and self .max_policy_grad_norm > 0.0 :
248- th .nn .utils .clip_grad_norm (self .policy .log_std , max_norm = self .max_policy_grad_norm , error_if_nonfinite = True )
248+ th .nn .utils .clip_grad_norm_ (self .policy .log_std , max_norm = self .max_policy_grad_norm ,
249+ error_if_nonfinite = True )
249250 self .policy .log_std_optimizer .step ()
250251 log_std_grad = self .policy .log_std .grad .clone ().detach ().cpu ().numpy ()
251252 self .policy .log_std_optimizer .zero_grad ()
@@ -264,14 +265,15 @@ def train(self) -> None:
264265 theta_grad_maxs .append (theta_grad .max ().item ())
265266 theta_grad_mins .append (theta_grad .min ().item ())
266267
267- explained_var = explained_variance (self .rollout_buffer .values .flatten (), self .rollout_buffer .returns .flatten ())
268+ explained_var = explained_variance (self .rollout_buffer .values .flatten (),
269+ self .rollout_buffer .returns .flatten ())
268270
269271 self ._n_updates += 1
270-
272+
271273 iteration = self .policy .model .get_iteration ()
272274 num_trees = self .policy .model .get_num_trees ()
273275 value_iteration = 0
274-
276+
275277 if isinstance (iteration , tuple ):
276278 iteration , value_iteration = iteration
277279 value_num_trees = 0
@@ -379,7 +381,8 @@ def collect_rollouts(
379381 and infos [idx ].get ("terminal_observation" ) is not None
380382 and infos [idx ].get ("TimeLimit.truncated" , False )
381383 ):
382- terminal_obs = infos [idx ]["terminal_observation" ] if self .is_categorical else self .policy .obs_to_tensor (infos [idx ]["terminal_observation" ])[0 ]
384+ terminal_obs = infos [idx ]["terminal_observation" ] if self .is_categorical else \
385+ self .policy .obs_to_tensor (infos [idx ]["terminal_observation" ])[0 ]
383386 with th .no_grad ():
384387 terminal_value = self .policy .predict_values (terminal_obs )[0 ] # type: ignore[arg-type]
385388 rewards [idx ] += self .gamma * terminal_value
@@ -397,14 +400,14 @@ def collect_rollouts(
397400
398401 with th .no_grad ():
399402 # Compute value for the last timestep
400- values = self .policy .predict_values (new_obs , requires_grad = False ) # type: ignore[arg-type] if self.is_categorical else self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
403+ values = self .policy .predict_values (new_obs , requires_grad = False )
401404
402405 rollout_buffer .compute_returns_and_advantage (last_values = values , dones = dones )
403406
404407 callback .on_rollout_end ()
405408
406409 return True
407-
410+
408411 def learn (
409412 self : SelfA2C ,
410413 total_timesteps : int ,
@@ -414,7 +417,7 @@ def learn(
414417 reset_num_timesteps : bool = True ,
415418 progress_bar : bool = False ,
416419 ) -> SelfA2C :
417-
420+
418421 iteration = 0
419422 total_timesteps , callback = self ._setup_learn (
420423 total_timesteps ,
@@ -426,7 +429,8 @@ def learn(
426429 callback .on_training_start (locals (), globals ())
427430 assert self .env is not None
428431 while self .num_timesteps < total_timesteps :
429- continue_training = self .collect_rollouts (self .env , callback , self .rollout_buffer , n_rollout_steps = self .n_steps )
432+ continue_training = self .collect_rollouts (self .env , callback , self .rollout_buffer ,
433+ n_rollout_steps = self .n_steps )
430434
431435 if continue_training is False :
432436 break
@@ -441,8 +445,10 @@ def learn(
441445 fps = int ((self .num_timesteps - self ._num_timesteps_at_start ) / time_elapsed )
442446 self .logger .record ("time/iterations" , iteration , exclude = "tensorboard" )
443447 if len (self .ep_info_buffer ) > 0 and len (self .ep_info_buffer [0 ]) > 0 :
444- self .logger .record ("rollout/ep_rew_mean" , safe_mean ([ep_info ["r" ] for ep_info in self .ep_info_buffer ]))
445- self .logger .record ("rollout/ep_len_mean" , safe_mean ([ep_info ["l" ] for ep_info in self .ep_info_buffer ]))
448+ self .logger .record ("rollout/ep_rew_mean" , safe_mean ([ep_info ["r" ] for ep_info in
449+ self .ep_info_buffer ]))
450+ self .logger .record ("rollout/ep_len_mean" , safe_mean ([ep_info ["l" ] for ep_info in
451+ self .ep_info_buffer ]))
446452 self .logger .record ("time/fps" , fps )
447453 self .logger .record ("time/time_elapsed" , int (time_elapsed ), exclude = "tensorboard" )
448454 self .logger .record ("time/total_timesteps" , self .num_timesteps , exclude = "tensorboard" )
@@ -453,10 +459,10 @@ def learn(
453459 callback .on_training_end ()
454460
455461 return self
456-
462+
457463 def save (self ,
458- path : Union [str , pathlib .Path , io .BufferedIOBase ],
459- exclude : Optional [Iterable [str ]] = None ,
460- include : Optional [Iterable [str ]] = None ,
461- ) -> None :
464+ path : Union [str , pathlib .Path , io .BufferedIOBase ],
465+ exclude : Optional [Iterable [str ]] = None ,
466+ include : Optional [Iterable [str ]] = None ,
467+ ) -> None :
462468 self .policy .model .save_model (path )
0 commit comments