diff --git a/package/samplers/value_at_risk/_gp/acqf.py b/package/samplers/value_at_risk/_gp/acqf.py index e52a6632..40ff03d0 100644 --- a/package/samplers/value_at_risk/_gp/acqf.py +++ b/package/samplers/value_at_risk/_gp/acqf.py @@ -11,10 +11,11 @@ if TYPE_CHECKING: from typing import Protocol - from optuna._gp.gp import GPRegressor from optuna._gp.search_space import SearchSpace import torch + from .gp import GPRegressor + class SobolGenerator(Protocol): def __call__(self, dim: int, n_samples: int, seed: int | None) -> torch.Tensor: raise NotImplementedError diff --git a/package/samplers/value_at_risk/requirements.txt b/package/samplers/value_at_risk/requirements.txt index e3f6b6a4..fa41cc4b 100644 --- a/package/samplers/value_at_risk/requirements.txt +++ b/package/samplers/value_at_risk/requirements.txt @@ -1,2 +1,3 @@ scipy torch +typing-extensions diff --git a/package/samplers/value_at_risk/sampler.py b/package/samplers/value_at_risk/sampler.py index fd54324e..75d13867 100644 --- a/package/samplers/value_at_risk/sampler.py +++ b/package/samplers/value_at_risk/sampler.py @@ -1,7 +1,9 @@ from __future__ import annotations from typing import Any +from typing import cast from typing import TYPE_CHECKING +from typing import TypedDict import numpy as np import optuna @@ -17,6 +19,7 @@ from optuna.study import StudyDirection from optuna.trial import FrozenTrial from optuna.trial import TrialState +from typing_extensions import NotRequired if TYPE_CHECKING: @@ -39,6 +42,11 @@ EPS = 1e-10 +class _NoiseKWArgs(TypedDict): + uniform_input_noise_rads: NotRequired[torch.Tensor] + normal_input_noise_stdevs: NotRequired[torch.Tensor] + + def _standardize_values(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: clipped_values = gp.warn_and_convert_inf(values) means = np.mean(clipped_values, axis=0) @@ -123,7 +131,10 @@ def __init__( self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed) self._intersection_search_space = optuna.search_space.IntersectionSearchSpace() self._n_startup_trials = n_startup_trials - self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = prior.default_log_prior + # We assume gp.GPRegressor is compatible with optuna._gp.gp.GPRegressor + self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = cast( + Callable[[gp.GPRegressor], torch.Tensor], prior.default_log_prior + ) self._minimum_noise: float = prior.DEFAULT_MINIMUM_NOISE_VAR # We cache the kernel parameters for initial values of fitting the next time. # TODO(nabenabe): Make the cache lists system_attrs to make GPSampler stateless. @@ -180,7 +191,8 @@ def _optimize_acqf( # Particularly, we may remove this function in future refactoring. assert best_params is None or len(best_params.shape) == 2 normalized_params, _acqf_val = optim_mixed.optimize_acqf_mixed( - acqf, + # We assume acqf_module.BaseAcquisitionFunc is compatible with optuna._gp.acqf.BaseAcquisitionFunc + cast(optuna._gp.acqf.BaseAcquisitionFunc, acqf), warmstart_normalized_params_array=best_params, n_preliminary_samples=self._n_preliminary_samples, n_local_search=self._n_local_search, @@ -267,7 +279,7 @@ def _get_scaled_input_noise_params( scaled_input_noise_params[i] = input_noise_param / (dist.high - dist.low) return scaled_input_noise_params - noise_kwargs: dict[str, torch.Tensor] = {} + noise_kwargs: _NoiseKWArgs = {} if self._uniform_input_noise_rads is not None: scaled_input_noise_params = _get_scaled_input_noise_params( self._uniform_input_noise_rads, "uniform_input_noise_rads" diff --git a/package/samplers/value_at_risk/tests/test_samplers.py b/package/samplers/value_at_risk/tests/test_samplers.py index 727467c5..4a9e46a6 100644 --- a/package/samplers/value_at_risk/tests/test_samplers.py +++ b/package/samplers/value_at_risk/tests/test_samplers.py @@ -280,8 +280,8 @@ def sample() -> float: @parametrize_noise_type def test_sample_relative_numerical( relative_sampler_class: Callable[[], BaseSampler], - x_distribution: BaseDistribution, - y_distribution: BaseDistribution, + x_distribution: FloatDistribution | IntDistribution, + y_distribution: FloatDistribution | IntDistribution, noise_type: str, ) -> None: can_x_be_noisy = x_distribution.step is None and not x_distribution.log