Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance & runtime improvements to info-theoretic acquisition functions (1/N) #2748

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def get_optimal_samples(
posterior_transform: ScalarizedPosteriorTransform | None = None,
objective: MCAcquisitionObjective | None = None,
return_transformed: bool = False,
options: dict | None = None,
) -> tuple[Tensor, Tensor]:
"""Draws sample paths from the posterior and maximizes the samples using GD.

Expand All @@ -551,7 +552,8 @@ def get_optimal_samples(
objective: An MCAcquisitionObjective, used to negate the objective or otherwise
transform sample outputs. Cannot be combined with `posterior_transform`.
return_transformed: If True, return the transformed samples.

options: Options for generation of initial candidates, passed to
gen_batch_initial_conditions.
Returns:
The optimal input locations and corresponding outputs, x* and f*.

Expand All @@ -576,12 +578,20 @@ def get_optimal_samples(
sample_transform = None

paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
suggested_points = prune_inferior_points(
model=model,
X=model.train_inputs[0],
posterior_transform=posterior_transform,
objective=objective,
)
optimal_inputs, optimal_outputs = optimize_posterior_samples(
paths=paths,
bounds=bounds,
raw_samples=raw_samples,
num_restarts=num_restarts,
sample_transform=sample_transform,
return_transformed=return_transformed,
suggested_points=suggested_points,
options=options,
)
return optimal_inputs, optimal_outputs
45 changes: 35 additions & 10 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,10 +999,12 @@ def sparse_to_dense_constraints(
def optimize_posterior_samples(
paths: GenericDeterministicModel,
bounds: Tensor,
raw_samples: int = 1024,
num_restarts: int = 20,
raw_samples: int = 2048,
num_restarts: int = 4,
sample_transform: Callable[[Tensor], Tensor] | None = None,
return_transformed: bool = False,
suggested_points: Tensor | None = None,
options: dict | None = None,
) -> tuple[Tensor, Tensor]:
r"""Cheaply maximizes posterior samples by random querying followed by
gradient-based optimization using SciPy's L-BFGS-B routine.
Expand All @@ -1011,19 +1013,27 @@ def optimize_posterior_samples(
paths: Random Fourier Feature-based sample paths from the GP
bounds: The bounds on the search space.
raw_samples: The number of samples with which to query the samples initially.
Raw samples are cheap to evaluate, so this should ideally be set much higher
than num_restarts.
num_restarts: The number of points selected for gradient-based optimization.
Should be set low relative to the number of raw samples for time-efficiency.
sample_transform: A callable transform of the sample outputs (e.g.
MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
negate the objective or otherwise transform the output.
return_transformed: A boolean indicating whether to return the transformed
or non-transformed samples.
suggested_points: Tensor of suggested input locations that are high-valued.
These are more densely evaluated during the sampling phase of optimization.
options: Options for generation of initial candidates, passed to
gen_batch_initial_conditions.

Returns:
A two-element tuple containing:
- X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
- f_opt: A `num_optima x [batch_size] x m`-dim, optionally
`num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
"""
options = {} if options is None else options

def path_func(x) -> Tensor:
res = paths(x)
Expand All @@ -1032,21 +1042,35 @@ def path_func(x) -> Tensor:

return res.squeeze(-1)

candidate_set = unnormalize(
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples),
bounds=bounds,
)
# queries all samples on all candidates - output shape
# raw_samples * num_optima * num_models
frac_random = 1 if suggested_points is None else options.get("frac_random", 0.9)
candidate_set = draw_sobol_samples(
bounds=bounds, n=round(raw_samples * frac_random), q=1
).squeeze(-2)
if frac_random < 1:
perturbed_suggestions = sample_truncated_normal_perturbations(
X=suggested_points,
n_discrete_points=round(raw_samples * (1 - frac_random)),
sigma=options.get("sample_around_best_sigma", 1e-2),
bounds=bounds,
)
candidate_set = torch.cat((candidate_set, perturbed_suggestions))

candidate_queries = path_func(candidate_set)
argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices
X_top_k = candidate_set[argtop_k, :]
idx = boltzmann_sample(
function_values=candidate_queries.unsqueeze(-1),
num_samples=num_restarts,
eta=options.get("eta", 2.0),
replacement=False,
)
ics = candidate_set[idx, :]

# to avoid circular import, the import occurs here
from botorch.generation.gen import gen_candidates_scipy

X_top_k, f_top_k = gen_candidates_scipy(
X_top_k,
ics,
path_func,
lower_bounds=bounds[0],
upper_bounds=bounds[1],
Expand Down Expand Up @@ -1101,8 +1125,9 @@ def boltzmann_sample(
eta *= temp_decrease
weights = torch.exp(eta * norm_weights)

# squeeze in case of m = 1 (mono-output provided as batch_size x N x 1)
return batched_multinomial(
weights=weights, num_samples=num_samples, replacement=replacement
weights=weights.squeeze(-1), num_samples=num_samples, replacement=replacement
)


Expand Down
29 changes: 22 additions & 7 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
UnsupportedError,
)
from botorch.models import SingleTaskGP
from botorch.utils.test_helpers import get_fully_bayesian_model, get_model
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultivariateNormal

Expand Down Expand Up @@ -413,17 +414,14 @@ def test_project_to_sample_points(self):


class TestGetOptimalSamples(BotorchTestCase):
def test_get_optimal_samples(self):
dims = 3
dtype = torch.float64
def _test_get_optimal_samples_base(self, model):
dims = model.train_inputs[0].shape[1]
dtype = model.train_targets.dtype
batch_shape = model.batch_shape
for_testing_speed_kwargs = {"raw_samples": 20, "num_restarts": 2}
num_optima = 7
batch_shape = (3,)

bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
model = SingleTaskGP(train_X=X, train_Y=Y)
posterior_transform = ScalarizedPosteriorTransform(
weights=torch.ones(1, dtype=dtype)
)
Expand All @@ -438,6 +436,7 @@ def test_get_optimal_samples(self):
num_optima=num_optima,
**for_testing_speed_kwargs,
)

correct_X_shape = (num_optima,) + batch_shape + (dims,)
correct_f_shape = (num_optima,) + batch_shape + (1,)
self.assertEqual(X_opt_def.shape, correct_X_shape)
Expand Down Expand Up @@ -519,6 +518,22 @@ def test_get_optimal_samples(self):
**for_testing_speed_kwargs,
)

def test_optimal_samples(self):
dims = 3
dtype = torch.float64
X = torch.rand(4, dims, dtype=dtype)
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
model = get_model(train_X=X, train_Y=Y)
self._test_get_optimal_samples_base(model)
fully_bayesian_model = get_fully_bayesian_model(
train_X=X,
train_Y=Y,
num_models=4,
standardize_model=True,
infer_noise=True,
)
self._test_get_optimal_samples_base(fully_bayesian_model)


class TestPreferenceUtils(BotorchTestCase):
def test_repeat_to_match_aug_dim(self):
Expand Down
21 changes: 13 additions & 8 deletions test/utils/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,13 @@ def test_optimize_posterior_samples(self):
dims = 2
dtype = torch.float64
eps = 1e-4
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
nums_optima = (1, 7)
batch_shapes = ((), (2,), (3, 2))
for_testing_speed_kwargs = {
"raw_samples": 64,
"num_restarts": 2,
"options": {"eta": 10},
}
nums_optima = (1, 5)
batch_shapes = ((), (3,))
posterior_transforms = (
None,
ScalarizedPosteriorTransform(weights=-torch.ones(1, dtype=dtype)),
Expand All @@ -589,16 +593,19 @@ def test_optimize_posterior_samples(self):
nums_optima, batch_shapes, posterior_transforms
):
bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
X = torch.rand(*batch_shape, 3, dims, dtype=dtype)
Y = torch.pow(X - 0.5, 2).sum(dim=-1, keepdim=True)

# having a noiseless model all but guarantees that the found optima
# will be better than the observations
model = SingleTaskGP(X, Y, torch.full_like(Y, eps))
model = SingleTaskGP(
train_X=X, train_Y=Y, train_Yvar=torch.full_like(Y, eps)
)
model.covar_module.lengthscale = 0.5
paths = get_matheron_path_model(
model=model, sample_shape=torch.Size([num_optima])
)

X_opt, f_opt = optimize_posterior_samples(
paths=paths,
bounds=bounds,
Expand All @@ -616,8 +623,6 @@ def test_optimize_posterior_samples(self):
self.assertTrue(torch.all(X_opt >= bounds[0]))
self.assertTrue(torch.all(X_opt <= bounds[1]))

# Check that the all found optima are larger than the observations
# This is not 100% deterministic, but just about.
Y_queries = paths(X)
# this is when we negate, so the values should be smaller
if posterior_transform:
Expand All @@ -642,7 +647,7 @@ def test_optimize_posterior_samples_multi_objective(self):
dims = 2
dtype = torch.float64
eps = 1e-4
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
for_testing_speed_kwargs = {"raw_samples": 64, "num_restarts": 2}
num_optima = 5
batch_shape = (3,)

Expand Down