Open
Description
What happened?
When calling initialize_q_batch()
with a non-trivial batch_shape, the maximum of the input acq_values
for each batch is not guaranteed to be included in the selected result. In fact, as the batch_shape
is given more and more elements, the probability of including it shrinks to zero.
The code which is supposed to include the maximum is here:
botorch/botorch/optim/initializers.py
Lines 989 to 991 in 01aa45e
The variable
max_idx
has shape batch_shape
and idcs
has shape n x batch_shape
. The check max_idx not in idcs
only checks that the maximum is in idcs
for at least one batch. I would expect initialize_q_batch()
to ensure the maximum of the input acq_values
is included in every batch.
Please provide a minimal, reproducible example of the unexpected behavior.
import torch
from botorch.optim import initialize_q_batch
if __name__ == "__main__":
torch.manual_seed(1234)
X = torch.rand((20, 100, 1, 2)) # b x batch_shape x q x d
Y = torch.sum(X**2, dim=[-2, -1]) # b x batch_shape
true_max, true_max_idx = Y.max(dim=0) # (batch_shape, batch_shape)
X_init, acq_init = initialize_q_batch(X, acq_vals=Y, n=1)
acq_init_max, _ = acq_init.max(dim=0) # batch_shape
mask = acq_init_max != true_max
print(f"{mask.sum()} discrepancies")
if mask.any():
idx = torch.arange(X.shape[1])[mask][0]
print(f"E.g. Batch index {idx}:")
print(f" Max input: {true_max[idx]} (index {true_max_idx[idx]})")
print(f" Max selected: {acq_init_max[idx]}")
Please paste any relevant traceback/logs produced by the example provided.
76 discrepancies
E.g. Batch index 2:
Max input: 1.706216812133789 (index 1)
Max selected: 1.3131837844848633
BoTorch Version
0.13.0
Python Version
3.13.2
Operating System
Ubuntu 20.04.6 LTS (Focal Fossa)
Code of Conduct
- I agree to follow BoTorch's Code of Conduct