Description
🐛 Bug
Evaluating an acquisition function with q=1
with SaasFullyBayesianSingleTaskGP
requires an unnecessarily large amount of memory, due to an inefficient broadcasted matmul
operation.
In the example below, the following line multiplies a tensor of size [256, 16, 1, 2048]
with a tensor of size [16, 2048, 2048]
which requires the allocation of 128GB of memory:
https://github.com/cornellius-gp/gpytorch/blob/9551eba889adf835b69cfd86e9a5d584fb61cdcc/gpytorch/models/exact_prediction_strategies.py#L118
To reproduce
** Code snippet to reproduce **
import torch
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.transforms import Standardize
from botorch import fit_fully_bayesian_model_nuts
from botorch.acquisition import UpperConfidenceBound
n_train = 2048
n_test = 256
d = 256
tkwargs = {
"device": torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
"dtype": torch.double,
}
train_X = torch.rand(n_train, d, **tkwargs)
test_X = torch.rand(n_test, d, **tkwargs)
train_Y = torch.sin(train_X[:, :1])
test_Y = torch.sin(test_X[:, :1])
gp = SaasFullyBayesianSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
outcome_transform=Standardize(m=1),
)
fit_fully_bayesian_model_nuts(
gp,
warmup_steps=4,
num_samples=16,
thinning=1,
)
ucb = UpperConfidenceBound(gp, beta=2.5)
acq_values = ucb(test_X[:, None, :])
** Stack trace/error message **
Traceback (most recent call last):
File "/tmp/ipykernel_3377365/3398296989.py", line 3, in <module>
acq_values = ucb(test_X[:, None, :])
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/utils/transforms.py", line 259, in decorated
output = method(acqf, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 786, in forward
mean, sigma = self._mean_and_sigma(X)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 106, in _mean_and_sigma
posterior = self.model.posterior(
^^^^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/models/fully_bayesian.py", line 536, in posterior
posterior = super().posterior(
^^^^^^^^^^^^^^^^^^
File "/home/.../lib/python3.11/site-packages/botorch/models/gpytorch.py", line 383, in posterior
mvn = self(X)
^^^^^^^
...
return test_train_covar.matmul(precomputed_cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB. GPU 3 has a total capacity of 79.15 GiB of which 44.45 GiB is free. Including non-PyTorch memory, this process has 34.69 GiB memory in use. Of the allocated memory 23.74 GiB is allocated by PyTorch, and 10.42 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Expected Behavior
The memory usage for this operation is very high because torch.matmul
is inefficient for such batched matrix-vector multiplications. If the same operation is written as an einsum
, or transposing such that it's a matrix-matrix multiplication, the memory usage and computation time are substantially reduced.
For example, below is a demonstration of two alternative operations which reduce the memory and computation time by orders of magnitude:
import torch
device = "cuda:3"
# Matrices to multiply
torch.manual_seed(50)
a = torch.randn((256, 16, 1, 1024), device=device)
b = torch.randn((16, 1024, 1024), device=device)
def profile(func):
torch.cuda.reset_peak_memory_stats(device=device)
m0 = torch.cuda.max_memory_allocated(device=device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
out = func()
end.record()
torch.cuda.synchronize()
t = start.elapsed_time(end)
m1 = torch.cuda.max_memory_allocated(device=device)
print(f"Memory used: {(m1 - m0) / 1024**3:.2f}GB")
print(f"Time: {1e3 * t:.6f} ms")
return out
with torch.no_grad():
print("matmul")
c = profile(lambda: torch.matmul(a, b))
print("\neinsum")
c_einsum = profile(lambda: torch.einsum("...ij,...jk", a, b))
print(f"Max error: {(c_einsum - c).abs().max().cpu().item():.7f}")
print("\ntransposed matmul")
c_transpose = profile(lambda: torch.matmul(a.transpose(0, 2), b).transpose(0, 2))
print(f"Max error: {(c_transpose - c).abs().max().cpu().item():.7f}")
matmul
Memory used: 16.02GB
Time: 261.343986 ms
einsum
Memory used: 0.02GB
Time: 160.416007 ms
Max error: 0.0002327
transposed matmul
Memory used: 0.02GB
Time: 118.303999 ms
Max error: 0.0002327
System information
Please complete the following information:
- BoTorch Version 1.11
- GPyTorch Version 0.9.5
- PyTorch Version 2.2.0+cu121
- Computer OS: Rocky Linux release 8.9
- GPU: NVIDIA A100 80GB PCIe