|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from typing import Any |
| 4 | +from typing import cast |
4 | 5 | from typing import TYPE_CHECKING |
| 6 | +from typing import TypedDict |
5 | 7 |
|
6 | 8 | import numpy as np |
7 | 9 | import optuna |
|
17 | 19 | from optuna.study import StudyDirection |
18 | 20 | from optuna.trial import FrozenTrial |
19 | 21 | from optuna.trial import TrialState |
| 22 | +from typing_extensions import NotRequired |
20 | 23 |
|
21 | 24 |
|
22 | 25 | if TYPE_CHECKING: |
|
39 | 42 | EPS = 1e-10 |
40 | 43 |
|
41 | 44 |
|
| 45 | +class _NoiseKWArgs(TypedDict): |
| 46 | + uniform_input_noise_rads: NotRequired[torch.Tensor] |
| 47 | + normal_input_noise_stdevs: NotRequired[torch.Tensor] |
| 48 | + |
| 49 | + |
42 | 50 | def _standardize_values(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
43 | 51 | clipped_values = gp.warn_and_convert_inf(values) |
44 | 52 | means = np.mean(clipped_values, axis=0) |
@@ -123,7 +131,10 @@ def __init__( |
123 | 131 | self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed) |
124 | 132 | self._intersection_search_space = optuna.search_space.IntersectionSearchSpace() |
125 | 133 | self._n_startup_trials = n_startup_trials |
126 | | - self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = prior.default_log_prior |
| 134 | + # We assume gp.GPRegressor is compatible with optuna._gp.gp.GPRegressor |
| 135 | + self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = cast( |
| 136 | + Callable[[gp.GPRegressor], torch.Tensor], prior.default_log_prior |
| 137 | + ) |
127 | 138 | self._minimum_noise: float = prior.DEFAULT_MINIMUM_NOISE_VAR |
128 | 139 | # We cache the kernel parameters for initial values of fitting the next time. |
129 | 140 | # TODO(nabenabe): Make the cache lists system_attrs to make GPSampler stateless. |
@@ -180,7 +191,8 @@ def _optimize_acqf( |
180 | 191 | # Particularly, we may remove this function in future refactoring. |
181 | 192 | assert best_params is None or len(best_params.shape) == 2 |
182 | 193 | normalized_params, _acqf_val = optim_mixed.optimize_acqf_mixed( |
183 | | - acqf, |
| 194 | + # We assume acqf_module.BaseAcquisitionFunc is compatible with optuna._gp.acqf.BaseAcquisitionFunc |
| 195 | + cast(optuna._gp.acqf.BaseAcquisitionFunc, acqf), |
184 | 196 | warmstart_normalized_params_array=best_params, |
185 | 197 | n_preliminary_samples=self._n_preliminary_samples, |
186 | 198 | n_local_search=self._n_local_search, |
@@ -267,7 +279,7 @@ def _get_scaled_input_noise_params( |
267 | 279 | scaled_input_noise_params[i] = input_noise_param / (dist.high - dist.low) |
268 | 280 | return scaled_input_noise_params |
269 | 281 |
|
270 | | - noise_kwargs: dict[str, torch.Tensor] = {} |
| 282 | + noise_kwargs: _NoiseKWArgs = {} |
271 | 283 | if self._uniform_input_noise_rads is not None: |
272 | 284 | scaled_input_noise_params = _get_scaled_input_noise_params( |
273 | 285 | self._uniform_input_noise_rads, "uniform_input_noise_rads" |
|
0 commit comments