Skip to content

Commit da19131

Browse files
committed
Fix mypy errors of value_at_risk sampler
1 parent 1497b6d commit da19131

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

package/samplers/value_at_risk/_gp/acqf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
if TYPE_CHECKING:
1212
from typing import Protocol
1313

14-
from optuna._gp.gp import GPRegressor
1514
from optuna._gp.search_space import SearchSpace
1615
import torch
1716

17+
from .gp import GPRegressor
18+
1819
class SobolGenerator(Protocol):
1920
def __call__(self, dim: int, n_samples: int, seed: int | None) -> torch.Tensor:
2021
raise NotImplementedError

package/samplers/value_at_risk/sampler.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

3+
from typing import cast
34
from typing import Any
5+
from typing import NotRequired
46
from typing import TYPE_CHECKING
7+
from typing import TypedDict
58

69
import numpy as np
710
import optuna
@@ -39,6 +42,11 @@
3942
EPS = 1e-10
4043

4144

45+
class _NoiseKWArgs(TypedDict):
46+
uniform_input_noise_rads: NotRequired[torch.Tensor]
47+
normal_input_noise_stdevs: NotRequired[torch.Tensor]
48+
49+
4250
def _standardize_values(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
4351
clipped_values = gp.warn_and_convert_inf(values)
4452
means = np.mean(clipped_values, axis=0)
@@ -123,7 +131,10 @@ def __init__(
123131
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
124132
self._intersection_search_space = optuna.search_space.IntersectionSearchSpace()
125133
self._n_startup_trials = n_startup_trials
126-
self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = prior.default_log_prior
134+
# We assume gp.GPRegressor is compatible with optuna._gp.gp.GPRegressor
135+
self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = cast(
136+
Callable[[gp.GPRegressor], torch.Tensor], prior.default_log_prior
137+
)
127138
self._minimum_noise: float = prior.DEFAULT_MINIMUM_NOISE_VAR
128139
# We cache the kernel parameters for initial values of fitting the next time.
129140
# TODO(nabenabe): Make the cache lists system_attrs to make GPSampler stateless.
@@ -180,7 +191,8 @@ def _optimize_acqf(
180191
# Particularly, we may remove this function in future refactoring.
181192
assert best_params is None or len(best_params.shape) == 2
182193
normalized_params, _acqf_val = optim_mixed.optimize_acqf_mixed(
183-
acqf,
194+
# We assume acqf_module.BaseAcquisitionFunc is compatible with optuna._gp.acqf.BaseAcquisitionFunc
195+
cast(optuna._gp.acqf.BaseAcquisitionFunc, acqf),
184196
warmstart_normalized_params_array=best_params,
185197
n_preliminary_samples=self._n_preliminary_samples,
186198
n_local_search=self._n_local_search,
@@ -267,7 +279,7 @@ def _get_scaled_input_noise_params(
267279
scaled_input_noise_params[i] = input_noise_param / (dist.high - dist.low)
268280
return scaled_input_noise_params
269281

270-
noise_kwargs: dict[str, torch.Tensor] = {}
282+
noise_kwargs: _NoiseKWArgs = {}
271283
if self._uniform_input_noise_rads is not None:
272284
scaled_input_noise_params = _get_scaled_input_noise_params(
273285
self._uniform_input_noise_rads, "uniform_input_noise_rads"

package/samplers/value_at_risk/tests/test_samplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def sample() -> float:
280280
@parametrize_noise_type
281281
def test_sample_relative_numerical(
282282
relative_sampler_class: Callable[[], BaseSampler],
283-
x_distribution: BaseDistribution,
284-
y_distribution: BaseDistribution,
283+
x_distribution: FloatDistribution | IntDistribution,
284+
y_distribution: FloatDistribution | IntDistribution,
285285
noise_type: str,
286286
) -> None:
287287
can_x_be_noisy = x_distribution.step is None and not x_distribution.log

0 commit comments

Comments
 (0)