-
Notifications
You must be signed in to change notification settings - Fork 51
Add BisectSampler
#271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add BisectSampler
#271
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| MIT License | ||
|
|
||
| Copyright (c) 2025 Shuhei Watanabe | ||
|
|
||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| of this software and associated documentation files (the "Software"), to deal | ||
| in the Software without restriction, including without limitation the rights | ||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| copies of the Software, and to permit persons to whom the Software is | ||
| furnished to do so, subject to the following conditions: | ||
|
|
||
| The above copyright notice and this permission notice shall be included in all | ||
| copies or substantial portions of the Software. | ||
|
|
||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| SOFTWARE. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| --- | ||
| author: Shuhei Watanabe | ||
| title: A Sampler Using Parameter-Wise Bisection, aka Binary, Search | ||
| description: This sampler allows users to apply binary search to each parameter. | ||
| tags: [sampler, binary search, bisection search] | ||
| optuna_versions: [4.3.0] | ||
| license: MIT License | ||
| --- | ||
|
|
||
| ## Abstract | ||
|
|
||
| This sampler allows users to apply binary search to each parameter. | ||
| Please see the example for the usage. | ||
| Note that this sampler is not supposed to be used in the distributed optimization setup. | ||
|
|
||
| ## APIs | ||
|
|
||
| - `BisectSampler(*, rtol: float = 1e-5, atol: float = 1e-8)` | ||
| - `rtol`: The relative tolerance parameter to be used to judge whether all the parameters are converged. Default to that in `np.isclose`, i.e., `1e-5`. | ||
| - `atol`: The absolute tolerance parameter to be used to judge whether all the parameters are converged. Default to that in `np.isclose`, i.e., `1e-8`. | ||
|
|
||
| By calling `BisectSampler.get_best_param(study)`, we can obtain the best parameter of the study. | ||
|
|
||
| ## Installation | ||
|
|
||
| This sampler does not have any dependencies on top of `optunahub`. | ||
|
|
||
| ## Example | ||
|
|
||
| For each parameter, please set `XXX_is_too_high` to the trial `user_attrs`. | ||
| Please see below for a concrete example. | ||
|
|
||
| ```python | ||
| from collections.abc import Callable | ||
|
|
||
| import optuna | ||
| import optunahub | ||
|
|
||
|
|
||
| BisectSampler = optunahub.load_module("samplers/bisect").BisectSampler | ||
|
|
||
|
|
||
| def objective(trial: optuna.Trial, score_func: Callable[[optuna.Trial], float]) -> float: | ||
| x = trial.suggest_float("x", -1, 1) | ||
| # For each param, e.g., `ZZZ`, please set `ZZZ_is_too_high`. | ||
| trial.set_user_attr("x_is_too_high", x > 0.5) | ||
| y = trial.suggest_float("y", -1, 1, step=0.2) | ||
| trial.set_user_attr("y_is_too_high", y > 0.2) | ||
| # Please use `BisectSampler.score_func`. | ||
| return BisectSampler.score_func(trial) | ||
|
|
||
|
|
||
| sampler = BisectSampler() | ||
| study = optuna.create_study(sampler=sampler) | ||
| study.optimize(objective, n_trials=20) | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from ._sampler import BisectSampler | ||
|
|
||
|
|
||
| __all__ = ["BisectSampler"] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,199 @@ | ||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||
| from typing import TYPE_CHECKING | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||
| import optuna | ||||||||||||||||||||||||
| from optuna.trial import TrialState | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||
| from collections.abc import Sequence | ||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| from optuna import Study | ||||||||||||||||||||||||
| from optuna import Trial | ||||||||||||||||||||||||
| from optuna.distributions import BaseDistribution | ||||||||||||||||||||||||
| from optuna.distributions import FloatDistribution | ||||||||||||||||||||||||
| from optuna.distributions import IntDistribution | ||||||||||||||||||||||||
| from optuna.trial import FrozenTrial | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| PREFIX_LEFT = "bisect:left_" | ||||||||||||||||||||||||
| PREFIX_RIGHT = "bisect:right_" | ||||||||||||||||||||||||
|
Comment on lines
+23
to
+24
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class BisectSampler(optuna.samplers.BaseSampler): | ||||||||||||||||||||||||
| """Sampler using bisect (binary search) indepedently for each parameter. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||
| rtol: | ||||||||||||||||||||||||
| The relative tolerance parameter to be used to judge whether all the parameters are | ||||||||||||||||||||||||
| converged. Default to that in `np.isclose`, i.e., 1e-5. | ||||||||||||||||||||||||
| atol: | ||||||||||||||||||||||||
| The absolute tolerance parameter to be used to judge whether all the parameters are | ||||||||||||||||||||||||
| converged. Default to that in `np.isclose`, i.e., 1e-8. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def __init__(self, *, rtol: float = 1e-5, atol: float = 1e-8) -> None: | ||||||||||||||||||||||||
| self._atol = atol | ||||||||||||||||||||||||
| self._rtol = rtol | ||||||||||||||||||||||||
| self._search_space: dict[str, IntDistribution | FloatDistribution] = {} | ||||||||||||||||||||||||
| self._stop_flag = False | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def infer_relative_search_space( | ||||||||||||||||||||||||
| self, study: optuna.Study, trial: optuna.trial.FrozenTrial | ||||||||||||||||||||||||
| ) -> dict[str, optuna.distributions.BaseDistribution]: | ||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The addition of
Suggested change
|
||||||||||||||||||||||||
| return {} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def sample_relative( | ||||||||||||||||||||||||
| self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution] | ||||||||||||||||||||||||
| ) -> dict[str, Any]: | ||||||||||||||||||||||||
| return {} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _get_current_left_and_right( | ||||||||||||||||||||||||
| self, study: Study, param_name: str | ||||||||||||||||||||||||
| ) -> tuple[int | float, int | float]: | ||||||||||||||||||||||||
| left_key = f"{PREFIX_LEFT}{param_name}" | ||||||||||||||||||||||||
| right_key = f"{PREFIX_RIGHT}{param_name}" | ||||||||||||||||||||||||
| system_attrs = study._storage.get_study_system_attrs(study._study_id) | ||||||||||||||||||||||||
| left = system_attrs[left_key] | ||||||||||||||||||||||||
| right = system_attrs[right_key] | ||||||||||||||||||||||||
| assert isinstance(left, (int, float)) and isinstance(right, (int, float)) | ||||||||||||||||||||||||
| return left, right | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _set_left_and_right( | ||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||
| study: Study, | ||||||||||||||||||||||||
| param_name: str, | ||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||
| left: int | float | None = None, | ||||||||||||||||||||||||
| right: int | float | None = None, | ||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||
| left_key = f"{PREFIX_LEFT}{param_name}" | ||||||||||||||||||||||||
| right_key = f"{PREFIX_RIGHT}{param_name}" | ||||||||||||||||||||||||
| if left is not None: | ||||||||||||||||||||||||
| study._storage.set_study_system_attr(study._study_id, left_key, left) | ||||||||||||||||||||||||
| if right is not None: | ||||||||||||||||||||||||
| study._storage.set_study_system_attr(study._study_id, right_key, right) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def sample_independent( | ||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||
| study: Study, | ||||||||||||||||||||||||
| trial: FrozenTrial, | ||||||||||||||||||||||||
| param_name: str, | ||||||||||||||||||||||||
| param_distribution: BaseDistribution, | ||||||||||||||||||||||||
| ) -> Any: | ||||||||||||||||||||||||
| if isinstance(param_distribution, optuna.distributions.CategoricalDistribution): | ||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. Note that we have to add
Suggested change
|
||||||||||||||||||||||||
| raise ValueError("CategoricalDistribution is not supported.") | ||||||||||||||||||||||||
| if param_distribution.log: | ||||||||||||||||||||||||
| raise ValueError("Log scale is not supported.") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self._search_space[param_name] = param_distribution | ||||||||||||||||||||||||
| step = param_distribution.step | ||||||||||||||||||||||||
| is_discrete = step is not None | ||||||||||||||||||||||||
| if trial.number == 0: | ||||||||||||||||||||||||
| right = param_distribution.high + step if is_discrete else param_distribution.high | ||||||||||||||||||||||||
| self._set_left_and_right(study, param_name, left=param_distribution.low, right=right) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| left, right = self._get_current_left_and_right(study, param_name) | ||||||||||||||||||||||||
| if not is_discrete: | ||||||||||||||||||||||||
| mid = (left + right) / 2.0 | ||||||||||||||||||||||||
| return mid | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| possible_param_values = self._get_possible_param_values(param_distribution) | ||||||||||||||||||||||||
| indices = np.arange(len(possible_param_values)) | ||||||||||||||||||||||||
| left_index = indices[np.isclose(possible_param_values, left)][0] | ||||||||||||||||||||||||
| right_index = indices[np.isclose(possible_param_values, right)][0] | ||||||||||||||||||||||||
| mid_index = (right_index + left_index) // 2 | ||||||||||||||||||||||||
| assert mid_index != len(possible_param_values) - 1, "The last element is for convenience." | ||||||||||||||||||||||||
| return possible_param_values[mid_index].item() | ||||||||||||||||||||||||
|
Comment on lines
+105
to
+111
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _get_possible_param_values( | ||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t think we need to list all possible parameter values for discrete distributions. We can avoid this by using the index of the discrete search space, as I’ll suggest in the alternative code below: |
||||||||||||||||||||||||
| self, param_distribution: FloatDistribution | IntDistribution | ||||||||||||||||||||||||
| ) -> np.ndarray: | ||||||||||||||||||||||||
| step = param_distribution.step | ||||||||||||||||||||||||
| low = param_distribution.low | ||||||||||||||||||||||||
| # The last element is padded to code the binary search routine cleaner. | ||||||||||||||||||||||||
| high = param_distribution.high + step | ||||||||||||||||||||||||
| assert step is not None | ||||||||||||||||||||||||
|
Comment on lines
+117
to
+120
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about moving this assertion here, since step is already used in the addition operation, which doesn’t allow
Suggested change
|
||||||||||||||||||||||||
| n_steps = int(np.round((high - low) / step)) + 1 | ||||||||||||||||||||||||
| return np.linspace(low, high, n_steps) | ||||||||||||||||||||||||
|
Comment on lines
+121
to
+122
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, the calculation of the possible parameter values is incorrect. |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
Comment on lines
+113
to
+123
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
| def _is_param_converged(self, study: Study, param_name: str) -> bool: | ||||||||||||||||||||||||
| left, right = self._get_current_left_and_right(study, param_name) | ||||||||||||||||||||||||
| dist = self._search_space[param_name] | ||||||||||||||||||||||||
| is_discrete = dist.step is not None | ||||||||||||||||||||||||
| if not is_discrete: | ||||||||||||||||||||||||
| return math.isclose(left, right, abs_tol=self._atol, rel_tol=self._rtol) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| possible_param_values = self._get_possible_param_values(dist) | ||||||||||||||||||||||||
| indices = np.arange(len(possible_param_values)) | ||||||||||||||||||||||||
| left_index = indices[np.isclose(possible_param_values, left)][0] | ||||||||||||||||||||||||
| right_index = indices[np.isclose(possible_param_values, right)][0] | ||||||||||||||||||||||||
| return right_index - left_index <= 1 | ||||||||||||||||||||||||
|
Comment on lines
+131
to
+135
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||
| def score_func(trial: Trial | FrozenTrial) -> float: | ||||||||||||||||||||||||
| score = 0.0 | ||||||||||||||||||||||||
| for k, param_value in trial.params.items(): | ||||||||||||||||||||||||
| low = trial.distributions[k].low | ||||||||||||||||||||||||
| high = trial.distributions[k].high | ||||||||||||||||||||||||
| is_too_high = trial.user_attrs[f"{k}_is_too_high"] | ||||||||||||||||||||||||
| score += (2 * is_too_high - 1) * (param_value - low) / (high - low) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return score | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||
| def get_best_param(study: Study) -> dict[str, Any]: | ||||||||||||||||||||||||
| best_param: dict[str, Any] = {} | ||||||||||||||||||||||||
| for t in study.trials: | ||||||||||||||||||||||||
| params = t.params | ||||||||||||||||||||||||
| user_attrs = t.user_attrs | ||||||||||||||||||||||||
| for k, v in params.items(): | ||||||||||||||||||||||||
| if user_attrs[f"{k}_is_too_high"]: | ||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| best_param[k] = max(v, best_param[k]) if k in best_param else v | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return best_param | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _enqueue_best_param(self, study: Study) -> None: | ||||||||||||||||||||||||
| study.enqueue_trial(self.get_best_param(study)) | ||||||||||||||||||||||||
| self._stop_flag = True | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def after_trial( | ||||||||||||||||||||||||
| self, study: Study, trial: FrozenTrial, state: TrialState, values: Sequence[float] | None | ||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||
| if values is None or state != TrialState.COMPLETE: | ||||||||||||||||||||||||
| return | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| converged = True | ||||||||||||||||||||||||
| for param_name, param_value in trial.params.items(): | ||||||||||||||||||||||||
| is_too_high_key = f"{param_name}_is_too_high" | ||||||||||||||||||||||||
| too_high = trial.user_attrs.get(is_too_high_key) | ||||||||||||||||||||||||
| if too_high is None or not isinstance(too_high, bool): | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| f"BisectSampler requires an attribute to judge whether each param is too high." | ||||||||||||||||||||||||
| f' Set it via `trial.set_user_attr("{is_too_high_key}", <True or False>)`.' | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if too_high: # param is too high. | ||||||||||||||||||||||||
| self._set_left_and_right(study, param_name, right=param_value) | ||||||||||||||||||||||||
| else: # param is too low. | ||||||||||||||||||||||||
| self._set_left_and_right(study, param_name, left=param_value) | ||||||||||||||||||||||||
| converged &= self._is_param_converged(study, param_name) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if self._stop_flag: | ||||||||||||||||||||||||
| study.stop() | ||||||||||||||||||||||||
| if converged and not self._stop_flag: | ||||||||||||||||||||||||
| self._enqueue_best_param(study) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if not math.isclose(self.score_func(trial), values[0]): | ||||||||||||||||||||||||
| expected_value = self.score_func(trial) | ||||||||||||||||||||||||
| got_value = values[0] | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "Please return `BisectSampler.score_func(trial)` in your objective. " | ||||||||||||||||||||||||
| f"Expected {expected_value}, but got {got_value}" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| from collections.abc import Callable | ||
|
|
||
| import optuna | ||
| import optunahub | ||
|
|
||
|
|
||
| BisectSampler = optunahub.load_module("samplers/bisect").BisectSampler | ||
|
|
||
|
|
||
| def objective(trial: optuna.Trial, score_func: Callable[[optuna.Trial], float]) -> float: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does it include the |
||
| x = trial.suggest_float("x", -1, 1) | ||
| # For each param, e.g., `ZZZ`, please set `ZZZ_is_too_high`. | ||
| trial.set_user_attr("x_is_too_high", x > 0.5) | ||
| y = trial.suggest_float("y", -1, 1, step=0.2) | ||
| trial.set_user_attr("y_is_too_high", y > 0.2) | ||
| # Please use `BisectSampler.score_func`. | ||
| return BisectSampler.score_func(trial) | ||
|
|
||
|
|
||
| sampler = BisectSampler() | ||
| study = optuna.create_study(sampler=sampler) | ||
| study.optimize(lambda t: objective(t, BisectSampler.score_func), n_trials=20) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.