77from collections import OrderedDict
88from pathlib import Path
99from pprint import pprint
10- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
10+ from typing import Any , Callable , Optional , Union
1111
1212import gymnasium as gym
1313import numpy as np
@@ -71,9 +71,9 @@ def __init__(
7171 eval_freq : int = 10000 ,
7272 n_eval_episodes : int = 5 ,
7373 save_freq : int = - 1 ,
74- hyperparams : Optional [Dict [str , Any ]] = None ,
75- env_kwargs : Optional [Dict [str , Any ]] = None ,
76- eval_env_kwargs : Optional [Dict [str , Any ]] = None ,
74+ hyperparams : Optional [dict [str , Any ]] = None ,
75+ env_kwargs : Optional [dict [str , Any ]] = None ,
76+ eval_env_kwargs : Optional [dict [str , Any ]] = None ,
7777 trained_agent : str = "" ,
7878 optimize_hyperparameters : bool = False ,
7979 storage : Optional [str ] = None ,
@@ -112,10 +112,10 @@ def __init__(
112112 default_path = Path (__file__ ).parent .parent
113113
114114 self .config = config or str (default_path / f"hyperparams/{ self .algo } .yml" )
115- self .env_kwargs : Dict [str , Any ] = env_kwargs or {}
115+ self .env_kwargs : dict [str , Any ] = env_kwargs or {}
116116 self .n_timesteps = n_timesteps
117117 self .normalize = False
118- self .normalize_kwargs : Dict [str , Any ] = {}
118+ self .normalize_kwargs : dict [str , Any ] = {}
119119 self .env_wrapper : Optional [Callable ] = None
120120 self .frame_stack = None
121121 self .seed = seed
@@ -124,23 +124,23 @@ def __init__(
124124 self .vec_env_class = {"dummy" : DummyVecEnv , "subproc" : SubprocVecEnv }[vec_env_type ]
125125 self .vec_env_wrapper : Optional [Callable ] = None
126126
127- self .vec_env_kwargs : Dict [str , Any ] = {}
127+ self .vec_env_kwargs : dict [str , Any ] = {}
128128 # self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}
129129
130130 # Callbacks
131- self .specified_callbacks : List = []
132- self .callbacks : List [BaseCallback ] = []
131+ self .specified_callbacks : list = []
132+ self .callbacks : list [BaseCallback ] = []
133133 # Use env-kwargs if eval_env_kwargs was not specified
134- self .eval_env_kwargs : Dict [str , Any ] = eval_env_kwargs or self .env_kwargs
134+ self .eval_env_kwargs : dict [str , Any ] = eval_env_kwargs or self .env_kwargs
135135 self .save_freq = save_freq
136136 self .eval_freq = eval_freq
137137 self .n_eval_episodes = n_eval_episodes
138138 self .n_eval_envs = n_eval_envs
139139
140140 self .n_envs = 1 # it will be updated when reading hyperparams
141141 self .n_actions = 0 # For DDPG/TD3 action noise objects
142- self ._hyperparams : Dict [str , Any ] = {}
143- self .monitor_kwargs : Dict [str , Any ] = {}
142+ self ._hyperparams : dict [str , Any ] = {}
143+ self .monitor_kwargs : dict [str , Any ] = {}
144144
145145 self .trained_agent = trained_agent
146146 self .continue_training = trained_agent .endswith (".zip" ) and os .path .isfile (trained_agent )
@@ -179,7 +179,7 @@ def __init__(
179179 )
180180 self .params_path = f"{ self .save_path } /{ self .env_name } "
181181
182- def setup_experiment (self ) -> Optional [Tuple [BaseAlgorithm , Dict [str , Any ]]]:
182+ def setup_experiment (self ) -> Optional [tuple [BaseAlgorithm , dict [str , Any ]]]:
183183 """
184184 Read hyperparameters, pre-process them (create schedules, wrappers, callbacks, action noise objects)
185185 create the environment and possibly the model.
@@ -223,7 +223,7 @@ def learn(self, model: BaseAlgorithm) -> None:
223223 """
224224 :param model: an initialized RL model
225225 """
226- kwargs : Dict [str , Any ] = {}
226+ kwargs : dict [str , Any ] = {}
227227 if self .log_interval > - 1 :
228228 kwargs = {"log_interval" : self .log_interval }
229229
@@ -272,7 +272,7 @@ def save_trained_model(self, model: BaseAlgorithm) -> None:
272272 assert vec_normalize is not None
273273 vec_normalize .save (os .path .join (self .params_path , "vecnormalize.pkl" ))
274274
275- def _save_config (self , saved_hyperparams : Dict [str , Any ]) -> None :
275+ def _save_config (self , saved_hyperparams : dict [str , Any ]) -> None :
276276 """
277277 Save unprocessed hyperparameters, this can be use later
278278 to reproduce an experiment.
@@ -290,15 +290,15 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
290290
291291 print (f"Log path: { self .save_path } " )
292292
293- def read_hyperparameters (self ) -> Tuple [ Dict [str , Any ], Dict [str , Any ]]:
293+ def read_hyperparameters (self ) -> tuple [ dict [str , Any ], dict [str , Any ]]:
294294 print (f"Loading hyperparameters from: { self .config } " )
295295
296296 if self .config .endswith (".yml" ) or self .config .endswith (".yaml" ):
297297 # Load hyperparameters from yaml file
298298 with open (self .config ) as f :
299299 hyperparams_dict = yaml .safe_load (f )
300300 elif self .config .endswith (".py" ):
301- global_variables : Dict = {}
301+ global_variables : dict = {}
302302 # Load hyperparameters from python file
303303 exec (Path (self .config ).read_text (), global_variables )
304304 hyperparams_dict = global_variables ["hyperparams" ]
@@ -327,7 +327,7 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
327327 return hyperparams , saved_hyperparams
328328
329329 @staticmethod
330- def _preprocess_schedules (hyperparams : Dict [str , Any ]) -> Dict [str , Any ]:
330+ def _preprocess_schedules (hyperparams : dict [str , Any ]) -> dict [str , Any ]:
331331 # Create schedules
332332 for key in ["learning_rate" , "clip_range" , "clip_range_vf" , "delta_std" ]:
333333 if key not in hyperparams :
@@ -345,7 +345,7 @@ def _preprocess_schedules(hyperparams: Dict[str, Any]) -> Dict[str, Any]:
345345 raise ValueError (f"Invalid value for { key } : { hyperparams [key ]} " )
346346 return hyperparams
347347
348- def _preprocess_normalization (self , hyperparams : Dict [str , Any ]) -> Dict [str , Any ]:
348+ def _preprocess_normalization (self , hyperparams : dict [str , Any ]) -> dict [str , Any ]:
349349 if "normalize" in hyperparams .keys ():
350350 self .normalize = hyperparams ["normalize" ]
351351
@@ -370,8 +370,8 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An
370370 return hyperparams
371371
372372 def _preprocess_hyperparams ( # noqa: C901
373- self , hyperparams : Dict [str , Any ]
374- ) -> Tuple [ Dict [str , Any ], Optional [Callable ], List [BaseCallback ], Optional [Callable ]]:
373+ self , hyperparams : dict [str , Any ]
374+ ) -> tuple [ dict [str , Any ], Optional [Callable ], list [BaseCallback ], Optional [Callable ]]:
375375 self .n_envs = hyperparams .get ("n_envs" , 1 )
376376
377377 if self .verbose > 0 :
@@ -448,8 +448,8 @@ def _preprocess_hyperparams( # noqa: C901
448448 return hyperparams , env_wrapper , callbacks , vec_env_wrapper
449449
450450 def _preprocess_action_noise (
451- self , hyperparams : Dict [str , Any ], saved_hyperparams : Dict [str , Any ], env : VecEnv
452- ) -> Dict [str , Any ]:
451+ self , hyperparams : dict [str , Any ], saved_hyperparams : dict [str , Any ], env : VecEnv
452+ ) -> dict [str , Any ]:
453453 # Parse noise string
454454 # Note: only off-policy algorithms are supported
455455 if hyperparams .get ("noise_type" ) is not None :
@@ -667,7 +667,7 @@ def make_env(**kwargs) -> gym.Env:
667667
668668 return env
669669
670- def _load_pretrained_agent (self , hyperparams : Dict [str , Any ], env : VecEnv ) -> BaseAlgorithm :
670+ def _load_pretrained_agent (self , hyperparams : dict [str , Any ], env : VecEnv ) -> BaseAlgorithm :
671671 # Continue training
672672 print ("Loading pretrained agent" )
673673 # Policy should not be changed
0 commit comments