Skip to content

Commit 78c04e2

Browse files
hvarfnerfacebook-github-bot
authored andcommitted
Performance & runtime improvements to info-theoretic acquisition functions (0/N) - Restructuring of sampling methods (#2753)
Summary: Reshuffling of sampling methods that are not directly related to acquisition function optimization (i.e., don't take it as an argument) based on [this discussion](#2748 (comment)). To remove code duplication specifically related to optimization of info-theoretic acquisition functions, this seemed like sensible moves! Pull Request resolved: #2753 Test Plan: Moved unittests and added new one for `boltzmann_sample`, which was used throughout and is once again used in subsequent PRs. ## Related PRs First of a series, like [this one](#2748). Reviewed By: esantorella Differential Revision: D70131981 Pulled By: saitcakmak fbshipit-source-id: 48dd86e7e06006054294d7cd8b9a3d318b0b0ad1
1 parent 0be800e commit 78c04e2

File tree

4 files changed

+352
-266
lines changed

4 files changed

+352
-266
lines changed

botorch/optim/initializers.py

+24-156
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import warnings
1919
from collections.abc import Callable
20-
from math import ceil
2120
from typing import Optional, Union
2221

2322
import torch
@@ -43,14 +42,15 @@
4342
from botorch.optim.utils import fix_features, get_X_baseline
4443
from botorch.utils.multi_objective.pareto import is_non_dominated
4544
from botorch.utils.sampling import (
46-
batched_multinomial,
45+
boltzmann_sample,
4746
draw_sobol_samples,
4847
get_polytope_samples,
4948
manual_seed,
49+
sample_perturbed_subset_dims,
50+
sample_truncated_normal_perturbations,
5051
)
51-
from botorch.utils.transforms import normalize, standardize, unnormalize
52+
from botorch.utils.transforms import unnormalize
5253
from torch import Tensor
53-
from torch.distributions import Normal
5454
from torch.quasirandom import SobolEngine
5555

5656
TGenInitialConditions = Callable[
@@ -578,10 +578,12 @@ def gen_one_shot_kg_initial_conditions(
578578

579579
# sampling from the optimizers
580580
n_value = int((1 - frac_random) * (q_aug - q)) # number of non-random ICs
581-
eta = options.get("eta", 2.0)
582-
weights = torch.exp(eta * standardize(fantasy_vals))
583-
idx = torch.multinomial(weights, num_restarts * n_value, replacement=True)
584-
581+
idx = boltzmann_sample(
582+
function_values=fantasy_vals,
583+
num_samples=num_restarts * n_value,
584+
eta=options.get("eta", 2.0),
585+
replacement=True,
586+
)
585587
# set the respective initial conditions to the sampled optimizers
586588
ics[..., -n_value:, :] = fantasy_cands[idx, 0].view(num_restarts, n_value, -1)
587589
return ics
@@ -699,14 +701,14 @@ def gen_one_shot_hvkg_initial_conditions(
699701
sequential=False,
700702
)
701703
# sampling from the optimizers
702-
eta = options.get("eta", 2.0)
703704
if num_optim_restarts > 0:
704-
probs = torch.nn.functional.softmax(eta * standardize(fantasy_vals), dim=0)
705-
idx = torch.multinomial(
706-
probs,
707-
num_optim_restarts * acq_function.num_fantasies,
705+
idx = boltzmann_sample(
706+
function_values=fantasy_vals,
707+
num_samples=num_optim_restarts * acq_function.num_fantasies,
708+
eta=options.get("eta", 2.0),
708709
replacement=True,
709710
)
711+
710712
optim_ics = fantasy_cands[idx]
711713
if is_mf_hvkg:
712714
# add fixed features
@@ -885,11 +887,10 @@ def gen_value_function_initial_conditions(
885887
# sampling from the optimizers
886888
n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
887889
if n_value > 0:
888-
eta = options.get("eta", 2.0)
889-
weights = torch.exp(eta * standardize(fantasy_vals))
890-
idx = batched_multinomial(
891-
weights=weights.expand(*batch_shape, -1),
890+
idx = boltzmann_sample(
891+
function_values=fantasy_vals.expand(*batch_shape, -1),
892892
num_samples=n_value,
893+
eta=options.get("eta", 2.0),
893894
replacement=True,
894895
).permute(-1, *range(len(batch_shape)))
895896
resampled = fantasy_cands[idx]
@@ -979,18 +980,12 @@ def initialize_q_batch(
979980
return X[idcs], acq_vals[idcs]
980981

981982
max_val, max_idx = torch.max(acq_vals, dim=0)
982-
Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd
983-
etaZ = eta * Z
984-
weights = torch.exp(etaZ)
985-
while torch.isinf(weights).any():
986-
etaZ *= 0.5
987-
weights = torch.exp(etaZ)
988-
if batch_shape == torch.Size():
989-
idcs = torch.multinomial(weights, n)
990-
else:
991-
idcs = batched_multinomial(
992-
weights=weights.permute(*range(1, len(batch_shape) + 1), 0), num_samples=n
993-
).permute(-1, *range(len(batch_shape)))
983+
idcs = boltzmann_sample(
984+
acq_vals.permute(*range(1, len(batch_shape) + 1), 0),
985+
num_samples=n,
986+
eta=eta,
987+
).permute(-1, *range(len(batch_shape)))
988+
994989
# make sure we get the maximum
995990
if max_idx not in idcs:
996991
idcs[-1] = max_idx
@@ -1239,133 +1234,6 @@ def sample_points_around_best(
12391234
return perturbed_X
12401235

12411236

1242-
def sample_truncated_normal_perturbations(
1243-
X: Tensor,
1244-
n_discrete_points: int,
1245-
sigma: float,
1246-
bounds: Tensor,
1247-
qmc: bool = True,
1248-
) -> Tensor:
1249-
r"""Sample points around `X`.
1250-
1251-
Sample perturbed points around `X` such that the added perturbations
1252-
are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d.
1253-
1254-
Args:
1255-
X: A `n x d`-dim tensor starting points.
1256-
n_discrete_points: The number of points to sample.
1257-
sigma: The standard deviation of the additive gaussian noise for
1258-
perturbing the points.
1259-
bounds: A `2 x d`-dim tensor containing the bounds.
1260-
qmc: A boolean indicating whether to use qmc.
1261-
1262-
Returns:
1263-
A `n_discrete_points x d`-dim tensor containing the sampled points.
1264-
"""
1265-
X = normalize(X, bounds=bounds)
1266-
d = X.shape[1]
1267-
# sample points from N(X_center, sigma^2 I), truncated to be within
1268-
# [0, 1]^d.
1269-
if X.shape[0] > 1:
1270-
rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device)
1271-
X = X[rand_indices]
1272-
if qmc:
1273-
std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device)
1274-
std_bounds[1] = 1
1275-
u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1)
1276-
else:
1277-
u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device)
1278-
# compute bounds to sample from
1279-
a = -X
1280-
b = 1 - X
1281-
# compute z-score of bounds
1282-
alpha = a / sigma
1283-
beta = b / sigma
1284-
normal = Normal(0, 1)
1285-
cdf_alpha = normal.cdf(alpha)
1286-
# use inverse transform
1287-
perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma
1288-
# add perturbation and clip points that are still outside
1289-
perturbed_X = (X + perturbation).clamp(0.0, 1.0)
1290-
return unnormalize(perturbed_X, bounds=bounds)
1291-
1292-
1293-
def sample_perturbed_subset_dims(
1294-
X: Tensor,
1295-
bounds: Tensor,
1296-
n_discrete_points: int,
1297-
sigma: float = 1e-1,
1298-
qmc: bool = True,
1299-
prob_perturb: float | None = None,
1300-
) -> Tensor:
1301-
r"""Sample around `X` by perturbing a subset of the dimensions.
1302-
1303-
By default, dimensions are perturbed with probability equal to
1304-
`min(20 / d, 1)`. As shown in [Regis]_, perturbing a small number
1305-
of dimensions can be beneificial. The perturbations are sampled
1306-
from N(0, sigma^2 I) and truncated to be within [0,1]^d.
1307-
1308-
Args:
1309-
X: A `n x d`-dim tensor starting points. `X`
1310-
must be normalized to be within `[0, 1]^d`.
1311-
bounds: The bounds to sample perturbed values from
1312-
n_discrete_points: The number of points to sample.
1313-
sigma: The standard deviation of the additive gaussian noise for
1314-
perturbing the points.
1315-
qmc: A boolean indicating whether to use qmc.
1316-
prob_perturb: The probability of perturbing each dimension. If omitted,
1317-
defaults to `min(20 / d, 1)`.
1318-
1319-
Returns:
1320-
A `n_discrete_points x d`-dim tensor containing the sampled points.
1321-
1322-
"""
1323-
if bounds.ndim != 2:
1324-
raise BotorchTensorDimensionError("bounds must be a `2 x d`-dim tensor.")
1325-
elif X.ndim != 2:
1326-
raise BotorchTensorDimensionError("X must be a `n x d`-dim tensor.")
1327-
d = bounds.shape[-1]
1328-
if prob_perturb is None:
1329-
# Only perturb a subset of the features
1330-
prob_perturb = min(20.0 / d, 1.0)
1331-
1332-
if X.shape[0] == 1:
1333-
X_cand = X.repeat(n_discrete_points, 1)
1334-
else:
1335-
rand_indices = torch.randint(X.shape[0], (n_discrete_points,), device=X.device)
1336-
X_cand = X[rand_indices]
1337-
pert = sample_truncated_normal_perturbations(
1338-
X=X_cand,
1339-
n_discrete_points=n_discrete_points,
1340-
sigma=sigma,
1341-
bounds=bounds,
1342-
qmc=qmc,
1343-
)
1344-
1345-
# find cases where we are not perturbing any dimensions
1346-
mask = (
1347-
torch.rand(
1348-
n_discrete_points,
1349-
d,
1350-
dtype=bounds.dtype,
1351-
device=bounds.device,
1352-
)
1353-
<= prob_perturb
1354-
)
1355-
ind = (~mask).all(dim=-1).nonzero()
1356-
# perturb `n_perturb` of the dimensions
1357-
n_perturb = ceil(d * prob_perturb)
1358-
perturb_mask = torch.zeros(d, dtype=mask.dtype, device=mask.device)
1359-
perturb_mask[:n_perturb].fill_(1)
1360-
# TODO: use batched `torch.randperm` when available:
1361-
# https://github.com/pytorch/pytorch/issues/42502
1362-
for idx in ind:
1363-
mask[idx] = perturb_mask[torch.randperm(d, device=bounds.device)]
1364-
# Create candidate points
1365-
X_cand[mask] = pert[mask]
1366-
return X_cand
1367-
1368-
13691237
def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
13701238
r"""Determine whether a given acquisition function is non-negative.
13711239

0 commit comments

Comments
 (0)