Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package/samplers/value_at_risk/_gp/acqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
if TYPE_CHECKING:
from typing import Protocol

from optuna._gp.gp import GPRegressor
from optuna._gp.search_space import SearchSpace
import torch

from .gp import GPRegressor

class SobolGenerator(Protocol):
def __call__(self, dim: int, n_samples: int, seed: int | None) -> torch.Tensor:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions package/samplers/value_at_risk/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
scipy
torch
typing-extensions
18 changes: 15 additions & 3 deletions package/samplers/value_at_risk/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import TypedDict

import numpy as np
import optuna
Expand All @@ -17,6 +19,7 @@
from optuna.study import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from typing_extensions import NotRequired


if TYPE_CHECKING:
Expand All @@ -39,6 +42,11 @@
EPS = 1e-10


class _NoiseKWArgs(TypedDict):
uniform_input_noise_rads: NotRequired[torch.Tensor]
normal_input_noise_stdevs: NotRequired[torch.Tensor]


def _standardize_values(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
clipped_values = gp.warn_and_convert_inf(values)
means = np.mean(clipped_values, axis=0)
Expand Down Expand Up @@ -123,7 +131,10 @@ def __init__(
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
self._intersection_search_space = optuna.search_space.IntersectionSearchSpace()
self._n_startup_trials = n_startup_trials
self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = prior.default_log_prior
# We assume gp.GPRegressor is compatible with optuna._gp.gp.GPRegressor
self._log_prior: Callable[[gp.GPRegressor], torch.Tensor] = cast(
Callable[[gp.GPRegressor], torch.Tensor], prior.default_log_prior
)
self._minimum_noise: float = prior.DEFAULT_MINIMUM_NOISE_VAR
# We cache the kernel parameters for initial values of fitting the next time.
# TODO(nabenabe): Make the cache lists system_attrs to make GPSampler stateless.
Expand Down Expand Up @@ -180,7 +191,8 @@ def _optimize_acqf(
# Particularly, we may remove this function in future refactoring.
assert best_params is None or len(best_params.shape) == 2
normalized_params, _acqf_val = optim_mixed.optimize_acqf_mixed(
acqf,
# We assume acqf_module.BaseAcquisitionFunc is compatible with optuna._gp.acqf.BaseAcquisitionFunc
cast(optuna._gp.acqf.BaseAcquisitionFunc, acqf),
warmstart_normalized_params_array=best_params,
n_preliminary_samples=self._n_preliminary_samples,
n_local_search=self._n_local_search,
Expand Down Expand Up @@ -267,7 +279,7 @@ def _get_scaled_input_noise_params(
scaled_input_noise_params[i] = input_noise_param / (dist.high - dist.low)
return scaled_input_noise_params

noise_kwargs: dict[str, torch.Tensor] = {}
noise_kwargs: _NoiseKWArgs = {}
if self._uniform_input_noise_rads is not None:
scaled_input_noise_params = _get_scaled_input_noise_params(
self._uniform_input_noise_rads, "uniform_input_noise_rads"
Expand Down
4 changes: 2 additions & 2 deletions package/samplers/value_at_risk/tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def sample() -> float:
@parametrize_noise_type
def test_sample_relative_numerical(
relative_sampler_class: Callable[[], BaseSampler],
x_distribution: BaseDistribution,
y_distribution: BaseDistribution,
x_distribution: FloatDistribution | IntDistribution,
y_distribution: FloatDistribution | IntDistribution,
noise_type: str,
) -> None:
can_x_be_noisy = x_distribution.step is None and not x_distribution.log
Expand Down