Skip to content

Commit c1c1a7b

Browse files
Separate covariance computations for customizable posterior inference
1 parent 685286d commit c1c1a7b

5 files changed

Lines changed: 420 additions & 114 deletions

File tree

gpytorch/kernels/rff_kernel.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#!/usr/bin/env python3
22

3-
from __future__ import annotations
4-
53
import math
64

75
import torch
@@ -120,22 +118,24 @@ def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch:
120118
if not hasattr(self, "randn_weights"):
121119
self._init_weights(num_dims, self.num_samples)
122120
x1_eq_x2 = torch.equal(x1, x2)
123-
z1 = self._featurize(x1, normalize=False)
121+
# Always use normalized features (scaled by 1/sqrt(D)) to ensure consistent
122+
# feature matrices regardless of whether x1 == x2 or not. This is important
123+
# for LinearPredictionStrategy, which extracts features from the LinearOperator.
124+
z1 = self._featurize(x1, normalize=True)
124125
if not x1_eq_x2:
125-
z2 = self._featurize(x2, normalize=False)
126+
z2 = self._featurize(x2, normalize=True)
126127
else:
127128
z2 = z1
128-
D = float(self.num_samples)
129129
if diag:
130-
return (z1 * z2).sum(-1) / D
130+
return (z1 * z2).sum(-1)
131131
if x1_eq_x2:
132132
# Exploit low rank structure, if there are fewer features than data points
133133
if z1.size(-1) < z2.size(-2):
134-
return LowRankRootLinearOperator(z1 / math.sqrt(D))
134+
return LowRankRootLinearOperator(z1)
135135
else:
136-
return RootLinearOperator(z1 / math.sqrt(D))
136+
return RootLinearOperator(z1)
137137
else:
138-
return MatmulLinearOperator(z1 / D, z2.transpose(-1, -2))
138+
return MatmulLinearOperator(z1, z2.transpose(-1, -2))
139139

140140
def _featurize(self, x: Tensor, normalize: bool = False) -> Tensor:
141141
# Recompute division each time to allow backprop through lengthscale

gpytorch/models/exact_gp.py

Lines changed: 114 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch
1111
from torch import Tensor
1212

13+
from gpytorch.distributions import Distribution
14+
1315
from .. import settings
1416
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
1517
from ..likelihoods import _GaussianLikelihoodBase
@@ -300,7 +302,7 @@ def __call__(self, *args, **kwargs):
300302

301303
# Get the terms that only depend on training data
302304
if self.prediction_strategy is None:
303-
train_output = super().__call__(*train_inputs, **kwargs)
305+
train_output = self._get_train_prior_distribution(train_inputs, **kwargs)
304306

305307
# Create the prediction strategy for
306308
self.prediction_strategy = prediction_strategy(
@@ -309,41 +311,119 @@ def __call__(self, *args, **kwargs):
309311
train_labels=self.train_targets,
310312
likelihood=self.likelihood,
311313
)
312-
313-
# Concatenate the input to the training input
314-
full_inputs = []
315-
batch_shape = train_inputs[0].shape[:-2]
316-
for train_input, input in length_safe_zip(train_inputs, inputs):
317-
# Make sure the batch shapes agree for training/test data
318-
if batch_shape != train_input.shape[:-2]:
319-
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
320-
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
321-
if batch_shape != input.shape[:-2]:
322-
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
323-
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
324-
input = input.expand(*batch_shape, *input.shape[-2:])
325-
full_inputs.append(torch.cat([train_input, input], dim=-2))
326-
327-
# Get the joint distribution for training/test data
328-
full_output = super().__call__(*full_inputs, **kwargs)
329-
if settings.debug.on():
330-
if not isinstance(full_output, MultivariateNormal):
331-
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
332-
full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
333-
334-
# Determine the shape of the joint distribution
335-
batch_shape = full_output.batch_shape
336-
joint_shape = full_output.event_shape
337-
tasks_shape = joint_shape[1:] # For multitask learning
338-
test_shape = torch.Size([joint_shape[0] - self.prediction_strategy.train_shape[0], *tasks_shape])
339-
314+
(
315+
test_mean,
316+
test_test_covar,
317+
test_train_covar,
318+
batch_shape,
319+
test_shape,
320+
posterior_class,
321+
) = self._get_test_prior_mean_and_covariances(train_inputs=train_inputs, test_inputs=inputs, **kwargs)
340322
# Make the prediction
341323
with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
342-
(
343-
predictive_mean,
344-
predictive_covar,
345-
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
324+
predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(
325+
test_mean=test_mean,
326+
test_test_covar=test_test_covar,
327+
test_train_covar=test_train_covar,
328+
)
346329

347330
# Reshape predictive mean to match the appropriate event shape
348331
predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
349-
return full_output.__class__(predictive_mean, predictive_covar)
332+
return posterior_class(predictive_mean, predictive_covar)
333+
334+
def _get_train_prior_distribution(
335+
self,
336+
train_inputs: Iterable[Tensor],
337+
**kwargs,
338+
) -> MultivariateNormal:
339+
"""Computes the prior distribution on the training set.
340+
341+
Override this method to customize train-train covariance computation.
342+
343+
Args:
344+
train_inputs: The inputs in the training set.
345+
kwargs: Additional keyword arguments passed to the model's forward method.
346+
347+
Returns:
348+
The prior distribution evaluated on the training set.
349+
"""
350+
# No prior_mode context needed: super().__call__() bypasses ExactGP.__call__
351+
# and goes directly to Module.__call__() -> forward(), which computes the prior.
352+
return super().__call__(*train_inputs, **kwargs)
353+
354+
def _get_test_prior_mean_and_covariances(
355+
self,
356+
train_inputs: Iterable[Tensor],
357+
test_inputs: Iterable[Tensor],
358+
**kwargs,
359+
) -> tuple[Tensor, Tensor, Tensor, torch.Size, torch.Size, type[Distribution]]:
360+
"""Computes the prior mean and covariances on the test set.
361+
362+
Override this method to customize test-set covariance computations, e.g.,
363+
for models with partial observations or per-component additive inference.
364+
365+
The returned covariances may have additional leading batch dimensions
366+
(e.g., for additive component-wise inference). The prediction strategy
367+
handles broadcasting with the train-train covariance.
368+
369+
Note: This method is efficient even when test_inputs overlaps with
370+
train_inputs. Slicing the lazy joint covariance only evaluates
371+
K(test, [train||test]); K(train, train) is never computed.
372+
373+
Args:
374+
train_inputs: The training inputs.
375+
test_inputs: The test inputs.
376+
kwargs: Additional keyword arguments passed to the model's forward.
377+
378+
Returns:
379+
A tuple of (test_mean, test_test_covar, test_train_covar, batch_shape,
380+
test_shape, posterior_class).
381+
"""
382+
# Concatenate the input to the training input
383+
full_inputs = []
384+
batch_shape = train_inputs[0].shape[:-2]
385+
for train_input, input in length_safe_zip(train_inputs, test_inputs):
386+
# Make sure the batch shapes agree for training/test data
387+
if batch_shape != train_input.shape[:-2]:
388+
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
389+
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
390+
if batch_shape != input.shape[:-2]:
391+
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
392+
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
393+
input = input.expand(*batch_shape, *input.shape[-2:])
394+
full_inputs.append(torch.cat([train_input, input], dim=-2))
395+
396+
# Get joint distribution (lazy when settings.lazily_evaluate_kernels is True)
397+
full_output = super().__call__(*full_inputs, **kwargs)
398+
if settings.debug().on():
399+
if not isinstance(full_output, MultivariateNormal):
400+
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
401+
joint_mean, joint_covar = full_output.loc, full_output.lazy_covariance_matrix
402+
403+
# Determine the shape of the joint distribution
404+
batch_shape = full_output.batch_shape
405+
joint_shape = full_output.event_shape
406+
# For single-task GPs: event_shape = (num_points,), so tasks_shape = ()
407+
# For multitask GPs: event_shape = (num_points, num_tasks), so tasks_shape = (num_tasks,)
408+
# This captures any task dimensions beyond the primary data dimension.
409+
tasks_shape = joint_shape[1:]
410+
411+
# Compute test_shape: the event shape for test predictions.
412+
# For single-task GPs: test_shape = (num_test,)
413+
# For multitask GPs: test_shape = (num_test, num_tasks)
414+
num_test = joint_shape[0] - self.prediction_strategy.train_shape[0]
415+
test_shape = torch.Size([num_test, *tasks_shape])
416+
417+
# Find the components of the distribution that contain test data
418+
num_train = self.prediction_strategy.num_train
419+
test_mean = joint_mean[..., num_train:]
420+
421+
# Extract test covariances. Slicing is lazy; K(train, train) is never computed.
422+
# evaluate_kernel() converts to the linear operator type needed by prediction.
423+
# NOTE: We must slice row and column indices together (not sequentially) for
424+
# compatibility with BlockInterleavedLinearOperator used in multitask GPs.
425+
test_test_covar = joint_covar[..., num_train:, num_train:].evaluate_kernel()
426+
test_train_covar = joint_covar[..., num_train:, :num_train].evaluate_kernel()
427+
428+
posterior_class = full_output.__class__
429+
return (test_mean, test_test_covar, test_train_covar, batch_shape, test_shape, posterior_class)

0 commit comments

Comments
 (0)