88from copy import deepcopy
99
1010import torch
11+ from linear_operator .operators import LinearOperator
1112from torch import Tensor
1213
14+ from gpytorch .distributions import Distribution
15+
1316from .. import settings
1417from ..distributions import MultitaskMultivariateNormal , MultivariateNormal
1518from ..likelihoods import _GaussianLikelihoodBase
@@ -300,7 +303,7 @@ def __call__(self, *args, **kwargs):
300303
301304 # Get the terms that only depend on training data
302305 if self .prediction_strategy is None :
303- train_output = super (). __call__ ( * train_inputs , ** kwargs )
306+ train_output = self . _get_train_prior_distribution ( train_inputs , ** kwargs )
304307
305308 # Create the prediction strategy for
306309 self .prediction_strategy = prediction_strategy (
@@ -309,41 +312,110 @@ def __call__(self, *args, **kwargs):
309312 train_labels = self .train_targets ,
310313 likelihood = self .likelihood ,
311314 )
312-
313- # Concatenate the input to the training input
314- full_inputs = []
315- batch_shape = train_inputs [0 ].shape [:- 2 ]
316- for train_input , input in length_safe_zip (train_inputs , inputs ):
317- # Make sure the batch shapes agree for training/test data
318- if batch_shape != train_input .shape [:- 2 ]:
319- batch_shape = torch .broadcast_shapes (batch_shape , train_input .shape [:- 2 ])
320- train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
321- if batch_shape != input .shape [:- 2 ]:
322- batch_shape = torch .broadcast_shapes (batch_shape , input .shape [:- 2 ])
323- train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
324- input = input .expand (* batch_shape , * input .shape [- 2 :])
325- full_inputs .append (torch .cat ([train_input , input ], dim = - 2 ))
326-
327- # Get the joint distribution for training/test data
328- full_output = super ().__call__ (* full_inputs , ** kwargs )
329- if settings .debug .on ():
330- if not isinstance (full_output , MultivariateNormal ):
331- raise RuntimeError ("ExactGP.forward must return a MultivariateNormal" )
332- full_mean , full_covar = full_output .loc , full_output .lazy_covariance_matrix
333-
334- # Determine the shape of the joint distribution
335- batch_shape = full_output .batch_shape
336- joint_shape = full_output .event_shape
337- tasks_shape = joint_shape [1 :] # For multitask learning
338- test_shape = torch .Size ([joint_shape [0 ] - self .prediction_strategy .train_shape [0 ], * tasks_shape ])
339-
315+ (
316+ test_mean ,
317+ test_test_covar ,
318+ test_train_covar ,
319+ batch_shape ,
320+ test_shape ,
321+ posterior_class ,
322+ ) = self ._get_test_prior_mean_and_covariances (train_inputs = train_inputs , test_inputs = inputs , ** kwargs )
340323 # Make the prediction
341324 with settings .cg_tolerance (settings .eval_cg_tolerance .value ()):
342- (
343- predictive_mean ,
344- predictive_covar ,
345- ) = self .prediction_strategy .exact_prediction (full_mean , full_covar )
325+ (predictive_mean , predictive_covar ,) = self .prediction_strategy .exact_prediction (
326+ test_mean = test_mean ,
327+ test_test_covar = test_test_covar ,
328+ test_train_covar = test_train_covar ,
329+ )
346330
347331 # Reshape predictive mean to match the appropriate event shape
348332 predictive_mean = predictive_mean .view (* batch_shape , * test_shape ).contiguous ()
349- return full_output .__class__ (predictive_mean , predictive_covar )
333+ return posterior_class (predictive_mean , predictive_covar )
334+
335+ def _get_train_prior_distribution (
336+ self ,
337+ train_inputs : Iterable [Tensor ],
338+ ** kwargs ,
339+ ) -> MultivariateNormal :
340+ """Computes the prior distribution on the training set.
341+
342+ Override this method to customize train-train covariance computation.
343+
344+ Args:
345+ train_inputs: The inputs in the training set.
346+ kwargs: Additional keyword arguments passed to the model's forward method.
347+
348+ Returns:
349+ The prior distribution evaluated on the training set.
350+ """
351+ return super ().__call__ (* train_inputs , ** kwargs )
352+
353+ def _get_test_prior_mean_and_covariances (
354+ self ,
355+ train_inputs : Iterable [Tensor | LinearOperator ],
356+ test_inputs : Iterable [Tensor | LinearOperator ],
357+ ** kwargs ,
358+ ) -> tuple [Tensor , Tensor , Tensor , torch .Size , torch .Size , type [Distribution ]]:
359+ """Computes the prior mean and covariances on the test set.
360+
361+ Override this method to customize test-set covariance computations, e.g.,
362+ for models with partial observations or per-component additive inference.
363+
364+ The returned covariances may have additional leading batch dimensions
365+ (e.g., for additive component-wise inference). The prediction strategy
366+ handles broadcasting with the train-train covariance.
367+
368+ Note: This method is efficient even when test_inputs overlaps with
369+ train_inputs. Slicing the lazy joint covariance only evaluates
370+ K(test, [train||test]); K(train, train) is never computed.
371+
372+ Args:
373+ train_inputs: The training inputs.
374+ test_inputs: The test inputs.
375+ kwargs: Additional keyword arguments passed to the model's forward.
376+
377+ Returns:
378+ A tuple of (test_mean, test_test_covar, test_train_covar, batch_shape,
379+ test_shape, posterior_class).
380+ """
381+ # Concatenate the input to the training input
382+ full_inputs = []
383+ batch_shape = train_inputs [0 ].shape [:- 2 ]
384+ for train_input , input in length_safe_zip (train_inputs , test_inputs ):
385+ # Make sure the batch shapes agree for training/test data
386+ if batch_shape != train_input .shape [:- 2 ]:
387+ batch_shape = torch .broadcast_shapes (batch_shape , train_input .shape [:- 2 ])
388+ train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
389+ if batch_shape != input .shape [:- 2 ]:
390+ batch_shape = torch .broadcast_shapes (batch_shape , input .shape [:- 2 ])
391+ train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
392+ input = input .expand (* batch_shape , * input .shape [- 2 :])
393+ full_inputs .append (torch .cat ([train_input , input ], dim = - 2 ))
394+
395+ # Get joint distribution (lazy when settings.lazily_evaluate_kernels is True)
396+ full_output = super ().__call__ (* full_inputs , ** kwargs )
397+ if settings .debug ().on ():
398+ if not isinstance (full_output , MultivariateNormal ):
399+ raise RuntimeError ("ExactGP.forward must return a MultivariateNormal" )
400+ joint_mean , joint_covar = full_output .loc , full_output .lazy_covariance_matrix
401+
402+ # Determine the shape of the joint distribution
403+ batch_shape = full_output .batch_shape
404+ joint_shape = full_output .event_shape
405+ tasks_shape = joint_shape [1 :] # For multitask learning
406+
407+ test_shape = torch .Size ([joint_shape [0 ] - self .prediction_strategy .train_shape [0 ], * tasks_shape ])
408+
409+ # Find the components of the distribution that contain test data
410+ num_train = self .prediction_strategy .num_train
411+ test_mean = joint_mean [..., num_train :]
412+
413+ # Extract test covariances. Slicing is lazy; K(train, train) is never computed.
414+ # evaluate_kernel() converts to the linear operator type needed by prediction.
415+ # NOTE: We must slice row and column indices together (not sequentially) for
416+ # compatibility with BlockInterleavedLinearOperator used in multitask GPs.
417+ test_test_covar = joint_covar [..., num_train :, num_train :].evaluate_kernel ()
418+ test_train_covar = joint_covar [..., num_train :, :num_train ].evaluate_kernel ()
419+
420+ posterior_class = full_output .__class__
421+ return (test_mean , test_test_covar , test_train_covar , batch_shape , test_shape , posterior_class )
0 commit comments