-
Notifications
You must be signed in to change notification settings - Fork 51
Add BisectSampler
#271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add BisectSampler
#271
Conversation
|
@kAIto47802 Could you review this PR? |
| @@ -0,0 +1,56 @@ | |||
| --- | |||
| author: Shuhei Watanabe | |||
| title: A Sampler Using Parameter-Wise Bisection, aka Binary, Search | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| title: A Sampler Using Parameter-Wise Bisection, aka Binary, Search | |
| title: A Sampler Using Parameter-Wise Bisection, aka Binary Search |
|
|
||
| def infer_relative_search_space( | ||
| self, study: optuna.Study, trial: optuna.trial.FrozenTrial | ||
| ) -> dict[str, optuna.distributions.BaseDistribution]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The addition of optuna.distributions is inconsistent. We can remove it because BaseDistributions is already imported.
| ) -> dict[str, optuna.distributions.BaseDistribution]: | |
| ) -> dict[str, BaseDistribution]: |
| param_name: str, | ||
| param_distribution: BaseDistribution, | ||
| ) -> Any: | ||
| if isinstance(param_distribution, optuna.distributions.CategoricalDistribution): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above. Note that we have to add from optuna.distributions import CategoricalDistribution at the begging of this file.
| if isinstance(param_distribution, optuna.distributions.CategoricalDistribution): | |
| if isinstance(param_distribution, CategoricalDistribution): |
| low = param_distribution.low | ||
| # The last element is padded to code the binary search routine cleaner. | ||
| high = param_distribution.high + step | ||
| assert step is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about moving this assertion here, since step is already used in the addition operation, which doesn’t allow None?
| low = param_distribution.low | |
| # The last element is padded to code the binary search routine cleaner. | |
| high = param_distribution.high + step | |
| assert step is not None | |
| assert step is not None | |
| low = param_distribution.low | |
| # The last element is padded to code the binary search routine cleaner. | |
| high = param_distribution.high + step |
| assert mid_index != len(possible_param_values) - 1, "The last element is for convenience." | ||
| return possible_param_values[mid_index].item() | ||
|
|
||
| def _get_possible_param_values( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think we need to list all possible parameter values for discrete distributions.
Doing so would make each sampling O(n_steps), which defeats the benefit of using binary search. Note that n_steps can be large, e.g., in trial.suggest_int("x", 0, 1 << 30, step=2).
We can avoid this by using the index of the discrete search space, as I’ll suggest in the alternative code below:
| possible_param_values = self._get_possible_param_values(dist) | ||
| indices = np.arange(len(possible_param_values)) | ||
| left_index = indices[np.isclose(possible_param_values, left)][0] | ||
| right_index = indices[np.isclose(possible_param_values, right)][0] | ||
| return right_index - left_index <= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| possible_param_values = self._get_possible_param_values(dist) | |
| indices = np.arange(len(possible_param_values)) | |
| left_index = indices[np.isclose(possible_param_values, left)][0] | |
| right_index = indices[np.isclose(possible_param_values, right)][0] | |
| return right_index - left_index <= 1 | |
| left_index = int(np.round((left - dist.low) / dist.step)) | |
| right_index = int(np.round((right - dist.low) / dist.step)) | |
| return right_index - left_index <= 1 |
| def _get_possible_param_values( | ||
| self, param_distribution: FloatDistribution | IntDistribution | ||
| ) -> np.ndarray: | ||
| step = param_distribution.step | ||
| low = param_distribution.low | ||
| # The last element is padded to code the binary search routine cleaner. | ||
| high = param_distribution.high + step | ||
| assert step is not None | ||
| n_steps = int(np.round((high - low) / step)) + 1 | ||
| return np.linspace(low, high, n_steps) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _get_possible_param_values( | |
| self, param_distribution: FloatDistribution | IntDistribution | |
| ) -> np.ndarray: | |
| step = param_distribution.step | |
| low = param_distribution.low | |
| # The last element is padded to code the binary search routine cleaner. | |
| high = param_distribution.high + step | |
| assert step is not None | |
| n_steps = int(np.round((high - low) / step)) + 1 | |
| return np.linspace(low, high, n_steps) |
| possible_param_values = self._get_possible_param_values(param_distribution) | ||
| indices = np.arange(len(possible_param_values)) | ||
| left_index = indices[np.isclose(possible_param_values, left)][0] | ||
| right_index = indices[np.isclose(possible_param_values, right)][0] | ||
| mid_index = (right_index + left_index) // 2 | ||
| assert mid_index != len(possible_param_values) - 1, "The last element is for convenience." | ||
| return possible_param_values[mid_index].item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| possible_param_values = self._get_possible_param_values(param_distribution) | |
| indices = np.arange(len(possible_param_values)) | |
| left_index = indices[np.isclose(possible_param_values, left)][0] | |
| right_index = indices[np.isclose(possible_param_values, right)][0] | |
| mid_index = (right_index + left_index) // 2 | |
| assert mid_index != len(possible_param_values) - 1, "The last element is for convenience." | |
| return possible_param_values[mid_index].item() | |
| left_index = int(np.round((left - param_distribution.low) / step)) | |
| right_index = int(np.round((right - param_distribution.low) / step)) | |
| mid_index = (left_index + right_index) // 2 | |
| return param_distribution.low + mid_index * step |
|
Also, I still do not understand the motivation for adding this sampler. The name |
| BisectSampler = optunahub.load_module("samplers/bisect").BisectSampler | ||
|
|
||
|
|
||
| def objective(trial: optuna.Trial, score_func: Callable[[optuna.Trial], float]) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it include the score_func argument, which will never be used and requires partial application before passed to study.optimize?
| PREFIX_LEFT = "bisect:left_" | ||
| PREFIX_RIGHT = "bisect:right_" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| PREFIX_LEFT = "bisect:left_" | |
| PREFIX_RIGHT = "bisect:right_" | |
| _PREFIX_LEFT = "bisect:left_" | |
| _PREFIX_RIGHT = "bisect:right_" |
| n_steps = int(np.round((high - low) / step)) + 1 | ||
| return np.linspace(low, high, n_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the calculation of the possible parameter values is incorrect.
Contributor Agreements
Please read the contributor agreements and if you agree, please click the checkbox below.
Tip
Please follow the Quick TODO list to smoothly merge your PR.
Motivation
Description of the changes
TODO List towards PR Merge
Please remove this section if this PR is not an addition of a new package.
Otherwise, please check the following TODO list:
./template/to create your package<COPYRIGHT HOLDER>inLICENSEof your package with your nameREADME.mdin your package__init__.pyfrom __future__ import annotationsat the head of any Python files that include typing to support older Python versionsREADME.mdREADME.md