Skip to content

Commit 61d47cc

Browse files
authored
Merge pull request #75 from y0z/feature/benchmarks
Introduce `ConstrainedMixin`
2 parents e0bfffd + 1cbaf7f commit 61d47cc

File tree

4 files changed

+110
-12
lines changed

4 files changed

+110
-12
lines changed

docs/source/benchmarks.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ optunahub.benchmarks
88
:nosignatures:
99
:template: custom_summary.rst
1010

11-
optunahub.benchmarks.BaseProblem
11+
optunahub.benchmarks.BaseProblem
12+
optunahub.benchmarks.ConstrainedMixin

optunahub/benchmarks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from ._base_problem import BaseProblem
2+
from ._constrained_mixin import ConstrainedMixin
23

34

45
__all__ = [
56
"BaseProblem",
7+
"ConstrainedMixin",
68
]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from typing import Any
5+
6+
import optuna
7+
8+
9+
class ConstrainedMixin:
10+
"""Mixin class for constrained optimization problems.
11+
12+
Example:
13+
You can define a constrained optimization problem by inheriting this class and implementing
14+
the :meth:`evaluate_constraints` method as follows.
15+
16+
::
17+
18+
import optuna
19+
import optunahub
20+
21+
class BinAndKorn(optunahub.benchmarks.ConstrainedMixin, optunahub.benchmarks.BaseProblem):
22+
def evaluate(self, params: dict[str, float]) -> tuple[float]:
23+
x = params["x"]
24+
y = params["y"]
25+
26+
v0 = 4 * x**2 + 4 * y**2
27+
v1 = (x - 5)**2 + (y - 5)**2
28+
29+
return v0, v1
30+
31+
def evaluate_constraints(self, params: dict[str, float]) -> tuple[float]:
32+
x = params["x"]
33+
y = params["y"]
34+
35+
# Constraints which are considered feasible if less than or equal to zero.
36+
# The feasible region is basically the intersection of a circle centered at (x=5, y=0)
37+
# and the complement to a circle centered at (x=8, y=-3).
38+
c0 = (x - 5)**2 + y**2 - 25
39+
c1 = -((x - 8)**2) - (y + 3)**2 + 7.7
40+
41+
return c0, c1
42+
43+
@property
44+
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
45+
return {
46+
"x": optuna.distributions.FloatDistribution(low=-15, high=30),
47+
"y": optuna.distributions.FloatDistribution(low=-15, high=30)
48+
}
49+
50+
@property
51+
def directions(self) -> list[optuna.study.StudyDirection]:
52+
return [optuna.study.StudyDirection.MINIMIZE, optuna.study.StudyDirection.MINIMIZE]
53+
54+
problem = BinAndKorn()
55+
sampler = optuna.samplers.TPESampler(constraints_func=problem.constraints_func)
56+
study = optuna.create_study(sampler=sampler, directions=problem.directions)
57+
study.optimize(problem, n_trials=20)
58+
"""
59+
60+
def constraints_func(self, trial: optuna.trial.FrozenTrial) -> Sequence[float]:
61+
"""Evaluate the constraint functions.
62+
63+
Args:
64+
trial: Optuna trial object.
65+
Returns:
66+
List of the constraint values.
67+
"""
68+
return self.evaluate_constraints(trial.params.copy())
69+
70+
def evaluate_constraints(self, params: dict[str, Any]) -> Sequence[float]:
71+
"""Evaluate the constraint functions.
72+
73+
Args:
74+
params: Dictionary of input parameters.
75+
Returns:
76+
List of the constraint values.
77+
"""
78+
raise NotImplementedError

tests/test_benchmarks.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,41 @@
11
from __future__ import annotations
22

33
import optuna
4+
from optuna.samplers._base import _CONSTRAINTS_KEY
45

56
import optunahub
67

78

8-
def test_base_problem() -> None:
9-
class TestProblem(optunahub.benchmarks.BaseProblem):
10-
def evaluate(self, params: dict[str, float]) -> float:
11-
x = params["x"]
12-
return x**2
9+
class TestProblem(optunahub.benchmarks.BaseProblem):
10+
def evaluate(self, params: dict[str, float]) -> float:
11+
x = params["x"]
12+
return x**2
13+
14+
@property
15+
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
16+
return {"x": optuna.distributions.FloatDistribution(low=-1, high=1)}
1317

14-
@property
15-
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
16-
return {"x": optuna.distributions.FloatDistribution(low=-1, high=1)}
18+
@property
19+
def directions(self) -> list[optuna.study.StudyDirection]:
20+
return [optuna.study.StudyDirection.MINIMIZE]
1721

18-
@property
19-
def directions(self) -> list[optuna.study.StudyDirection]:
20-
return [optuna.study.StudyDirection.MINIMIZE]
2122

23+
def test_base_problem() -> None:
2224
problem = TestProblem()
2325
study = optuna.create_study(directions=problem.directions)
2426
study.optimize(problem, n_trials=20) # verify no error occurs
27+
28+
29+
def test_constrained_mixin() -> None:
30+
class ConstrainedTestProblem(optunahub.benchmarks.ConstrainedMixin, TestProblem):
31+
def evaluate_constraints(self, params: dict[str, float]) -> list[float]:
32+
return [params["x"]]
33+
34+
problem = ConstrainedTestProblem()
35+
sampler = optuna.samplers.TPESampler(constraints_func=problem.constraints_func)
36+
study = optuna.create_study(sampler=sampler, directions=problem.directions)
37+
study.optimize(problem, n_trials=20) # verify no error occurs
38+
39+
# Check if constraints are stored in trials
40+
for t in study.trials:
41+
assert _CONSTRAINTS_KEY in study._storage.get_trial_system_attrs(t._trial_id)

0 commit comments

Comments
 (0)