Skip to content

Commit c4ee5f5

Browse files
Separate covariance computations for customizable posterior inference
1 parent 38d7713 commit c4ee5f5

5 files changed

Lines changed: 418 additions & 114 deletions

File tree

gpytorch/kernels/rff_kernel.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#!/usr/bin/env python3
22

3-
from __future__ import annotations
4-
53
import math
64

75
import torch
@@ -120,22 +118,24 @@ def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch:
120118
if not hasattr(self, "randn_weights"):
121119
self._init_weights(num_dims, self.num_samples)
122120
x1_eq_x2 = torch.equal(x1, x2)
123-
z1 = self._featurize(x1, normalize=False)
121+
# Always use normalized features (scaled by 1/sqrt(D)) to ensure consistent
122+
# feature matrices regardless of whether x1 == x2 or not. This is important
123+
# for LinearPredictionStrategy, which extracts features from the LinearOperator.
124+
z1 = self._featurize(x1, normalize=True)
124125
if not x1_eq_x2:
125-
z2 = self._featurize(x2, normalize=False)
126+
z2 = self._featurize(x2, normalize=True)
126127
else:
127128
z2 = z1
128-
D = float(self.num_samples)
129129
if diag:
130-
return (z1 * z2).sum(-1) / D
130+
return (z1 * z2).sum(-1)
131131
if x1_eq_x2:
132132
# Exploit low rank structure, if there are fewer features than data points
133133
if z1.size(-1) < z2.size(-2):
134-
return LowRankRootLinearOperator(z1 / math.sqrt(D))
134+
return LowRankRootLinearOperator(z1)
135135
else:
136-
return RootLinearOperator(z1 / math.sqrt(D))
136+
return RootLinearOperator(z1)
137137
else:
138-
return MatmulLinearOperator(z1 / D, z2.transpose(-1, -2))
138+
return MatmulLinearOperator(z1, z2.transpose(-1, -2))
139139

140140
def _featurize(self, x: Tensor, normalize: bool = False) -> Tensor:
141141
# Recompute division each time to allow backprop through lengthscale

gpytorch/models/exact_gp.py

Lines changed: 113 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
from copy import deepcopy
99

1010
import torch
11+
from linear_operator.operators import LinearOperator
1112
from torch import Tensor
1213

14+
from gpytorch.distributions import Distribution
15+
1316
from .. import settings
1417
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
1518
from ..likelihoods import _GaussianLikelihoodBase
@@ -300,7 +303,7 @@ def __call__(self, *args, **kwargs):
300303

301304
# Get the terms that only depend on training data
302305
if self.prediction_strategy is None:
303-
train_output = super().__call__(*train_inputs, **kwargs)
306+
train_output = self._get_train_prior_distribution(train_inputs, **kwargs)
304307

305308
# Create the prediction strategy for
306309
self.prediction_strategy = prediction_strategy(
@@ -309,41 +312,117 @@ def __call__(self, *args, **kwargs):
309312
train_labels=self.train_targets,
310313
likelihood=self.likelihood,
311314
)
312-
313-
# Concatenate the input to the training input
314-
full_inputs = []
315-
batch_shape = train_inputs[0].shape[:-2]
316-
for train_input, input in length_safe_zip(train_inputs, inputs):
317-
# Make sure the batch shapes agree for training/test data
318-
if batch_shape != train_input.shape[:-2]:
319-
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
320-
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
321-
if batch_shape != input.shape[:-2]:
322-
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
323-
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
324-
input = input.expand(*batch_shape, *input.shape[-2:])
325-
full_inputs.append(torch.cat([train_input, input], dim=-2))
326-
327-
# Get the joint distribution for training/test data
328-
full_output = super().__call__(*full_inputs, **kwargs)
329-
if settings.debug.on():
330-
if not isinstance(full_output, MultivariateNormal):
331-
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
332-
full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
333-
334-
# Determine the shape of the joint distribution
335-
batch_shape = full_output.batch_shape
336-
joint_shape = full_output.event_shape
337-
tasks_shape = joint_shape[1:] # For multitask learning
338-
test_shape = torch.Size([joint_shape[0] - self.prediction_strategy.train_shape[0], *tasks_shape])
339-
315+
(
316+
test_mean,
317+
test_test_covar,
318+
test_train_covar,
319+
batch_shape,
320+
test_shape,
321+
posterior_class,
322+
) = self._get_test_prior_mean_and_covariances(train_inputs=train_inputs, test_inputs=inputs, **kwargs)
340323
# Make the prediction
341324
with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
342-
(
343-
predictive_mean,
344-
predictive_covar,
345-
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
325+
predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(
326+
test_mean=test_mean,
327+
test_test_covar=test_test_covar,
328+
test_train_covar=test_train_covar,
329+
)
346330

347331
# Reshape predictive mean to match the appropriate event shape
348332
predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
349-
return full_output.__class__(predictive_mean, predictive_covar)
333+
return posterior_class(predictive_mean, predictive_covar)
334+
335+
def _get_train_prior_distribution(
336+
self,
337+
train_inputs: Iterable[Tensor],
338+
**kwargs,
339+
) -> MultivariateNormal:
340+
"""Computes the prior distribution on the training set.
341+
342+
Override this method to customize train-train covariance computation.
343+
344+
Args:
345+
train_inputs: The inputs in the training set.
346+
kwargs: Additional keyword arguments passed to the model's forward method.
347+
348+
Returns:
349+
The prior distribution evaluated on the training set.
350+
"""
351+
return super().__call__(*train_inputs, **kwargs)
352+
353+
def _get_test_prior_mean_and_covariances(
354+
self,
355+
train_inputs: Iterable[Tensor | LinearOperator],
356+
test_inputs: Iterable[Tensor | LinearOperator],
357+
**kwargs,
358+
) -> tuple[Tensor, Tensor, Tensor, torch.Size, torch.Size, type[Distribution]]:
359+
"""Computes the prior mean and covariances on the test set.
360+
361+
Override this method to customize test-set covariance computations, e.g.,
362+
for models with partial observations or per-component additive inference.
363+
364+
The returned covariances may have additional leading batch dimensions
365+
(e.g., for additive component-wise inference). The prediction strategy
366+
handles broadcasting with the train-train covariance.
367+
368+
Note: This method is efficient even when test_inputs overlaps with
369+
train_inputs. Slicing the lazy joint covariance only evaluates
370+
K(test, [train||test]); K(train, train) is never computed.
371+
372+
Args:
373+
train_inputs: The training inputs.
374+
test_inputs: The test inputs.
375+
kwargs: Additional keyword arguments passed to the model's forward.
376+
377+
Returns:
378+
A tuple of (test_mean, test_test_covar, test_train_covar, batch_shape,
379+
test_shape, posterior_class).
380+
"""
381+
# Concatenate the input to the training input
382+
full_inputs = []
383+
batch_shape = train_inputs[0].shape[:-2]
384+
for train_input, input in length_safe_zip(train_inputs, test_inputs):
385+
# Make sure the batch shapes agree for training/test data
386+
if batch_shape != train_input.shape[:-2]:
387+
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
388+
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
389+
if batch_shape != input.shape[:-2]:
390+
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
391+
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
392+
input = input.expand(*batch_shape, *input.shape[-2:])
393+
full_inputs.append(torch.cat([train_input, input], dim=-2))
394+
395+
# Get joint distribution (lazy when settings.lazily_evaluate_kernels is True)
396+
full_output = super().__call__(*full_inputs, **kwargs)
397+
if settings.debug().on():
398+
if not isinstance(full_output, MultivariateNormal):
399+
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
400+
joint_mean, joint_covar = full_output.loc, full_output.lazy_covariance_matrix
401+
402+
# Determine the shape of the joint distribution
403+
batch_shape = full_output.batch_shape
404+
joint_shape = full_output.event_shape
405+
# For single-task GPs: event_shape = (num_points,), so tasks_shape = ()
406+
# For multitask GPs: event_shape = (num_points, num_tasks), so tasks_shape = (num_tasks,)
407+
# This captures any task dimensions beyond the primary data dimension.
408+
tasks_shape = joint_shape[1:]
409+
410+
# Compute test_shape: the event shape for test predictions.
411+
# For single-task GPs: test_shape = (num_test,)
412+
# For multitask GPs: test_shape = (num_test, num_tasks)
413+
num_test = joint_shape[0] - self.prediction_strategy.train_shape[0]
414+
test_shape = torch.Size([num_test, *tasks_shape])
415+
416+
# Find the components of the distribution that contain test data
417+
num_train = self.prediction_strategy.num_train
418+
test_mean = joint_mean[..., num_train:]
419+
420+
# Extract test covariances. Slicing is lazy; K(train, train) is never computed.
421+
# evaluate_kernel() converts to the linear operator type needed by prediction.
422+
# NOTE: We must slice row and column indices together (not sequentially) for
423+
# compatibility with BlockInterleavedLinearOperator used in multitask GPs.
424+
test_test_covar = joint_covar[..., num_train:, num_train:].evaluate_kernel()
425+
test_train_covar = joint_covar[..., num_train:, :num_train].evaluate_kernel()
426+
427+
posterior_class = full_output.__class__
428+
return (test_mean, test_test_covar, test_train_covar, batch_shape, test_shape, posterior_class)

0 commit comments

Comments
 (0)