Skip to content

[Bug] Inefficient posterior evaluation of SaasFullyBayesianSingleTaskGP when q=1 #2310

Open
@slishak-PX

Description

@slishak-PX

🐛 Bug

Evaluating an acquisition function with q=1 with SaasFullyBayesianSingleTaskGP requires an unnecessarily large amount of memory, due to an inefficient broadcasted matmul operation.

In the example below, the following line multiplies a tensor of size [256, 16, 1, 2048] with a tensor of size [16, 2048, 2048] which requires the allocation of 128GB of memory:
https://github.com/cornellius-gp/gpytorch/blob/9551eba889adf835b69cfd86e9a5d584fb61cdcc/gpytorch/models/exact_prediction_strategies.py#L118

To reproduce

** Code snippet to reproduce **

import torch
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.transforms import Standardize
from botorch import fit_fully_bayesian_model_nuts
from botorch.acquisition import UpperConfidenceBound

n_train = 2048
n_test = 256
d = 256

tkwargs = {
    "device": torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.double,
}

train_X = torch.rand(n_train, d, **tkwargs)
test_X = torch.rand(n_test, d, **tkwargs)
train_Y = torch.sin(train_X[:, :1])
test_Y = torch.sin(test_X[:, :1])

gp = SaasFullyBayesianSingleTaskGP(
    train_X=train_X, 
    train_Y=train_Y, 
    outcome_transform=Standardize(m=1),
)
fit_fully_bayesian_model_nuts(
    gp,
    warmup_steps=4,
    num_samples=16,
    thinning=1,
)

ucb = UpperConfidenceBound(gp, beta=2.5)
acq_values = ucb(test_X[:, None, :])

** Stack trace/error message **

Traceback (most recent call last):
  File "/tmp/ipykernel_3377365/3398296989.py", line 3, in <module>
    acq_values = ucb(test_X[:, None, :])
                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/utils/transforms.py", line 259, in decorated
    output = method(acqf, X, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 786, in forward
    mean, sigma = self._mean_and_sigma(X)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 106, in _mean_and_sigma
    posterior = self.model.posterior(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/models/fully_bayesian.py", line 536, in posterior
    posterior = super().posterior(
                ^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/models/gpytorch.py", line 383, in posterior
    mvn = self(X)
          ^^^^^^^
...
    return test_train_covar.matmul(precomputed_cache)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB. GPU 3 has a total capacity of 79.15 GiB of which 44.45 GiB is free. Including non-PyTorch memory, this process has 34.69 GiB memory in use. Of the allocated memory 23.74 GiB is allocated by PyTorch, and 10.42 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Expected Behavior

The memory usage for this operation is very high because torch.matmul is inefficient for such batched matrix-vector multiplications. If the same operation is written as an einsum, or transposing such that it's a matrix-matrix multiplication, the memory usage and computation time are substantially reduced.

For example, below is a demonstration of two alternative operations which reduce the memory and computation time by orders of magnitude:

import torch
device = "cuda:3"

# Matrices to multiply
torch.manual_seed(50)
a = torch.randn((256, 16, 1, 1024), device=device)
b = torch.randn((16, 1024, 1024), device=device)

def profile(func):
    torch.cuda.reset_peak_memory_stats(device=device)
    m0 = torch.cuda.max_memory_allocated(device=device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    out = func()
    end.record()

    torch.cuda.synchronize()
    t = start.elapsed_time(end)

    m1 = torch.cuda.max_memory_allocated(device=device)

    print(f"Memory used: {(m1 - m0) / 1024**3:.2f}GB")
    print(f"Time: {1e3 * t:.6f} ms")

    return out

with torch.no_grad():
    print("matmul")
    c = profile(lambda: torch.matmul(a, b))

    print("\neinsum")
    c_einsum = profile(lambda: torch.einsum("...ij,...jk", a, b))
    print(f"Max error: {(c_einsum - c).abs().max().cpu().item():.7f}")

    print("\ntransposed matmul")
    c_transpose = profile(lambda: torch.matmul(a.transpose(0, 2), b).transpose(0, 2))
    print(f"Max error: {(c_transpose - c).abs().max().cpu().item():.7f}")
matmul
Memory used: 16.02GB
Time: 261.343986 ms

einsum
Memory used: 0.02GB
Time: 160.416007 ms
Max error: 0.0002327

transposed matmul
Memory used: 0.02GB
Time: 118.303999 ms
Max error: 0.0002327

System information

Please complete the following information:

  • BoTorch Version 1.11
  • GPyTorch Version 0.9.5
  • PyTorch Version 2.2.0+cu121
  • Computer OS: Rocky Linux release 8.9
  • GPU: NVIDIA A100 80GB PCIe

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions