Skip to content

[Bug]: initialize_q_batch does not always include the maximum value when called in batch mode #2772

Open
@JackBuck

Description

@JackBuck

What happened?

When calling initialize_q_batch() with a non-trivial batch_shape, the maximum of the input acq_values for each batch is not guaranteed to be included in the selected result. In fact, as the batch_shape is given more and more elements, the probability of including it shrinks to zero.

The code which is supposed to include the maximum is here:

# make sure we get the maximum
if max_idx not in idcs:
idcs[-1] = max_idx

The variable max_idx has shape batch_shape and idcs has shape n x batch_shape. The check max_idx not in idcs only checks that the maximum is in idcs for at least one batch. I would expect initialize_q_batch() to ensure the maximum of the input acq_values is included in every batch.

Please provide a minimal, reproducible example of the unexpected behavior.

import torch
from botorch.optim import initialize_q_batch

if __name__ == "__main__":
    torch.manual_seed(1234)

    X = torch.rand((20, 100, 1, 2))  # b x batch_shape x q x d
    Y = torch.sum(X**2, dim=[-2, -1])  # b x batch_shape

    true_max, true_max_idx = Y.max(dim=0)  # (batch_shape, batch_shape)

    X_init, acq_init = initialize_q_batch(X, acq_vals=Y, n=1)

    acq_init_max, _ = acq_init.max(dim=0)  # batch_shape
    mask = acq_init_max != true_max

    print(f"{mask.sum()} discrepancies")
    if mask.any():
        idx = torch.arange(X.shape[1])[mask][0]
        print(f"E.g. Batch index {idx}:")
        print(f"  Max input: {true_max[idx]} (index {true_max_idx[idx]})")
        print(f"  Max selected: {acq_init_max[idx]}")

Please paste any relevant traceback/logs produced by the example provided.

76 discrepancies
E.g. Batch index 2:
  Max input: 1.706216812133789 (index 1)
  Max selected: 1.3131837844848633

BoTorch Version

0.13.0

Python Version

3.13.2

Operating System

Ubuntu 20.04.6 LTS (Focal Fossa)

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions