66import time
77import warnings
88from collections import OrderedDict
9+ from collections .abc import Callable
910from pathlib import Path
1011from pprint import pprint
11- from typing import Any , Callable , Optional , Union
12+ from typing import Any
1213
1314import gymnasium as gym
1415import numpy as np
@@ -69,7 +70,7 @@ class ExperimentManager:
6970 """
7071
7172 # For special VecEnv like Brax, IsaacLab, ...
72- default_vec_env_cls : Optional [ type [VecEnv ]] = None
73+ default_vec_env_cls : type [VecEnv ] | None = None
7374
7475 def __init__ (
7576 self ,
@@ -82,19 +83,19 @@ def __init__(
8283 eval_freq : int = 10000 ,
8384 n_eval_episodes : int = 5 ,
8485 save_freq : int = - 1 ,
85- hyperparams : Optional [ dict [str , Any ]] = None ,
86- env_kwargs : Optional [ dict [str , Any ]] = None ,
87- eval_env_kwargs : Optional [ dict [str , Any ]] = None ,
86+ hyperparams : dict [str , Any ] | None = None ,
87+ env_kwargs : dict [str , Any ] | None = None ,
88+ eval_env_kwargs : dict [str , Any ] | None = None ,
8889 trained_agent : str = "" ,
8990 optimize_hyperparameters : bool = False ,
90- storage : Optional [ str ] = None ,
91- study_name : Optional [ str ] = None ,
91+ storage : str | None = None ,
92+ study_name : str | None = None ,
9293 n_trials : int = 1 ,
93- max_total_trials : Optional [ int ] = None ,
94+ max_total_trials : int | None = None ,
9495 n_jobs : int = 1 ,
9596 sampler : str = "tpe" ,
9697 pruner : str = "median" ,
97- optimization_log_path : Optional [ str ] = None ,
98+ optimization_log_path : str | None = None ,
9899 n_startup_trials : int = 0 ,
99100 n_evaluations : int = 1 ,
100101 truncate_last_trajectory : bool = False ,
@@ -106,10 +107,10 @@ def __init__(
106107 vec_env_type : str = "dummy" ,
107108 n_eval_envs : int = 1 ,
108109 no_optim_plots : bool = False ,
109- device : Union [ th .device , str ] = "auto" ,
110- config : Optional [ str ] = None ,
110+ device : th .device | str = "auto" ,
111+ config : str | None = None ,
111112 show_progress : bool = False ,
112- trial_id : Optional [ int ] = None ,
113+ trial_id : int | None = None ,
113114 ):
114115 super ().__init__ ()
115116 self .algo = algo
@@ -128,7 +129,7 @@ def __init__(
128129 self .n_timesteps = n_timesteps
129130 self .normalize = False
130131 self .normalize_kwargs : dict [str , Any ] = {}
131- self .env_wrapper : Optional [ Callable ] = None
132+ self .env_wrapper : Callable | None = None
132133 self .frame_stack = None
133134 self .seed = seed
134135 self .optimization_log_path = optimization_log_path
@@ -138,7 +139,7 @@ def __init__(
138139 if self .default_vec_env_cls is not None :
139140 self .vec_env_class = self .default_vec_env_cls
140141
141- self .vec_env_wrapper : Optional [ Callable ] = None
142+ self .vec_env_wrapper : Callable | None = None
142143
143144 self .vec_env_kwargs : dict [str , Any ] = {}
144145 # self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}
@@ -197,7 +198,7 @@ def __init__(
197198 )
198199 self .params_path = f"{ self .save_path } /{ self .env_name } "
199200
200- def setup_experiment (self ) -> Optional [ tuple [BaseAlgorithm , dict [str , Any ]]] :
201+ def setup_experiment (self ) -> tuple [BaseAlgorithm , dict [str , Any ]] | None :
201202 """
202203 Read hyperparameters, pre-process them (create schedules, wrappers, callbacks, action noise objects)
203204 create the environment and possibly the model.
@@ -361,12 +362,10 @@ def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
361362
362363 return hyperparams , saved_hyperparams
363364
364- def load_trial (
365- self , storage : str , study_name : str , trial_id : Optional [int ] = None , convert : bool = True
366- ) -> dict [str , Any ]:
365+ def load_trial (self , storage : str , study_name : str , trial_id : int | None = None , convert : bool = True ) -> dict [str , Any ]:
367366
368367 if storage .endswith (".log" ):
369- optuna_storage = optuna .storages .JournalStorage (optuna .storages .journal .JournalFileBackend (storage ))
368+ optuna_storage = optuna .storages .JournalStorage (optuna .storages .journal .JournalFileBackend (storage )) # type: ignore[attr-defined]
370369 else :
371370 optuna_storage = storage # type: ignore[assignment]
372371 study = optuna .load_study (storage = optuna_storage , study_name = study_name )
@@ -386,7 +385,7 @@ def _preprocess_schedules(hyperparams: dict[str, Any]) -> dict[str, Any]:
386385 if key not in hyperparams :
387386 continue
388387 if isinstance (hyperparams [key ], str ):
389- schedule , initial_value = hyperparams [key ].split ("_" )
388+ _schedule , initial_value = hyperparams [key ].split ("_" )
390389 initial_value = float (initial_value )
391390 hyperparams [key ] = SimpleLinearSchedule (initial_value )
392391 elif isinstance (hyperparams [key ], (float , int )):
@@ -424,7 +423,7 @@ def _preprocess_normalization(self, hyperparams: dict[str, Any]) -> dict[str, An
424423
425424 def _preprocess_hyperparams ( # noqa: C901
426425 self , hyperparams : dict [str , Any ]
427- ) -> tuple [dict [str , Any ], Optional [ Callable ] , list [BaseCallback ], Optional [ Callable ] ]:
426+ ) -> tuple [dict [str , Any ], Callable | None , list [BaseCallback ], Callable | None ]:
428427 self .n_envs = hyperparams .get ("n_envs" , 1 )
429428
430429 if self .verbose > 0 :
@@ -891,7 +890,7 @@ def hyperparameters_optimization(self) -> None:
891890 # Create folder if it doesn't exist
892891 Path (storage ).parent .mkdir (parents = True , exist_ok = True )
893892 storage = optuna .storages .JournalStorage ( # type: ignore[assignment]
894- optuna .storages .journal .JournalFileBackend (storage ),
893+ optuna .storages .journal .JournalFileBackend (storage ), # type: ignore[attr-defined]
895894 )
896895
897896 if self .verbose > 0 :
0 commit comments