Skip to content

[Bug] v1.15.2 slow torch.no_grad() predictions of mean and std #2736

@alegresor

Description

@alegresor

🐛 Bug

A simple adaptation of the simple GPyTorch Regression Tutorial gives significantly slower predictions in GPyTorch v1.15.2 (around 9 sec) than in GPyTorch v1.15.1 (less than 0.02 sec). The two differences from the linked tutorial are to

  • not use gpytorch.settings.fast_pred_var()
  • greatly increase the number of prediction sites to $65536 = 2^{16}$

To reproduce

import math
import torch
import gpytorch
import time

train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * math.sqrt(0.04)

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

model.eval()
likelihood.eval()

with torch.no_grad():#, gpytorch.settings.fast_pred_var():
    test_x = torch.linspace(0, 1, 2**16)
    t0 = time.time()
    observed_pred = likelihood(model(test_x))
    print(time.time()-t0) # around 9 seconds for gpytorch==1.15.2, less than 0.02 for gpytorch==1.15.1

Expected Behavior

Would expect the time to be similar for both v1.15.2 and v1.15.1

System information

  • GPyTorch v1.15.2 and GPyTorch v1.15.1
  • PyTorch v2.10.0
  • MacOS Apple Silicon

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions