Skip to content

Commit 2e78415

Browse files
authored
Merge pull request #369 from nabenabe0928/add-plot-target-over-time
Add plot target over time
2 parents 07e0369 + a1d98df commit 2e78415

5 files changed

Lines changed: 256 additions & 0 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2026 Shuhei Watanabe
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
---
2+
author: Shuhei Watanabe
3+
title: Plot Target Over Time
4+
description: With this visualization module, we can plot the best target value over wall-clock time, averaged across multiple studies with standard error bands.
5+
tags: [visualization, benchmarking, runtime]
6+
optuna_versions: [4.8.0]
7+
license: MIT License
8+
---
9+
10+
## Abstract
11+
12+
This visualization module enables users to plot the target value over time with standard error bands.
13+
This module is especially convenient when we use parallel optimization such as [Asynchronous optimization simulation](https://hub.optuna.org/benchmarks/async_opt_simulator).
14+
15+
![Example Using Async Opt](images/async-bench-example.png)
16+
17+
## Class or Function Names
18+
19+
- plot_target_over_time
20+
21+
## Installation
22+
23+
This module requires the following dependencies:
24+
25+
- matplotlib
26+
- numpy
27+
28+
## APIs
29+
30+
# `plot_target_over_time(study_list, *, ax=None, states=None, target=None, target_direction=None, cumtime_func=None, log_time_scale=True, n_steps=100, color=None, **plot_kwargs)`
31+
32+
- `study_list`: A list of `optuna.Study` objects. Each study is treated as one run, and results are averaged across them.
33+
- `ax`: A `matplotlib.axes.Axes` object. If not provided, a new figure and axes will be created.
34+
- `states`: A list of `optuna.trial.TrialState` to include. Defaults to `[TrialState.COMPLETE, TrialState.PRUNED]`.
35+
- `target`: A callable that takes a `FrozenTrial` and returns a float value. If not provided, `trial.value` is used.
36+
- `target_direction`: The direction to optimize the target. Required when `target` is specified. Must be `"minimize"`, `"maximize"`, or the corresponding `StudyDirection` enum.
37+
- `cumtime_func`: A callable that takes a `FrozenTrial` and returns the cumulative time as a float. If not provided, the elapsed time from the first trial start is used.
38+
- `log_time_scale`: Whether to use a logarithmic time scale for interpolation. Defaults to `True`.
39+
- `n_steps`: The number of time steps for interpolation. Defaults to `100`.
40+
- `color`: The color for the plot line and shaded region.
41+
- `**plot_kwargs`: Additional keyword arguments passed to `ax.plot` (e.g., `label`, `linestyle`).
42+
43+
## Example
44+
45+
```python
46+
from __future__ import annotations
47+
48+
import optuna
49+
import optunahub
50+
51+
import matplotlib.pyplot as plt
52+
53+
54+
def objective(trial: optuna.Trial) -> float:
55+
x = trial.suggest_float("x", -5, 5)
56+
y = trial.suggest_float("y", -5, 5)
57+
return x**2 + y**2
58+
59+
60+
plot_target_over_time = optunahub.load_module("visualization/plot_target_over_time").plot_target_over_time
61+
_, ax = plt.subplots()
62+
colors = ["darkred", "black"]
63+
for sampler, color in zip([optuna.samplers.TPESampler(), optuna.samplers.RandomSampler()], colors):
64+
study_list = []
65+
for _ in range(5):
66+
study = optuna.create_study(sampler=sampler)
67+
study.optimize(objective, n_trials=20)
68+
study_list.append(study)
69+
plot_target_over_time(
70+
study_list,
71+
ax=ax,
72+
color=color,
73+
label=sampler.__class__.__name__,
74+
)
75+
76+
ax.legend()
77+
plt.show()
78+
79+
```
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import optuna
8+
from optuna.study import StudyDirection
9+
from optuna.trial import TrialState
10+
11+
12+
if TYPE_CHECKING:
13+
from collections.abc import Callable
14+
from typing import Any
15+
16+
from matplotlib.axes import Axes
17+
from matplotlib.lines import Line2D
18+
19+
20+
def _get_values_on_fixed_time_steps(
21+
cumtime_list: list[np.ndarray],
22+
target_list: list[np.ndarray],
23+
log_time_scale: bool,
24+
n_steps: int,
25+
) -> tuple[np.ndarray, np.ndarray]:
26+
t_min = np.min(np.stack(cumtime_list))
27+
t_max = np.max(np.stack(cumtime_list))
28+
if log_time_scale:
29+
ts = np.exp(np.linspace(np.log(t_min), np.log(t_max), num=n_steps))
30+
else:
31+
ts = np.linspace(t_min, t_max, num=n_steps)
32+
v_on_grid = []
33+
for ct, v in zip(cumtime_list, target_list):
34+
i_upper = np.minimum(np.searchsorted(ct, ts, side="left"), v.size - 1)
35+
v_on_grid.append(v[i_upper])
36+
return ts, np.array(v_on_grid)
37+
38+
39+
def _validate(
40+
valid_states: tuple[TrialState, ...],
41+
states: tuple[TrialState, ...],
42+
target: Callable[[optuna.trial.FrozenTrial], float] | None,
43+
target_direction: StudyDirection | str | None,
44+
) -> None:
45+
if any(s not in valid_states for s in states):
46+
raise ValueError(f"{states=} must be in {valid_states}.")
47+
if target_direction is None:
48+
if target is not None:
49+
raise ValueError("target was specified, but got target_direction=None.")
50+
else:
51+
if target is None:
52+
raise ValueError("target_direction was provided, but got target=None.")
53+
if target_direction not in [
54+
"minimize",
55+
"maximize",
56+
StudyDirection.MAXIMIZE,
57+
StudyDirection.MINIMIZE,
58+
]:
59+
raise ValueError(
60+
f"target_direction must be either `minimize` or `maximize` but got {target_direction=}"
61+
)
62+
63+
64+
def plot_target_over_time(
65+
study_list: list[optuna.Study],
66+
*,
67+
color: str,
68+
ax: Axes | None = None,
69+
states: tuple[TrialState, ...] | None = None,
70+
target: Callable[[optuna.trial.FrozenTrial], float] | None = None,
71+
target_direction: optuna.study.StudyDirection | str | None = None,
72+
cumtime_func: Callable[[optuna.trial.FrozenTrial], float] | None = None,
73+
log_time_scale: bool = True,
74+
n_steps: int = 100,
75+
**plot_kwargs: Any,
76+
) -> Line2D:
77+
if ax is None:
78+
_, ax = plt.subplots()
79+
80+
valid_states = (TrialState.COMPLETE, TrialState.PRUNED)
81+
states = states or valid_states
82+
assert states is not None, "Mypy Redefinition."
83+
_validate(valid_states, states, target, target_direction)
84+
85+
target_list = []
86+
cumtime_list = []
87+
direction = target_direction or study_list[0].direction
88+
for study in study_list:
89+
trials = study.get_trials(deepcopy=False, states=states)
90+
target_vals = np.array([target(t) if target is not None else t.value for t in trials])
91+
if cumtime_func is not None:
92+
cumtime_list.append(np.array([cumtime_func(t) for t in trials]))
93+
else:
94+
datetime_start = min(t.datetime_start for t in trials if t.datetime_start is not None)
95+
cumtimes = np.array(
96+
[
97+
(t.datetime_complete - datetime_start).total_seconds()
98+
for t in trials
99+
if t.datetime_complete is not None
100+
]
101+
)
102+
cumtime_list.append(cumtimes)
103+
order = np.argsort(cumtime_list[-1])
104+
cumtime_list[-1] = cumtime_list[-1][order]
105+
if direction in ["minimize", StudyDirection.MINIMIZE]:
106+
target_list.append(np.minimum.accumulate(target_vals[order]))
107+
else:
108+
target_list.append(np.maximum.accumulate(target_vals[order]))
109+
110+
ts, vs = _get_values_on_fixed_time_steps(
111+
cumtime_list,
112+
target_list,
113+
log_time_scale,
114+
n_steps,
115+
)
116+
m = np.mean(vs, axis=0)
117+
s = np.std(vs, axis=0) / np.sqrt(len(study_list))
118+
(line,) = ax.plot(ts, m, color=color, **plot_kwargs)
119+
ax.fill_between(ts, m - s, m + s, color=color, alpha=0.2)
120+
return line
121+
122+
123+
__all__ = ["plot_target_over_time"]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
import matplotlib.pyplot as plt
4+
import optuna
5+
import optunahub
6+
7+
8+
def objective(trial: optuna.Trial) -> float:
9+
x = trial.suggest_float("x", -5, 5)
10+
y = trial.suggest_float("y", -5, 5)
11+
return x**2 + y**2
12+
13+
14+
plot_target_over_time = optunahub.load_module(
15+
"visualization/plot_target_over_time"
16+
).plot_target_over_time
17+
_, ax = plt.subplots()
18+
colors = ["darkred", "black"]
19+
for sampler, color in zip([optuna.samplers.TPESampler(), optuna.samplers.RandomSampler()], colors):
20+
study_list = []
21+
for __ in range(5):
22+
study = optuna.create_study(sampler=sampler)
23+
study.optimize(objective, n_trials=20)
24+
study_list.append(study)
25+
plot_target_over_time(
26+
study_list,
27+
ax=ax,
28+
color=color,
29+
label=sampler.__class__.__name__,
30+
)
31+
32+
ax.legend()
33+
plt.show()
92.1 KB
Loading

0 commit comments

Comments
 (0)