Skip to content

Commit 9e353ea

Browse files
authored
Merge pull request #155 from ryota717/113-add-bandit-sampler
Add multi-armed bandit sampler
2 parents 83871a9 + 0bc6cb7 commit 9e353ea

File tree

5 files changed

+140
-0
lines changed

5 files changed

+140
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 <Ryota Nishijima>
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
---
2+
author: Ryota Nishijima
3+
title: MAB Epsilon-Greedy Sampler
4+
description: Sampler based on multi-armed bandit algorithm with epsilon-greedy arm selection.
5+
tags: [sampler, multi-armed bandit]
6+
optuna_versions: [4.0.0]
7+
license: MIT License
8+
---
9+
10+
## Class or Function Names
11+
12+
- MABEpsilonGreedySampler
13+
14+
## Example
15+
16+
```python
17+
mod = optunahub.load_module("samplers/mab_epsilon_greedy")
18+
sampler = mod.MABEpsilonGreedySampler()
19+
```
20+
21+
See [`example.py`](https://github.com/optuna/optunahub-registry/blob/main/package/samplers/mab_epsilon_greedy/example.py) for more details.
22+
23+
## Others
24+
25+
This package provides a sampler based on Multi-armed bandit algorithm with epsilon-greedy selection.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .mab_epsilon_greedy import MABEpsilonGreedySampler
2+
3+
4+
__all__ = ["MABEpsilonGreedySampler"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import optuna
2+
import optunahub
3+
4+
5+
if __name__ == "__main__":
6+
module = optunahub.load_module(
7+
package="samplers/mab_epsilon_greedy",
8+
)
9+
sampler = module.MABEpsilonGreedySampler()
10+
11+
def objective(trial: optuna.Trial) -> float:
12+
x = trial.suggest_categorical("arm_1", [1, 2, 3])
13+
y = trial.suggest_categorical("arm_2", [1, 2])
14+
15+
return x + y
16+
17+
study = optuna.create_study(sampler=sampler)
18+
study.optimize(objective, n_trials=20)
19+
20+
print(study.best_trial.value, study.best_trial.params)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from collections import defaultdict
2+
from typing import Any
3+
from typing import Optional
4+
5+
from optuna.distributions import BaseDistribution
6+
from optuna.samplers import RandomSampler
7+
from optuna.study import Study
8+
from optuna.study._study_direction import StudyDirection
9+
from optuna.trial import FrozenTrial
10+
from optuna.trial import TrialState
11+
12+
13+
class MABEpsilonGreedySampler(RandomSampler):
14+
"""Sampler based on Multi-armed Bandit Algorithm.
15+
16+
Args:
17+
epsilon (float):
18+
Params for epsolon-greedy algorithm.
19+
epsilon is probability of selecting arm randomly.
20+
seed (int | None):
21+
Seed for random number generator and arm selection.
22+
23+
"""
24+
25+
def __init__(
26+
self,
27+
epsilon: float = 0.7,
28+
seed: Optional[int] = None,
29+
) -> None:
30+
super().__init__(seed)
31+
self._epsilon = epsilon
32+
33+
def sample_independent(
34+
self,
35+
study: Study,
36+
trial: FrozenTrial,
37+
param_name: str,
38+
param_distribution: BaseDistribution,
39+
) -> Any:
40+
states = (TrialState.COMPLETE, TrialState.PRUNED)
41+
trials = study._get_trials(deepcopy=False, states=states, use_cache=True)
42+
43+
rewards_by_choice: defaultdict = defaultdict(float)
44+
cnt_by_choice: defaultdict = defaultdict(int)
45+
for t in trials:
46+
rewards_by_choice[t.params[param_name]] += t.value
47+
cnt_by_choice[t.params[param_name]] += 1
48+
49+
# Use never selected arm for initialization like UCB1 algorithm.
50+
# ref. https://github.com/optuna/optunahub-registry/pull/155#discussion_r1780446062
51+
never_selected = [
52+
arm for arm in param_distribution.choices if arm not in rewards_by_choice
53+
]
54+
if never_selected:
55+
return self._rng.rng.choice(never_selected)
56+
57+
# If all arms are selected at least once, select arm by epsilon-greedy.
58+
if self._rng.rng.rand() < self._epsilon:
59+
return self._rng.rng.choice(param_distribution.choices)
60+
else:
61+
if study.direction == StudyDirection.MINIMIZE:
62+
return min(
63+
param_distribution.choices,
64+
key=lambda x: rewards_by_choice[x] / max(cnt_by_choice[x], 1),
65+
)
66+
else:
67+
return max(
68+
param_distribution.choices,
69+
key=lambda x: rewards_by_choice[x] / max(cnt_by_choice[x], 1),
70+
)

0 commit comments

Comments
 (0)