Skip to content

[Bug] OOM regression in ExactGP prediction introduced by #2715 #2754

@illorens

Description

@illorens

[Bug] OOM regression in ExactGP prediction introduced by #2715

🐛 Bug

Regression in ExactGP posterior inference causes eager evaluation of joint_covar[..., num_train:, num_train:]. The failure occurs on large test batches (~1.5M points), which leads to an attempted allocation of ~27 TB and immediate CUDA OOM.

This appears to be introduced after #2715 in commit 3b9e89e (gpytorch 1.15.2).

To reproduce

** Code snippet to reproduce **

import torch
import gpytorch
from gpytorch import settings


class GPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)


# Toy setup with same structure as real failure case
train_x = torch.randn(3, 818, 16, device="cuda")  # [batch, points, features]
train_y = torch.randn(3, 818, device="cuda")
test_x = torch.randn(3, 1568800, 16, device="cuda")

likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
model = GPModel(train_x, train_y, likelihood).cuda()

model.eval()
likelihood.eval()
with (
    torch.no_grad(),
    settings.max_cholesky_size(1000),  # ensure exact inference path is used with 818 points
):
    posterior = model(test_x)
    print(posterior, flush=True)

print("Prediction finished successfully!")

** Stack trace/error message **
Output of breaking version (Commit 3b9e89e):

Traceback (most recent call last):
  File "/.../code_snippet.py", line 32, in <module>
    posterior = model(test_x)
                ^^^^^^^^^^^^^
  File "/.../gpytorch/models/exact_gp.py", line 321, in __call__
    ) = self._get_test_prior_mean_and_covariances(train_inputs=train_inputs, test_inputs=inputs, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/models/exact_gp.py", line 425, in _get_test_prior_mean_and_covariances
    test_test_covar = joint_covar[..., num_train:, num_train:].evaluate_kernel()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/utils/memoize.py", line 61, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 27, in wrapped
    output = method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
    res = self.kernel(
          ^^^^^^^^^^^^
  File "/.../gpytorch/kernels/kernel.py", line 533, in __call__
    res = to_linear_operator(super().__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params))
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/module.py", line 83, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/kernels/scale_kernel.py", line 109, in forward
    orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/kernels/rbf_kernel.py", line 80, in forward
    return RBFCovariance.apply(
           ^^^^^^^^^^^^^^^^^^^^
  File "/.../torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/functions/rbf_covariance.py", line 16, in forward
    unitless_sq_dist = sq_dist_func(x1_, x2_)
                       ^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/kernels/rbf_kernel.py", line 84, in <lambda>
    lambda x1, x2: self.covar_dist(x1, x2, square_dist=True, diag=False, **params),
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/kernels/kernel.py", line 352, in covar_dist
    return dist_func(x1, x2, x1_eq_x2)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../gpytorch/kernels/kernel.py", line 43, in sq_dist
    res = x1_.matmul(x2_.transpose(-2, -1))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 27505.31 GiB. GPU 0 has a total capacity of 95.04 GiB of which 92.37 GiB is free. Including non-PyTorch memory, this process has 2.66 GiB memory in use. Of the allocated memory 2.07 GiB is allocated by PyTorch, and 37.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Expected Behavior

Inference should complete without materializing a full K(test, test) covariance matrix, and should not exhibit O(N_test^2) memory usage.

Output of working version (Commit 685286d):

MultivariateNormal(loc: torch.Size([3, 1568800]))
Prediction finished successfully!

System information

  • Python version: 3.12.7
  • GPyTorch version: 1.15.2.dev23+g3b9e89eb5
  • PyTorch version: 2.6.0+cu124
  • Computer OS: Linux HPC cluster (Cray Shasta), kernel 5.14.21

Additional context

This regression was bisected and isolated to:

Settings such as:

gpytorch.settings.fast_pred_var = False
gpytorch.settings.max_eager_kernel_size = 512

do not change between working and failing environments.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions