Skip to content

Commit 66a0045

Browse files
authored
Merge pull request #73 from y0z/feature/benchmarks
Introduce the `optunahub.benchmarks` module
2 parents a3f85ea + acec775 commit 66a0045

File tree

9 files changed

+135
-2
lines changed

9 files changed

+135
-2
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{{ fullname | escape | underline }}
2+
3+
.. autoclass:: {{ fullname }}
4+
:members:
5+
:special-members: __init__, __call__
6+
:inherited-members:
7+
:undoc-members:

docs/source/benchmarks.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.. module:: optunahub.benchmarks
2+
3+
optunahub.benchmarks
4+
====================
5+
6+
.. autosummary::
7+
:toctree: generated/
8+
:nosignatures:
9+
:template: custom_summary.rst
10+
11+
optunahub.benchmarks.BaseProblem

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"sphinx.ext.autodoc",
2020
"sphinx.ext.autosummary",
2121
"sphinx.ext.napoleon",
22+
"sphinx.ext.viewcode",
2223
]
2324

2425
templates_path = ["_templates"]

docs/source/reference.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ Reference
55
:maxdepth: 1
66

77
optunahub
8-
samplers
8+
samplers
9+
benchmarks

docs/source/samplers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ optunahub.samplers
66
.. autosummary::
77
:toctree: generated/
88
:nosignatures:
9+
:template: custom_summary.rst
910

1011
optunahub.samplers.SimpleBaseSampler

optunahub/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from optunahub import benchmarks
12
from optunahub import samplers
23
from optunahub.hub import load_local_module
34
from optunahub.hub import load_module
45
from optunahub.version import __version__
56

67

7-
__all__ = ["load_module", "load_local_module", "__version__", "samplers"]
8+
__all__ = ["__version__", "benchmarks", "load_local_module", "load_module", "samplers"]

optunahub/benchmarks/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._base_problem import BaseProblem
2+
3+
4+
__all__ = [
5+
"BaseProblem",
6+
]
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
from abc import ABCMeta
4+
from abc import abstractmethod
5+
from typing import Any
6+
from typing import Sequence
7+
8+
import optuna
9+
10+
11+
class BaseProblem(metaclass=ABCMeta):
12+
"""Base class for optimization problems."""
13+
14+
def __call__(self, trial: optuna.Trial) -> float | Sequence[float]:
15+
"""Objective function for Optuna. By default, this method calls :meth:`evaluate` with the parameters defined in :attr:`search_space`.
16+
17+
Args:
18+
trial: Optuna trial object.
19+
Returns:
20+
The objective value or a sequence of the objective values for multi-objective optimization.
21+
"""
22+
params = {}
23+
for name, dist in self.search_space.items():
24+
params[name] = trial._suggest(name, dist)
25+
trial._check_distribution(name, dist)
26+
return self.evaluate(params)
27+
28+
def evaluate(self, params: dict[str, Any]) -> float | Sequence[float]:
29+
"""Evaluate the objective function.
30+
31+
Args:
32+
params: Dictionary of input parameters.
33+
34+
Returns:
35+
The objective value or a sequence of the objective values for multi-objective optimization.
36+
37+
Example:
38+
::
39+
40+
def evaluate(self, params: dict[str, Any]) -> float:
41+
x = params["x"]
42+
y = params["y"]
43+
return x ** 2 + y
44+
"""
45+
raise NotImplementedError
46+
47+
@property
48+
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
49+
"""Return the search space.
50+
51+
Returns:
52+
Dictionary of search space. Each dictionary element consists of the parameter name and distribution (see `optuna.distributions <https://optuna.readthedocs.io/en/stable/reference/distributions.html>`__).
53+
54+
Example:
55+
::
56+
57+
@property
58+
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
59+
return {
60+
"x": optuna.distributions.FloatDistribution(low=0, high=1),
61+
"y": optuna.distributions.CategoricalDistribution(choices=[0, 1, 2]),
62+
}
63+
"""
64+
raise NotImplementedError
65+
66+
@property
67+
@abstractmethod
68+
def directions(self) -> list[optuna.study.StudyDirection]:
69+
"""Return the optimization directions.
70+
71+
Returns:
72+
List of `optuna.study.direction <https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.StudyDirection.html>`__.
73+
74+
Example:
75+
::
76+
77+
@property
78+
def directions(self) -> list[optuna.study.StudyDirection]:
79+
return [optuna.study.StudyDirection.MINIMIZE]
80+
"""
81+
...

tests/test_benchmarks.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
import optuna
4+
5+
import optunahub
6+
7+
8+
def test_base_problem() -> None:
9+
class TestProblem(optunahub.benchmarks.BaseProblem):
10+
def evaluate(self, params: dict[str, float]) -> float:
11+
x = params["x"]
12+
return x**2
13+
14+
@property
15+
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
16+
return {"x": optuna.distributions.FloatDistribution(low=-1, high=1)}
17+
18+
@property
19+
def directions(self) -> list[optuna.study.StudyDirection]:
20+
return [optuna.study.StudyDirection.MINIMIZE]
21+
22+
problem = TestProblem()
23+
study = optuna.create_study(directions=problem.directions)
24+
study.optimize(problem, n_trials=20) # verify no error occurs

0 commit comments

Comments
 (0)