[Bug] OOM regression in ExactGP prediction introduced by #2715
🐛 Bug
Regression in ExactGP posterior inference causes eager evaluation of joint_covar[..., num_train:, num_train:]. The failure occurs on large test batches (~1.5M points), which leads to an attempted allocation of ~27 TB and immediate CUDA OOM.
This appears to be introduced after #2715 in commit 3b9e89e (gpytorch 1.15.2).
To reproduce
** Code snippet to reproduce **
import torch
import gpytorch
from gpytorch import settings
class GPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super().__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ZeroMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean, covar)
# Toy setup with same structure as real failure case
train_x = torch.randn(3, 818, 16, device="cuda") # [batch, points, features]
train_y = torch.randn(3, 818, device="cuda")
test_x = torch.randn(3, 1568800, 16, device="cuda")
likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
model = GPModel(train_x, train_y, likelihood).cuda()
model.eval()
likelihood.eval()
with (
torch.no_grad(),
settings.max_cholesky_size(1000), # ensure exact inference path is used with 818 points
):
posterior = model(test_x)
print(posterior, flush=True)
print("Prediction finished successfully!")
** Stack trace/error message **
Output of breaking version (Commit 3b9e89e):
Traceback (most recent call last):
File "/.../code_snippet.py", line 32, in <module>
posterior = model(test_x)
^^^^^^^^^^^^^
File "/.../gpytorch/models/exact_gp.py", line 321, in __call__
) = self._get_test_prior_mean_and_covariances(train_inputs=train_inputs, test_inputs=inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/models/exact_gp.py", line 425, in _get_test_prior_mean_and_covariances
test_test_covar = joint_covar[..., num_train:, num_train:].evaluate_kernel()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/utils/memoize.py", line 61, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 27, in wrapped
output = method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
res = self.kernel(
^^^^^^^^^^^^
File "/.../gpytorch/kernels/kernel.py", line 533, in __call__
res = to_linear_operator(super().__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/module.py", line 83, in __call__
outputs = self.forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/kernels/scale_kernel.py", line 109, in forward
orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/kernels/rbf_kernel.py", line 80, in forward
return RBFCovariance.apply(
^^^^^^^^^^^^^^^^^^^^
File "/.../torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/functions/rbf_covariance.py", line 16, in forward
unitless_sq_dist = sq_dist_func(x1_, x2_)
^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/kernels/rbf_kernel.py", line 84, in <lambda>
lambda x1, x2: self.covar_dist(x1, x2, square_dist=True, diag=False, **params),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/kernels/kernel.py", line 352, in covar_dist
return dist_func(x1, x2, x1_eq_x2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../gpytorch/kernels/kernel.py", line 43, in sq_dist
res = x1_.matmul(x2_.transpose(-2, -1))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 27505.31 GiB. GPU 0 has a total capacity of 95.04 GiB of which 92.37 GiB is free. Including non-PyTorch memory, this process has 2.66 GiB memory in use. Of the allocated memory 2.07 GiB is allocated by PyTorch, and 37.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Expected Behavior
Inference should complete without materializing a full K(test, test) covariance matrix, and should not exhibit O(N_test^2) memory usage.
Output of working version (Commit 685286d):
MultivariateNormal(loc: torch.Size([3, 1568800]))
Prediction finished successfully!
System information
- Python version: 3.12.7
- GPyTorch version: 1.15.2.dev23+g3b9e89eb5
- PyTorch version: 2.6.0+cu124
- Computer OS: Linux HPC cluster (Cray Shasta), kernel 5.14.21
Additional context
This regression was bisected and isolated to:
Settings such as:
gpytorch.settings.fast_pred_var = False
gpytorch.settings.max_eager_kernel_size = 512
do not change between working and failing environments.
[Bug] OOM regression in ExactGP prediction introduced by #2715
🐛 Bug
Regression in ExactGP posterior inference causes eager evaluation of joint_covar[..., num_train:, num_train:]. The failure occurs on large test batches (~1.5M points), which leads to an attempted allocation of ~27 TB and immediate CUDA OOM.
This appears to be introduced after #2715 in commit 3b9e89e (gpytorch 1.15.2).
To reproduce
** Code snippet to reproduce **
** Stack trace/error message **
Output of breaking version (Commit 3b9e89e):
Expected Behavior
Inference should complete without materializing a full K(test, test) covariance matrix, and should not exhibit O(N_test^2) memory usage.
Output of working version (Commit 685286d):
System information
Additional context
This regression was bisected and isolated to:
Settings such as:
do not change between working and failing environments.