From c281d3eb3485b29ac3447622bf6d14caba6527ad Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Fri, 1 May 2026 14:03:22 -0400 Subject: [PATCH 1/2] Fix prediction slowdown from eager kernel evaluation Keep kernel covariances lazy in `_get_test_prior_mean_and_covariances`. Let prediction strategies call `.evaluate_kernel()` only when needed. --- gpytorch/models/exact_gp.py | 11 +++++++---- gpytorch/models/exact_prediction_strategies.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/gpytorch/models/exact_gp.py b/gpytorch/models/exact_gp.py index 807af2cb7..42e9477af 100644 --- a/gpytorch/models/exact_gp.py +++ b/gpytorch/models/exact_gp.py @@ -8,6 +8,7 @@ from copy import deepcopy import torch +from linear_operator import LinearOperator from torch import Tensor from gpytorch.distributions import Distribution @@ -357,7 +358,7 @@ def _get_test_prior_mean_and_covariances( train_inputs: Iterable[Tensor], test_inputs: Iterable[Tensor], **kwargs, - ) -> tuple[Tensor, Tensor, Tensor, torch.Size, torch.Size, type[Distribution]]: + ) -> tuple[Tensor, LinearOperator, LinearOperator, torch.Size, torch.Size, type[Distribution]]: """Computes the prior mean and covariances on the test set. Override this method to customize test-set covariance computations, e.g., @@ -420,11 +421,13 @@ def _get_test_prior_mean_and_covariances( test_mean = joint_mean[..., num_train:] # Extract test covariances. Slicing is lazy; K(train, train) is never computed. - # evaluate_kernel() converts to the linear operator type needed by prediction. + # NOTE: We do not call ``.evaluate_kernel()`` even for test covariances. Keeping these covariances lazy allows + # downstream code to compute only what's needed (e.g., just the diagonal for variance). Prediction strategies + # should call ``.evaluate_kernel()`` themselves if needed. # NOTE: We must slice row and column indices together (not sequentially) for # compatibility with BlockInterleavedLinearOperator used in multitask GPs. - test_test_covar = joint_covar[..., num_train:, num_train:].evaluate_kernel() - test_train_covar = joint_covar[..., num_train:, :num_train].evaluate_kernel() + test_test_covar = joint_covar[..., num_train:, num_train:] + test_train_covar = joint_covar[..., num_train:, :num_train] posterior_class = full_output.__class__ return (test_mean, test_test_covar, test_train_covar, batch_shape, test_shape, posterior_class) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 73c83bc20..cfdc1cb24 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -765,6 +765,13 @@ def exact_prediction( ``(*batch_shape, num_test, num_tasks)`` - ``predictive_covar``: LinearOperator with same shape as ``test_test_covar`` """ + # Evaluate kernels to get concrete InterpolatedLinearOperator types + # needed for accessing interpolation indices/values. + if hasattr(test_test_covar, "evaluate_kernel"): + test_test_covar = test_test_covar.evaluate_kernel() + if hasattr(test_train_covar, "evaluate_kernel"): + test_train_covar = test_train_covar.evaluate_kernel() + return ( self.exact_predictive_mean(test_mean, test_train_covar), self.exact_predictive_covar(test_test_covar, test_train_covar), @@ -1068,6 +1075,11 @@ def exact_prediction( **test_test_covar.params, ) + # Evaluate test_train_covar to get concrete type for isinstance checks + # in exact_predictive_covar (MatmulLinearOperator, etc.) + if hasattr(test_train_covar, "evaluate_kernel"): + test_train_covar = test_train_covar.evaluate_kernel() + return ( self.exact_predictive_mean(test_mean, test_train_covar), self.exact_predictive_covar(test_test_covar, test_train_covar), From fe45da60f5fc732f17283104e74f7bac981b6b96 Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Thu, 28 May 2026 20:11:26 -0400 Subject: [PATCH 2/2] remove `test_test_covar.evaluate_kernel` --- gpytorch/models/exact_prediction_strategies.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index cfdc1cb24..52e19dbe4 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -765,10 +765,9 @@ def exact_prediction( ``(*batch_shape, num_test, num_tasks)`` - ``predictive_covar``: LinearOperator with same shape as ``test_test_covar`` """ - # Evaluate kernels to get concrete InterpolatedLinearOperator types - # needed for accessing interpolation indices/values. - if hasattr(test_test_covar, "evaluate_kernel"): - test_test_covar = test_test_covar.evaluate_kernel() + # Only ``test_train_covar`` must be a concrete ``InterpolatedLinearOperator`` so that we can access its + # interpolation indices and values. ``test_test_covar`` is only ever used in an addition and thus it is left + # lazy to avoid eagerly materializing the test-test covariance. if hasattr(test_train_covar, "evaluate_kernel"): test_train_covar = test_train_covar.evaluate_kernel()