1010import torch
1111from torch import Tensor
1212
13+ from gpytorch .distributions import Distribution
14+
1315from .. import settings
1416from ..distributions import MultitaskMultivariateNormal , MultivariateNormal
1517from ..likelihoods import _GaussianLikelihoodBase
@@ -300,7 +302,7 @@ def __call__(self, *args, **kwargs):
300302
301303 # Get the terms that only depend on training data
302304 if self .prediction_strategy is None :
303- train_output = super (). __call__ ( * train_inputs , ** kwargs )
305+ train_output = self . _get_train_prior_distribution ( train_inputs , ** kwargs )
304306
305307 # Create the prediction strategy for
306308 self .prediction_strategy = prediction_strategy (
@@ -309,41 +311,119 @@ def __call__(self, *args, **kwargs):
309311 train_labels = self .train_targets ,
310312 likelihood = self .likelihood ,
311313 )
312-
313- # Concatenate the input to the training input
314- full_inputs = []
315- batch_shape = train_inputs [0 ].shape [:- 2 ]
316- for train_input , input in length_safe_zip (train_inputs , inputs ):
317- # Make sure the batch shapes agree for training/test data
318- if batch_shape != train_input .shape [:- 2 ]:
319- batch_shape = torch .broadcast_shapes (batch_shape , train_input .shape [:- 2 ])
320- train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
321- if batch_shape != input .shape [:- 2 ]:
322- batch_shape = torch .broadcast_shapes (batch_shape , input .shape [:- 2 ])
323- train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
324- input = input .expand (* batch_shape , * input .shape [- 2 :])
325- full_inputs .append (torch .cat ([train_input , input ], dim = - 2 ))
326-
327- # Get the joint distribution for training/test data
328- full_output = super ().__call__ (* full_inputs , ** kwargs )
329- if settings .debug .on ():
330- if not isinstance (full_output , MultivariateNormal ):
331- raise RuntimeError ("ExactGP.forward must return a MultivariateNormal" )
332- full_mean , full_covar = full_output .loc , full_output .lazy_covariance_matrix
333-
334- # Determine the shape of the joint distribution
335- batch_shape = full_output .batch_shape
336- joint_shape = full_output .event_shape
337- tasks_shape = joint_shape [1 :] # For multitask learning
338- test_shape = torch .Size ([joint_shape [0 ] - self .prediction_strategy .train_shape [0 ], * tasks_shape ])
339-
314+ (
315+ test_mean ,
316+ test_test_covar ,
317+ test_train_covar ,
318+ batch_shape ,
319+ test_shape ,
320+ posterior_class ,
321+ ) = self ._get_test_prior_mean_and_covariances (train_inputs = train_inputs , test_inputs = inputs , ** kwargs )
340322 # Make the prediction
341323 with settings .cg_tolerance (settings .eval_cg_tolerance .value ()):
342- (
343- predictive_mean ,
344- predictive_covar ,
345- ) = self .prediction_strategy .exact_prediction (full_mean , full_covar )
324+ predictive_mean , predictive_covar = self .prediction_strategy .exact_prediction (
325+ test_mean = test_mean ,
326+ test_test_covar = test_test_covar ,
327+ test_train_covar = test_train_covar ,
328+ )
346329
347330 # Reshape predictive mean to match the appropriate event shape
348331 predictive_mean = predictive_mean .view (* batch_shape , * test_shape ).contiguous ()
349- return full_output .__class__ (predictive_mean , predictive_covar )
332+ return posterior_class (predictive_mean , predictive_covar )
333+
334+ def _get_train_prior_distribution (
335+ self ,
336+ train_inputs : Iterable [Tensor ],
337+ ** kwargs ,
338+ ) -> MultivariateNormal :
339+ """Computes the prior distribution on the training set.
340+
341+ Override this method to customize train-train covariance computation.
342+
343+ Args:
344+ train_inputs: The inputs in the training set.
345+ kwargs: Additional keyword arguments passed to the model's forward method.
346+
347+ Returns:
348+ The prior distribution evaluated on the training set.
349+ """
350+ # No prior_mode context needed: super().__call__() bypasses ExactGP.__call__
351+ # and goes directly to Module.__call__() -> forward(), which computes the prior.
352+ return super ().__call__ (* train_inputs , ** kwargs )
353+
354+ def _get_test_prior_mean_and_covariances (
355+ self ,
356+ train_inputs : Iterable [Tensor ],
357+ test_inputs : Iterable [Tensor ],
358+ ** kwargs ,
359+ ) -> tuple [Tensor , Tensor , Tensor , torch .Size , torch .Size , type [Distribution ]]:
360+ """Computes the prior mean and covariances on the test set.
361+
362+ Override this method to customize test-set covariance computations, e.g.,
363+ for models with partial observations or per-component additive inference.
364+
365+ The returned covariances may have additional leading batch dimensions
366+ (e.g., for additive component-wise inference). The prediction strategy
367+ handles broadcasting with the train-train covariance.
368+
369+ Note: This method is efficient even when test_inputs overlaps with
370+ train_inputs. Slicing the lazy joint covariance only evaluates
371+ K(test, [train||test]); K(train, train) is never computed.
372+
373+ Args:
374+ train_inputs: The training inputs.
375+ test_inputs: The test inputs.
376+ kwargs: Additional keyword arguments passed to the model's forward.
377+
378+ Returns:
379+ A tuple of (test_mean, test_test_covar, test_train_covar, batch_shape,
380+ test_shape, posterior_class).
381+ """
382+ # Concatenate the input to the training input
383+ full_inputs = []
384+ batch_shape = train_inputs [0 ].shape [:- 2 ]
385+ for train_input , input in length_safe_zip (train_inputs , test_inputs ):
386+ # Make sure the batch shapes agree for training/test data
387+ if batch_shape != train_input .shape [:- 2 ]:
388+ batch_shape = torch .broadcast_shapes (batch_shape , train_input .shape [:- 2 ])
389+ train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
390+ if batch_shape != input .shape [:- 2 ]:
391+ batch_shape = torch .broadcast_shapes (batch_shape , input .shape [:- 2 ])
392+ train_input = train_input .expand (* batch_shape , * train_input .shape [- 2 :])
393+ input = input .expand (* batch_shape , * input .shape [- 2 :])
394+ full_inputs .append (torch .cat ([train_input , input ], dim = - 2 ))
395+
396+ # Get joint distribution (lazy when settings.lazily_evaluate_kernels is True)
397+ full_output = super ().__call__ (* full_inputs , ** kwargs )
398+ if settings .debug ().on ():
399+ if not isinstance (full_output , MultivariateNormal ):
400+ raise RuntimeError ("ExactGP.forward must return a MultivariateNormal" )
401+ joint_mean , joint_covar = full_output .loc , full_output .lazy_covariance_matrix
402+
403+ # Determine the shape of the joint distribution
404+ batch_shape = full_output .batch_shape
405+ joint_shape = full_output .event_shape
406+ # For single-task GPs: event_shape = (num_points,), so tasks_shape = ()
407+ # For multitask GPs: event_shape = (num_points, num_tasks), so tasks_shape = (num_tasks,)
408+ # This captures any task dimensions beyond the primary data dimension.
409+ tasks_shape = joint_shape [1 :]
410+
411+ # Compute test_shape: the event shape for test predictions.
412+ # For single-task GPs: test_shape = (num_test,)
413+ # For multitask GPs: test_shape = (num_test, num_tasks)
414+ num_test = joint_shape [0 ] - self .prediction_strategy .train_shape [0 ]
415+ test_shape = torch .Size ([num_test , * tasks_shape ])
416+
417+ # Find the components of the distribution that contain test data
418+ num_train = self .prediction_strategy .num_train
419+ test_mean = joint_mean [..., num_train :]
420+
421+ # Extract test covariances. Slicing is lazy; K(train, train) is never computed.
422+ # evaluate_kernel() converts to the linear operator type needed by prediction.
423+ # NOTE: We must slice row and column indices together (not sequentially) for
424+ # compatibility with BlockInterleavedLinearOperator used in multitask GPs.
425+ test_test_covar = joint_covar [..., num_train :, num_train :].evaluate_kernel ()
426+ test_train_covar = joint_covar [..., num_train :, :num_train ].evaluate_kernel ()
427+
428+ posterior_class = full_output .__class__
429+ return (test_mean , test_test_covar , test_train_covar , batch_shape , test_shape , posterior_class )
0 commit comments