Skip to content

Commit c1080b1

Browse files
authored
Refactor hyperparameter optimization (#485)
* Start refactoring hyperparam optim * Remove hardcoded env for success recording * Refactor all algos * Bug fixes * Update doc and add test
1 parent 06ab062 commit c1080b1

File tree

8 files changed

+410
-361
lines changed

8 files changed

+410
-361
lines changed

CHANGELOG.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
1-
## Release 2.6.0a2 (WIP)
1+
## Release 2.6.0a3 (WIP)
22

33
### Breaking Changes
44
- Upgraded to SB3 >= 2.6.0
5+
- Refactored hyperparameter optimization. The Optuna [Journal storage backend](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.storages.JournalStorage.html) is now supported (recommended default) and you can easily load tuned hyperparameter via the new `--trial-id` argument of `train.py`.
6+
7+
For example, optimize using the journal storage:
8+
```bash
9+
python train.py --algo ppo --env Pendulum-v1 -n 40000 --study-name demo --storage logs/demo.log --sampler tpe --n-evaluations 2 --optimize --no-optim-plots
10+
```
11+
Visualize live using [optuna-dashboard](https://optuna-dashboard.readthedocs.io/en/latest/getting-started.html)
12+
```
13+
optuna-dashboard logs/demo.log
14+
```
15+
16+
Load hyperparameters from trial number 21 and train an agent with it:
17+
```bash
18+
python train.py --algo ppo --env Pendulum-v1 --study-name demo --storage logs/demo.log --trial-id 21
19+
```
20+
521

622
### New Features
723
- Save the exact command line used to launch a training
@@ -15,6 +31,7 @@
1531
### Documentation
1632

1733
### Other
34+
- `scripts/parse_study.py` is now deprecated because of the new hyperparameter optimization scripts
1835

1936
## Release 2.5.0 (2025-01-27)
2037

docs/guide/tuning.rst

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
Hyperparameter Tuning
55
=====================
66

7-
Hyperparameter Tuning
8-
---------------------
7+
Automated hyperparameter optimization
8+
-------------------------------------
9+
10+
Blog post: `Automatic Hyperparameter Tuning - A Visual Guide <https://araffin.github.io/post/hyperparam-tuning/>`_
11+
12+
Video: https://www.youtube.com/watch?v=AidFTOdGNFQ
913

1014
We use `Optuna <https://optuna.org/>`__ for optimizing the
1115
hyperparameters. Not all hyperparameters are tuned, and tuning enforces
@@ -35,20 +39,29 @@ documentation <https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/
3539

3640
::
3741

38-
python train.py --algo ppo --env MountainCar-v0 -optimize --study-name test --storage sqlite:///example.db
42+
python train.py --algo ppo --env MountainCar-v0 -optimize --study-name test --storage logs/demo.log
3943

40-
Print and save best hyperparameters of an Optuna study:
4144

42-
::
4345

44-
python scripts/parse_study.py -i path/to/study.pkl --print-n-best-trials 10 --save-n-best-hyperparameters 10
46+
Visualize live using `optuna-dashboard <https://optuna-dashboard.readthedocs.io/en/latest/getting-started.html>`__
47+
48+
.. code:: bash
49+
50+
optuna-dashboard logs/demo.log
51+
52+
Load hyperparameters from trial number 21 and train an agent with it:
53+
54+
.. code:: bash
55+
56+
python train.py --algo ppo --env MountainCar-v0 --study-name test --storage logs/demo.log --trial-id 21
57+
4558
4659
The default budget for hyperparameter tuning is 500 trials and there is
4760
one intermediate evaluation for pruning/early stopping per 100k time
4861
steps.
4962

5063
Hyperparameters search space
51-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
64+
----------------------------
5265

5366
Note that the default hyperparameters used in the zoo when tuning are
5467
not always the same as the defaults provided in
@@ -65,7 +78,3 @@ example:
6578
- Non-episodic rollout in TD3 and DDPG assumes
6679
``gradient_steps = train_freq`` and so tunes only ``train_freq`` to
6780
reduce the search space.
68-
69-
When working with continuous actions, we recommend to enable
70-
`gSDE <https://arxiv.org/abs/2005.05719>`__ by uncommenting lines in
71-
`rl_zoo3/hyperparams_opt.py <https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/hyperparams_opt.py>`__.

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ box2d-py==2.3.8
44
pybullet_envs_gymnasium>=0.6.0
55
# minigrid
66
cloudpickle>=2.2.1
7+
# Optuna auto
8+
optunahub>=0.2.0
79
# optuna plots:
810
plotly
911
# need to upgrade to gymnasium:

rl_zoo3/exp_manager.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
# Register custom envs
5050
import rl_zoo3.import_envs # noqa: F401
5151
from rl_zoo3.callbacks import SaveVecNormalizeCallback, TrialEvalCallback
52-
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER
52+
from rl_zoo3.hyperparams_opt import HYPERPARAMS_CONVERTER, HYPERPARAMS_SAMPLER
5353
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule
5454

5555

@@ -102,6 +102,7 @@ def __init__(
102102
device: Union[th.device, str] = "auto",
103103
config: Optional[str] = None,
104104
show_progress: bool = False,
105+
trial_id: Optional[int] = None,
105106
):
106107
super().__init__()
107108
self.algo = algo
@@ -160,6 +161,8 @@ def __init__(
160161
self.storage = storage
161162
self.study_name = study_name
162163
self.no_optim_plots = no_optim_plots
164+
# For loading hyperparams from a study
165+
self.trial_id = trial_id
163166
# maximum number of trials for finding the best hyperparams
164167
self.n_trials = n_trials
165168
self.max_total_trials = max_total_trials
@@ -334,6 +337,11 @@ def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
334337
else:
335338
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id} in {self.config}")
336339

340+
if self.storage and self.study_name and self.trial_id:
341+
print("Loading from Optuna study...")
342+
study_hyperparams = self.load_trial(self.storage, self.study_name, self.trial_id)
343+
hyperparams.update(study_hyperparams)
344+
337345
if self.custom_hyperparams is not None:
338346
# Overwrite hyperparams if needed
339347
hyperparams.update(self.custom_hyperparams)
@@ -346,6 +354,24 @@ def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
346354

347355
return hyperparams, saved_hyperparams
348356

357+
def load_trial(
358+
self, storage: str, study_name: str, trial_id: Optional[int] = None, convert: bool = True
359+
) -> dict[str, Any]:
360+
361+
if storage.endswith(".log"):
362+
optuna_storage = optuna.storages.JournalStorage(optuna.storages.journal.JournalFileBackend(storage))
363+
else:
364+
optuna_storage = storage # type: ignore[assignment]
365+
study = optuna.load_study(storage=optuna_storage, study_name=study_name)
366+
if trial_id is not None:
367+
params = study.trials[trial_id].params
368+
else:
369+
params = study.best_trial.params
370+
371+
if convert:
372+
return HYPERPARAMS_CONVERTER[self.algo](params)
373+
return params
374+
349375
@staticmethod
350376
def _preprocess_schedules(hyperparams: dict[str, Any]) -> dict[str, Any]:
351377
# Create schedules
@@ -470,6 +496,10 @@ def _preprocess_hyperparams( # noqa: C901
470496
def _preprocess_action_noise(
471497
self, hyperparams: dict[str, Any], saved_hyperparams: dict[str, Any], env: VecEnv
472498
) -> dict[str, Any]:
499+
# Compute n_actions for hyperparameter optim
500+
if isinstance(env.action_space, spaces.Box):
501+
self.n_actions = env.action_space.shape[0]
502+
473503
# Parse noise string
474504
# Note: only off-policy algorithms are supported
475505
if hyperparams.get("noise_type") is not None:
@@ -480,7 +510,6 @@ def _preprocess_action_noise(
480510
assert isinstance(
481511
env.action_space, spaces.Box
482512
), f"Action noise can only be used with Box action space, not {env.action_space}"
483-
self.n_actions = env.action_space.shape[0]
484513

485514
if "normal" in noise_type:
486515
hyperparams["action_noise"] = NormalActionNoise(
@@ -619,11 +648,9 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
619648
log_dir = None if eval_env or no_log else self.save_path
620649

621650
# Special case for GoalEnvs: log success rate too
622-
if (
623-
"Neck" in self.env_name.gym_id
624-
or self.is_robotics_env(self.env_name.gym_id)
625-
or ("parking-v0" in self.env_name.gym_id and len(self.monitor_kwargs) == 0) # do not overwrite custom kwargs
626-
):
651+
if self.is_robotics_env(self.env_name.gym_id) or (
652+
"parking-v0" in self.env_name.gym_id and len(self.monitor_kwargs) == 0
653+
): # do not overwrite custom kwargs
627654
self.monitor_kwargs = dict(info_keywords=("is_success",))
628655

629656
spec = gym.spec(self.env_name.gym_id)
@@ -722,13 +749,10 @@ def _create_sampler(self, sampler_method: str) -> BaseSampler:
722749
sampler: BaseSampler = RandomSampler(seed=self.seed)
723750
elif sampler_method == "tpe":
724751
sampler = TPESampler(n_startup_trials=self.n_startup_trials, seed=self.seed, multivariate=True)
725-
elif sampler_method == "skopt":
726-
from optuna.integration.skopt import SkoptSampler
752+
elif sampler_method == "auto":
753+
import optunahub
727754

728-
# cf https://scikit-optimize.github.io/#skopt.Optimizer
729-
# GP: gaussian process
730-
# Gradient boosted regression: GBRT
731-
sampler = SkoptSampler(skopt_kwargs={"base_estimator": "GP", "acq_func": "gp_hedge"})
755+
sampler = optunahub.load_module("samplers/auto_sampler").AutoSampler(seed=self.seed)
732756
else:
733757
raise ValueError(f"Unknown sampler: {sampler_method}")
734758
return sampler
@@ -854,14 +878,22 @@ def hyperparameters_optimization(self) -> None:
854878
# TODO: eval each hyperparams several times to account for noisy evaluation
855879
sampler = self._create_sampler(self.sampler)
856880
pruner = self._create_pruner(self.pruner)
881+
# Log file storage
882+
storage = self.storage
883+
if storage is not None and storage.endswith(".log"):
884+
# Create folder if it doesn't exist
885+
Path(storage).parent.mkdir(parents=True, exist_ok=True)
886+
storage = optuna.storages.JournalStorage( # type: ignore[assignment]
887+
optuna.storages.journal.JournalFileBackend(storage),
888+
)
857889

858890
if self.verbose > 0:
859891
print(f"Sampler: {self.sampler} - Pruner: {self.pruner}")
860892

861893
study = optuna.create_study(
862894
sampler=sampler,
863895
pruner=pruner,
864-
storage=self.storage,
896+
storage=storage,
865897
study_name=self.study_name,
866898
load_if_exists=True,
867899
direction="maximize",
@@ -903,6 +935,9 @@ def hyperparameters_optimization(self) -> None:
903935
print("Params: ")
904936
for key, value in trial.params.items():
905937
print(f" {key}: {value}")
938+
print("User Attributes: ")
939+
for key, value in trial.user_attrs.items():
940+
print(f" {key}: {value}")
906941

907942
report_name = (
908943
f"report_{self.env_name}_{self.n_trials}-trials-{self.n_timesteps}"

0 commit comments

Comments
 (0)