|
17 | 17 |
|
18 | 18 | import warnings
|
19 | 19 | from collections.abc import Callable
|
20 |
| -from math import ceil |
21 | 20 | from typing import Optional, Union
|
22 | 21 |
|
23 | 22 | import torch
|
|
43 | 42 | from botorch.optim.utils import fix_features, get_X_baseline
|
44 | 43 | from botorch.utils.multi_objective.pareto import is_non_dominated
|
45 | 44 | from botorch.utils.sampling import (
|
46 |
| - batched_multinomial, |
| 45 | + boltzmann_sample, |
47 | 46 | draw_sobol_samples,
|
48 | 47 | get_polytope_samples,
|
49 | 48 | manual_seed,
|
| 49 | + sample_perturbed_subset_dims, |
| 50 | + sample_truncated_normal_perturbations, |
50 | 51 | )
|
51 |
| -from botorch.utils.transforms import normalize, standardize, unnormalize |
| 52 | +from botorch.utils.transforms import unnormalize |
52 | 53 | from torch import Tensor
|
53 |
| -from torch.distributions import Normal |
54 | 54 | from torch.quasirandom import SobolEngine
|
55 | 55 |
|
56 | 56 | TGenInitialConditions = Callable[
|
@@ -578,10 +578,12 @@ def gen_one_shot_kg_initial_conditions(
|
578 | 578 |
|
579 | 579 | # sampling from the optimizers
|
580 | 580 | n_value = int((1 - frac_random) * (q_aug - q)) # number of non-random ICs
|
581 |
| - eta = options.get("eta", 2.0) |
582 |
| - weights = torch.exp(eta * standardize(fantasy_vals)) |
583 |
| - idx = torch.multinomial(weights, num_restarts * n_value, replacement=True) |
584 |
| - |
| 581 | + idx = boltzmann_sample( |
| 582 | + function_values=fantasy_vals, |
| 583 | + num_samples=num_restarts * n_value, |
| 584 | + eta=options.get("eta", 2.0), |
| 585 | + replacement=True, |
| 586 | + ) |
585 | 587 | # set the respective initial conditions to the sampled optimizers
|
586 | 588 | ics[..., -n_value:, :] = fantasy_cands[idx, 0].view(num_restarts, n_value, -1)
|
587 | 589 | return ics
|
@@ -699,14 +701,14 @@ def gen_one_shot_hvkg_initial_conditions(
|
699 | 701 | sequential=False,
|
700 | 702 | )
|
701 | 703 | # sampling from the optimizers
|
702 |
| - eta = options.get("eta", 2.0) |
703 | 704 | if num_optim_restarts > 0:
|
704 |
| - probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals), dim=0) |
705 |
| - idx = torch.multinomial( |
706 |
| - probs, |
707 |
| - num_optim_restarts * acq_function.num_fantasies, |
| 705 | + idx = boltzmann_sample( |
| 706 | + function_values=fantasy_vals, |
| 707 | + num_samples=num_optim_restarts * acq_function.num_fantasies, |
| 708 | + eta=options.get("eta", 2.0), |
708 | 709 | replacement=True,
|
709 | 710 | )
|
| 711 | + |
710 | 712 | optim_ics = fantasy_cands[idx]
|
711 | 713 | if is_mf_hvkg:
|
712 | 714 | # add fixed features
|
@@ -885,11 +887,10 @@ def gen_value_function_initial_conditions(
|
885 | 887 | # sampling from the optimizers
|
886 | 888 | n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
|
887 | 889 | if n_value > 0:
|
888 |
| - eta = options.get("eta", 2.0) |
889 |
| - weights = torch.exp(eta * standardize(fantasy_vals)) |
890 |
| - idx = batched_multinomial( |
891 |
| - weights=weights.expand(*batch_shape, -1), |
| 890 | + idx = boltzmann_sample( |
| 891 | + function_values=fantasy_vals.expand(*batch_shape, -1), |
892 | 892 | num_samples=n_value,
|
| 893 | + eta=options.get("eta", 2.0), |
893 | 894 | replacement=True,
|
894 | 895 | ).permute(-1, *range(len(batch_shape)))
|
895 | 896 | resampled = fantasy_cands[idx]
|
@@ -979,18 +980,12 @@ def initialize_q_batch(
|
979 | 980 | return X[idcs], acq_vals[idcs]
|
980 | 981 |
|
981 | 982 | max_val, max_idx = torch.max(acq_vals, dim=0)
|
982 |
| - Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd |
983 |
| - etaZ = eta * Z |
984 |
| - weights = torch.exp(etaZ) |
985 |
| - while torch.isinf(weights).any(): |
986 |
| - etaZ *= 0.5 |
987 |
| - weights = torch.exp(etaZ) |
988 |
| - if batch_shape == torch.Size(): |
989 |
| - idcs = torch.multinomial(weights, n) |
990 |
| - else: |
991 |
| - idcs = batched_multinomial( |
992 |
| - weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n |
993 |
| - ).permute(-1, *range(len(batch_shape))) |
| 983 | + idcs = boltzmann_sample( |
| 984 | + acq_vals.permute(*range(1, len(batch_shape) + 1), 0), |
| 985 | + num_samples=n, |
| 986 | + eta=eta, |
| 987 | + ).permute(-1, *range(len(batch_shape))) |
| 988 | + |
994 | 989 | # make sure we get the maximum
|
995 | 990 | if max_idx not in idcs:
|
996 | 991 | idcs[-1] = max_idx
|
@@ -1239,133 +1234,6 @@ def sample_points_around_best(
|
1239 | 1234 | return perturbed_X
|
1240 | 1235 |
|
1241 | 1236 |
|
1242 |
| -def sample_truncated_normal_perturbations( |
1243 |
| - X: Tensor, |
1244 |
| - n_discrete_points: int, |
1245 |
| - sigma: float, |
1246 |
| - bounds: Tensor, |
1247 |
| - qmc: bool = True, |
1248 |
| -) -> Tensor: |
1249 |
| - r"""Sample points around `X`. |
1250 |
| -
|
1251 |
| - Sample perturbed points around `X` such that the added perturbations |
1252 |
| - are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d. |
1253 |
| -
|
1254 |
| - Args: |
1255 |
| - X: A `n x d`-dim tensor starting points. |
1256 |
| - n_discrete_points: The number of points to sample. |
1257 |
| - sigma: The standard deviation of the additive gaussian noise for |
1258 |
| - perturbing the points. |
1259 |
| - bounds: A `2 x d`-dim tensor containing the bounds. |
1260 |
| - qmc: A boolean indicating whether to use qmc. |
1261 |
| -
|
1262 |
| - Returns: |
1263 |
| - A `n_discrete_points x d`-dim tensor containing the sampled points. |
1264 |
| - """ |
1265 |
| - X = normalize(X, bounds=bounds) |
1266 |
| - d = X.shape[1] |
1267 |
| - # sample points from N(X_center, sigma^2 I), truncated to be within |
1268 |
| - # [0, 1]^d. |
1269 |
| - if X.shape[0] > 1: |
1270 |
| - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) |
1271 |
| - X = X[rand_indices] |
1272 |
| - if qmc: |
1273 |
| - std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device) |
1274 |
| - std_bounds[1] = 1 |
1275 |
| - u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1) |
1276 |
| - else: |
1277 |
| - u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device) |
1278 |
| - # compute bounds to sample from |
1279 |
| - a = -X |
1280 |
| - b = 1 - X |
1281 |
| - # compute z-score of bounds |
1282 |
| - alpha = a / sigma |
1283 |
| - beta = b / sigma |
1284 |
| - normal = Normal(0, 1) |
1285 |
| - cdf_alpha = normal.cdf(alpha) |
1286 |
| - # use inverse transform |
1287 |
| - perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma |
1288 |
| - # add perturbation and clip points that are still outside |
1289 |
| - perturbed_X = (X + perturbation).clamp(0.0, 1.0) |
1290 |
| - return unnormalize(perturbed_X, bounds=bounds) |
1291 |
| - |
1292 |
| - |
1293 |
| -def sample_perturbed_subset_dims( |
1294 |
| - X: Tensor, |
1295 |
| - bounds: Tensor, |
1296 |
| - n_discrete_points: int, |
1297 |
| - sigma: float = 1e-1, |
1298 |
| - qmc: bool = True, |
1299 |
| - prob_perturb: float | None = None, |
1300 |
| -) -> Tensor: |
1301 |
| - r"""Sample around `X` by perturbing a subset of the dimensions. |
1302 |
| -
|
1303 |
| - By default, dimensions are perturbed with probability equal to |
1304 |
| - `min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number |
1305 |
| - of dimensions can be beneificial. The perturbations are sampled |
1306 |
| - from N(0, sigma^2 I) and truncated to be within [0,1]^d. |
1307 |
| -
|
1308 |
| - Args: |
1309 |
| - X: A `n x d`-dim tensor starting points. `X` |
1310 |
| - must be normalized to be within `[0, 1]^d`. |
1311 |
| - bounds: The bounds to sample perturbed values from |
1312 |
| - n_discrete_points: The number of points to sample. |
1313 |
| - sigma: The standard deviation of the additive gaussian noise for |
1314 |
| - perturbing the points. |
1315 |
| - qmc: A boolean indicating whether to use qmc. |
1316 |
| - prob_perturb: The probability of perturbing each dimension. If omitted, |
1317 |
| - defaults to `min(20 / d, 1)`. |
1318 |
| -
|
1319 |
| - Returns: |
1320 |
| - A `n_discrete_points x d`-dim tensor containing the sampled points. |
1321 |
| -
|
1322 |
| - """ |
1323 |
| - if bounds.ndim != 2: |
1324 |
| - raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.") |
1325 |
| - elif X.ndim != 2: |
1326 |
| - raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.") |
1327 |
| - d = bounds.shape[-1] |
1328 |
| - if prob_perturb is None: |
1329 |
| - # Only perturb a subset of the features |
1330 |
| - prob_perturb = min(20.0 / d, 1.0) |
1331 |
| - |
1332 |
| - if X.shape[0] == 1: |
1333 |
| - X_cand = X.repeat(n_discrete_points, 1) |
1334 |
| - else: |
1335 |
| - rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device) |
1336 |
| - X_cand = X[rand_indices] |
1337 |
| - pert = sample_truncated_normal_perturbations( |
1338 |
| - X=X_cand, |
1339 |
| - n_discrete_points=n_discrete_points, |
1340 |
| - sigma=sigma, |
1341 |
| - bounds=bounds, |
1342 |
| - qmc=qmc, |
1343 |
| - ) |
1344 |
| - |
1345 |
| - # find cases where we are not perturbing any dimensions |
1346 |
| - mask = ( |
1347 |
| - torch.rand( |
1348 |
| - n_discrete_points, |
1349 |
| - d, |
1350 |
| - dtype=bounds.dtype, |
1351 |
| - device=bounds.device, |
1352 |
| - ) |
1353 |
| - <= prob_perturb |
1354 |
| - ) |
1355 |
| - ind = (~mask).all(dim=-1).nonzero() |
1356 |
| - # perturb `n_perturb` of the dimensions |
1357 |
| - n_perturb = ceil(d * prob_perturb) |
1358 |
| - perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device) |
1359 |
| - perturb_mask[:n_perturb].fill_(1) |
1360 |
| - # TODO: use batched `torch.randperm` when available: |
1361 |
| - # https://github.com/pytorch/pytorch/issues/42502 |
1362 |
| - for idx in ind: |
1363 |
| - mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)] |
1364 |
| - # Create candidate points |
1365 |
| - X_cand[mask] = pert[mask] |
1366 |
| - return X_cand |
1367 |
| - |
1368 |
| - |
1369 | 1237 | def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
|
1370 | 1238 | r"""Determine whether a given acquisition function is non-negative.
|
1371 | 1239 |
|
|
0 commit comments