Skip to content

Commit c7d0f72

Browse files
Merge pull request #328 from msakai/add-const-noisy
Extend RobustGPSampler to handle environmental disturbance
2 parents 0845269 + 895b998 commit c7d0f72

File tree

1 file changed

+71
-16
lines changed

1 file changed

+71
-16
lines changed

package/samplers/value_at_risk/sampler.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ class RobustGPSampler(BaseSampler):
101101
The input noise standard deviations for each parameter. For example, when
102102
`{"x": 0.1, "y": 0.2}` is given, the sampler assumes that the input noise of `x` and
103103
`y` follows `N(0, 0.1**2)` and `N(0, 0.2**2)`, respectively.
104+
const_noisy_param_names:
105+
The list of parameters determined externally rather than being decision variables.
106+
For these parameters, `suggest_float` samples random values instead of searching
107+
values that optimize the objective function.
104108
"""
105109

106110
def __init__(
@@ -114,6 +118,7 @@ def __init__(
114118
warn_independent_sampling: bool = True,
115119
uniform_input_noise_rads: dict[str, float] | None = None,
116120
normal_input_noise_stdevs: dict[str, float] | None = None,
121+
const_noisy_param_names: list[str] | None = None,
117122
) -> None:
118123
if uniform_input_noise_rads is None and normal_input_noise_stdevs is None:
119124
raise ValueError(
@@ -125,8 +130,25 @@ def __init__(
125130
"Only one of `uniform_input_noise_rads` and `normal_input_noise_stdevs` "
126131
"can be specified."
127132
)
133+
if const_noisy_param_names is not None:
134+
if uniform_input_noise_rads is not None and len(
135+
const_noisy_param_names & uniform_input_noise_rads.keys()
136+
):
137+
raise ValueError(
138+
"noisy parameters can be specified only in one of "
139+
"`const_noisy_param_names` and `uniform_input_noise_rads`."
140+
)
141+
if normal_input_noise_stdevs is not None and len(
142+
const_noisy_param_names & normal_input_noise_stdevs.keys()
143+
):
144+
raise ValueError(
145+
"noisy parameters can be specified only in one of "
146+
"`const_noisy_param_names` and `normal_input_noise_stdevs`."
147+
)
148+
128149
self._uniform_input_noise_rads = uniform_input_noise_rads
129150
self._normal_input_noise_stdevs = normal_input_noise_stdevs
151+
self._const_noisy_param_names = const_noisy_param_names or []
130152
self._rng = LazyRandomState(seed)
131153
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
132154
self._intersection_search_space = optuna.search_space.IntersectionSearchSpace()
@@ -183,17 +205,14 @@ def infer_relative_search_space(
183205

184206
return search_space
185207

186-
def _optimize_acqf(
187-
self, acqf: acqf_module.BaseAcquisitionFunc, best_params: np.ndarray | None
188-
) -> np.ndarray:
208+
def _optimize_acqf(self, acqf: acqf_module.BaseAcquisitionFunc) -> np.ndarray:
189209
# Advanced users can override this method to change the optimization algorithm.
190210
# However, we do not make any effort to keep backward compatibility between versions.
191211
# Particularly, we may remove this function in future refactoring.
192-
assert best_params is None or len(best_params.shape) == 2
193212
normalized_params, _acqf_val = optim_mixed.optimize_acqf_mixed(
194213
# We assume acqf_module.BaseAcquisitionFunc is compatible with optuna._gp.acqf.BaseAcquisitionFunc
195214
cast(optuna._gp.acqf.BaseAcquisitionFunc, acqf),
196-
warmstart_normalized_params_array=best_params,
215+
warmstart_normalized_params_array=None,
197216
n_preliminary_samples=self._n_preliminary_samples,
198217
n_local_search=self._n_local_search,
199218
tol=self._tol,
@@ -241,6 +260,14 @@ def _get_constraints_acqf_args(
241260
self._constraints_gprs_cache_list = constraints_gprs
242261
return constraints_gprs, constraints_threshold_list
243262

263+
def _get_internal_search_space_with_fixed_params(
264+
self, search_space: dict[str, BaseDistribution]
265+
) -> gp_search_space.SearchSpace:
266+
search_space_with_fixed_params = search_space.copy()
267+
for param_name in self._const_noisy_param_names:
268+
search_space_with_fixed_params[param_name] = optuna.distributions.IntDistribution(0, 0)
269+
return gp_search_space.SearchSpace(search_space_with_fixed_params)
270+
244271
def _get_value_at_risk(
245272
self,
246273
gpr: gp.GPRegressor,
@@ -280,23 +307,34 @@ def _get_scaled_input_noise_params(
280307
return scaled_input_noise_params
281308

282309
noise_kwargs: _NoiseKWArgs = {}
310+
const_noise_param_inds = [
311+
i
312+
for i, param_name in enumerate(search_space)
313+
if param_name in self._const_noisy_param_names
314+
]
283315
if self._uniform_input_noise_rads is not None:
284316
scaled_input_noise_params = _get_scaled_input_noise_params(
285317
self._uniform_input_noise_rads, "uniform_input_noise_rads"
286318
)
319+
scaled_input_noise_params[const_noise_param_inds] = 0.5
287320
noise_kwargs["uniform_input_noise_rads"] = scaled_input_noise_params
288321
elif self._normal_input_noise_stdevs is not None:
289322
scaled_input_noise_params = _get_scaled_input_noise_params(
290323
self._normal_input_noise_stdevs, "normal_input_noise_stdevs"
291324
)
325+
# NOTE(nabenabe): \pm 2 sigma will cover the domain.
326+
scaled_input_noise_params[const_noise_param_inds] = 0.25
292327
noise_kwargs["normal_input_noise_stdevs"] = scaled_input_noise_params
293328
else:
294329
assert False, "Should not reach here."
295330

331+
search_space_with_fixed_params = self._get_internal_search_space_with_fixed_params(
332+
search_space
333+
)
296334
if constraints_gpr_list is None or constraints_threshold_list is None:
297335
return acqf_module.ValueAtRisk(
298336
gpr=gpr,
299-
search_space=internal_search_space,
337+
search_space=search_space_with_fixed_params,
300338
confidence_level=self._objective_confidence_level,
301339
n_input_noise_samples=self._n_input_noise_samples,
302340
n_qmc_samples=self._n_qmc_samples,
@@ -307,7 +345,7 @@ def _get_scaled_input_noise_params(
307345
else:
308346
return acqf_module.ConstrainedLogValueAtRisk(
309347
gpr=gpr,
310-
search_space=internal_search_space,
348+
search_space=search_space_with_fixed_params,
311349
constraints_gpr_list=constraints_gpr_list,
312350
constraints_threshold_list=constraints_threshold_list,
313351
objective_confidence_level=self._objective_confidence_level,
@@ -379,19 +417,15 @@ def _get_gpr_list(
379417
self._gprs_cache_list = gprs_list
380418
return gprs_list
381419

382-
def sample_relative(
383-
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
420+
def _optimize_params(
421+
self, study: Study, trials: list[FrozenTrial], search_space: dict[str, BaseDistribution]
384422
) -> dict[str, Any]:
385423
if search_space == {}:
386424
return {}
387425

388426
self._verify_search_space(search_space)
389-
trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
390-
if len(trials) < self._n_startup_trials:
391-
return {}
392427

393428
gprs_list = self._get_gpr_list(study, search_space)
394-
best_params: np.ndarray | None
395429
acqf: acqf_module.BaseAcquisitionFunc
396430
assert len(gprs_list) == 1
397431
internal_search_space = gp_search_space.SearchSpace(search_space)
@@ -403,7 +437,6 @@ def sample_relative(
403437
search_space,
404438
acqf_type="mean",
405439
)
406-
best_params = None
407440
else:
408441
constraint_vals, _ = _get_constraint_vals_and_feasibility(study, trials)
409442
constr_gpr_list, constr_threshold_list = self._get_constraints_acqf_args(
@@ -420,11 +453,27 @@ def sample_relative(
420453
constraints_gpr_list=constr_gpr_list,
421454
constraints_threshold_list=constr_threshold_list,
422455
)
423-
best_params = None
424456

425-
normalized_param = self._optimize_acqf(acqf, best_params)
457+
normalized_param = self._optimize_acqf(acqf)
426458
return internal_search_space.get_unnormalized_param(normalized_param)
427459

460+
def sample_relative(
461+
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
462+
) -> dict[str, Any]:
463+
trials = study._get_trials(deepcopy=False, states=(TrialState.COMPLETE,), use_cache=True)
464+
if len(trials) < self._n_startup_trials:
465+
return {}
466+
467+
params = self._optimize_params(study, trials, search_space)
468+
469+
# Perturb constant noisy parameter uniformly
470+
for name in self._const_noisy_param_names:
471+
dist = search_space[name]
472+
assert isinstance(dist, optuna.distributions.FloatDistribution)
473+
params[name] = self._rng.rng.uniform(dist.low, dist.high)
474+
475+
return params
476+
428477
def get_robust_trial(self, study: Study) -> FrozenTrial:
429478
states = (TrialState.COMPLETE,)
430479
trials = study._get_trials(deepcopy=False, states=states, use_cache=True)
@@ -457,6 +506,12 @@ def get_robust_trial(self, study: Study) -> FrozenTrial:
457506
best_idx = np.argmax(acqf.eval_acqf_no_grad(X_train)).item()
458507
return trials[best_idx]
459508

509+
def get_robust_params(self, study: Study) -> dict[str, Any]:
510+
states = (TrialState.COMPLETE,)
511+
trials = study._get_trials(deepcopy=False, states=states, use_cache=True)
512+
search_space = self.infer_relative_search_space(study, trials[0])
513+
return self._optimize_params(study, trials, search_space)
514+
460515
def sample_independent(
461516
self,
462517
study: Study,

0 commit comments

Comments
 (0)