Skip to content

[Bug] Optimize_acqf with sample_around_best=True errors out with HOGP #1687

Open
@saitcakmak

Description

@saitcakmak

🐛 Bug

sample_points_around_best implicitly assumes that the posterior mean has shape batch x m, which does not hold with HOGP. These lines average over dimensions up to -2 assuming that these are all batch dimensions.

To reproduce

** Code snippet to reproduce **

import torch
from botorch.acquisition.objective import GenericMCObjective
from botorch.acquisition import qExpectedImprovement
from botorch.models import HigherOrderGP
from botorch.optim import optimize_acqf
from linear_operator.settings import _fast_solves

_fast_solves.default = True

bounds = torch.zeros(2, 2)
bounds[1] = 1.0
objective = GenericMCObjective(lambda samples: samples.sum(dim=(-1, -2)))
model = HigherOrderGP(torch.rand(10, 2), torch.randn(10, 2, 2)).eval()
acqf = qExpectedImprovement(
    model,
    best_f=0.0,
    objective=objective,
)

# Not using sample_around_best. This works fine.
optimize_acqf(
    acq_function=acqf,
    bounds=bounds,
    q=1,
    raw_samples=256,
    num_restarts=10,
)

# With sample_around_best. This errors out.
optimize_acqf(
    acq_function=acqf,
    bounds=bounds,
    q=1,
    raw_samples=256,
    num_restarts=10,
    options={"sample_around_best": True}
)

** Stack trace/error message **

tuple index out of range
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-4-9b3e99a04183> in <module>
----> 1 optimize_acqf(
      2     acq_function=acqf,
      3     bounds=bounds,
      4     q=1,
      5     raw_samples=256,
/mnt/xarfuse/uid-352651/f20c7833-seed-nspid4026531836_cgpid22290725-ns-4026531840/botorch/optim/optimize.py in optimize_acqf(acq_function, bounds, q, num_restarts, raw_samples, options, inequality_constraints, equality_constraints, nonlinear_inequality_constraints, fixed_features, post_processing_func, batch_initial_conditions, return_best_only, gen_candidates, sequential, **kwargs)
    484         kwargs=kwargs,
    485     )
--> 486     return _optimize_acqf(opt_acqf_inputs)
    487 
    488 
/mnt/xarfuse/uid-352651/f20c7833-seed-nspid4026531836_cgpid22290725-ns-4026531840/botorch/optim/optimize.py in _optimize_acqf(opt_inputs)
    514 
    515     # Batch optimization (including the case q=1)
--> 516     return _optimize_acqf_batch(
    517         opt_inputs=opt_inputs, start_time=start_time, timeout_sec=timeout_sec
    518     )
/mnt/xarfuse/uid-352651/f20c7833-seed-nspid4026531836_cgpid22290725-ns-4026531840/botorch/optim/optimize.py in _optimize_acqf_batch(opt_inputs, start_time, timeout_sec)
    228         batch_initial_conditions = opt_inputs.batch_initial_conditions
    229     else:
--> 230         batch_initial_conditions = opt_inputs.ic_generator(
    231             acq_function=opt_inputs.acq_function,
    232             bounds=opt_inputs.bounds,
/mnt/xarfuse/uid-352651/f20c7833-seed-nspid4026531836_cgpid22290725-ns-4026531840/botorch/optim/initializers.py in gen_batch_initial_conditions(acq_function, bounds, q, num_restarts, raw_samples, fixed_features, options, inequality_constraints, equality_constraints)
    167             # sample points around best
    168             if sample_around_best:
--> 169                 X_best_rnd = sample_points_around_best(
    170                     acq_function=acq_function,
    171                     n_discrete_points=n * q,
/mnt/xarfuse/uid-352651/f20c7833-seed-nspid4026531836_cgpid22290725-ns-4026531840/botorch/optim/initializers.py in sample_points_around_best(acq_function, n_discrete_points, sigma, bounds, best_pct, subset_sigma, prob_perturb)
    688             best_X = X[is_pareto]
    689         else:
--> 690             if f_pred.shape[-1] == 1:
    691                 f_pred = f_pred.squeeze(-1)
    692             n_best = max(1, round(X.shape[0] * best_pct / 100))
IndexError: tuple index out of range

Expected Behavior

It should work :).

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions