4949# Register custom envs
5050import rl_zoo3 .import_envs # noqa: F401
5151from rl_zoo3 .callbacks import SaveVecNormalizeCallback , TrialEvalCallback
52- from rl_zoo3 .hyperparams_opt import HYPERPARAMS_SAMPLER
52+ from rl_zoo3 .hyperparams_opt import HYPERPARAMS_CONVERTER , HYPERPARAMS_SAMPLER
5353from rl_zoo3 .utils import ALGOS , get_callback_list , get_class_by_name , get_latest_run_id , get_wrapper_class , linear_schedule
5454
5555
@@ -102,6 +102,7 @@ def __init__(
102102 device : Union [th .device , str ] = "auto" ,
103103 config : Optional [str ] = None ,
104104 show_progress : bool = False ,
105+ trial_id : Optional [int ] = None ,
105106 ):
106107 super ().__init__ ()
107108 self .algo = algo
@@ -160,6 +161,8 @@ def __init__(
160161 self .storage = storage
161162 self .study_name = study_name
162163 self .no_optim_plots = no_optim_plots
164+ # For loading hyperparams from a study
165+ self .trial_id = trial_id
163166 # maximum number of trials for finding the best hyperparams
164167 self .n_trials = n_trials
165168 self .max_total_trials = max_total_trials
@@ -334,6 +337,11 @@ def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
334337 else :
335338 raise ValueError (f"Hyperparameters not found for { self .algo } -{ self .env_name .gym_id } in { self .config } " )
336339
340+ if self .storage and self .study_name and self .trial_id :
341+ print ("Loading from Optuna study..." )
342+ study_hyperparams = self .load_trial (self .storage , self .study_name , self .trial_id )
343+ hyperparams .update (study_hyperparams )
344+
337345 if self .custom_hyperparams is not None :
338346 # Overwrite hyperparams if needed
339347 hyperparams .update (self .custom_hyperparams )
@@ -346,6 +354,24 @@ def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
346354
347355 return hyperparams , saved_hyperparams
348356
357+ def load_trial (
358+ self , storage : str , study_name : str , trial_id : Optional [int ] = None , convert : bool = True
359+ ) -> dict [str , Any ]:
360+
361+ if storage .endswith (".log" ):
362+ optuna_storage = optuna .storages .JournalStorage (optuna .storages .journal .JournalFileBackend (storage ))
363+ else :
364+ optuna_storage = storage # type: ignore[assignment]
365+ study = optuna .load_study (storage = optuna_storage , study_name = study_name )
366+ if trial_id is not None :
367+ params = study .trials [trial_id ].params
368+ else :
369+ params = study .best_trial .params
370+
371+ if convert :
372+ return HYPERPARAMS_CONVERTER [self .algo ](params )
373+ return params
374+
349375 @staticmethod
350376 def _preprocess_schedules (hyperparams : dict [str , Any ]) -> dict [str , Any ]:
351377 # Create schedules
@@ -470,6 +496,10 @@ def _preprocess_hyperparams( # noqa: C901
470496 def _preprocess_action_noise (
471497 self , hyperparams : dict [str , Any ], saved_hyperparams : dict [str , Any ], env : VecEnv
472498 ) -> dict [str , Any ]:
499+ # Compute n_actions for hyperparameter optim
500+ if isinstance (env .action_space , spaces .Box ):
501+ self .n_actions = env .action_space .shape [0 ]
502+
473503 # Parse noise string
474504 # Note: only off-policy algorithms are supported
475505 if hyperparams .get ("noise_type" ) is not None :
@@ -480,7 +510,6 @@ def _preprocess_action_noise(
480510 assert isinstance (
481511 env .action_space , spaces .Box
482512 ), f"Action noise can only be used with Box action space, not { env .action_space } "
483- self .n_actions = env .action_space .shape [0 ]
484513
485514 if "normal" in noise_type :
486515 hyperparams ["action_noise" ] = NormalActionNoise (
@@ -619,11 +648,9 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
619648 log_dir = None if eval_env or no_log else self .save_path
620649
621650 # Special case for GoalEnvs: log success rate too
622- if (
623- "Neck" in self .env_name .gym_id
624- or self .is_robotics_env (self .env_name .gym_id )
625- or ("parking-v0" in self .env_name .gym_id and len (self .monitor_kwargs ) == 0 ) # do not overwrite custom kwargs
626- ):
651+ if self .is_robotics_env (self .env_name .gym_id ) or (
652+ "parking-v0" in self .env_name .gym_id and len (self .monitor_kwargs ) == 0
653+ ): # do not overwrite custom kwargs
627654 self .monitor_kwargs = dict (info_keywords = ("is_success" ,))
628655
629656 spec = gym .spec (self .env_name .gym_id )
@@ -722,13 +749,10 @@ def _create_sampler(self, sampler_method: str) -> BaseSampler:
722749 sampler : BaseSampler = RandomSampler (seed = self .seed )
723750 elif sampler_method == "tpe" :
724751 sampler = TPESampler (n_startup_trials = self .n_startup_trials , seed = self .seed , multivariate = True )
725- elif sampler_method == "skopt " :
726- from optuna . integration . skopt import SkoptSampler
752+ elif sampler_method == "auto " :
753+ import optunahub
727754
728- # cf https://scikit-optimize.github.io/#skopt.Optimizer
729- # GP: gaussian process
730- # Gradient boosted regression: GBRT
731- sampler = SkoptSampler (skopt_kwargs = {"base_estimator" : "GP" , "acq_func" : "gp_hedge" })
755+ sampler = optunahub .load_module ("samplers/auto_sampler" ).AutoSampler (seed = self .seed )
732756 else :
733757 raise ValueError (f"Unknown sampler: { sampler_method } " )
734758 return sampler
@@ -854,14 +878,22 @@ def hyperparameters_optimization(self) -> None:
854878 # TODO: eval each hyperparams several times to account for noisy evaluation
855879 sampler = self ._create_sampler (self .sampler )
856880 pruner = self ._create_pruner (self .pruner )
881+ # Log file storage
882+ storage = self .storage
883+ if storage is not None and storage .endswith (".log" ):
884+ # Create folder if it doesn't exist
885+ Path (storage ).parent .mkdir (parents = True , exist_ok = True )
886+ storage = optuna .storages .JournalStorage ( # type: ignore[assignment]
887+ optuna .storages .journal .JournalFileBackend (storage ),
888+ )
857889
858890 if self .verbose > 0 :
859891 print (f"Sampler: { self .sampler } - Pruner: { self .pruner } " )
860892
861893 study = optuna .create_study (
862894 sampler = sampler ,
863895 pruner = pruner ,
864- storage = self . storage ,
896+ storage = storage ,
865897 study_name = self .study_name ,
866898 load_if_exists = True ,
867899 direction = "maximize" ,
@@ -903,6 +935,9 @@ def hyperparameters_optimization(self) -> None:
903935 print ("Params: " )
904936 for key , value in trial .params .items ():
905937 print (f" { key } : { value } " )
938+ print ("User Attributes: " )
939+ for key , value in trial .user_attrs .items ():
940+ print (f" { key } : { value } " )
906941
907942 report_name = (
908943 f"report_{ self .env_name } _{ self .n_trials } -trials-{ self .n_timesteps } "
0 commit comments