Skip to content

Commit 83f2c4c

Browse files
SebastianAmentmeta-codesync[bot]
authored andcommitted
Add untransform support to efficient LOO cross-validation (#3288)
Summary: Pull Request resolved: #3288 Previously, efficient_loo_cv, loo_cv, and ensemble_loo_cv always returned results in the model's internal (transformed) space. This was inconsistent with model.posterior() and batch_cross_validation, which automatically untransform predictions via outcome_transform.untransform_posterior(). This adds an untransform parameter (default True) to all three LOO CV functions, so results are returned in the original outcome space by default. When the model has an outcome transform (e.g. Standardize), the LOO posterior and observed values are mapped back to the original space. The untransform=False escape hatch preserves access to the internal space for users who need it. Safety analysis: - For Standardize (most common case), the untransform is an exact analytical rescaling of the mean and covariance — no approximation involved. - For Log transforms, the untransform uses the exact log-normal mean (exp(mu + sigma^2/2)) and variance formulas, not naive exp(mu). - The only production consumer is Ax (ax/adapter/cross_validation.py), which performs its own Ax-level untransform chain on top. Because Ax's Standardize and botorch's Standardize on already-standardized data were approximately a no-op, the old behavior was approximately correct. With this change, botorch now correctly maps from internal space to Ax model space before Ax applies its own transforms, making the pipeline more correct with negligible numerical differences. No double-untransform risk exists because botorch and Ax untransforms operate at different abstraction levels (tensors vs Observation objects). Reviewed By: esantorella Differential Revision: D102340841 fbshipit-source-id: b20aa332006eec88e87da4e7804237f93b7289a2
1 parent 5ea736d commit 83f2c4c

4 files changed

Lines changed: 672 additions & 65 deletions

File tree

botorch/cross_validation.py

Lines changed: 199 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616
from botorch.exceptions.errors import UnsupportedError
1717
from botorch.fit import fit_gpytorch_mll
1818
from botorch.models.gpytorch import GPyTorchModel
19+
from botorch.models.likelihoods.sparse_outlier_noise import (
20+
SparseOutlierGaussianLikelihood,
21+
)
1922
from botorch.models.multitask import MultiTaskGP
2023
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
2124
from botorch.posteriors.gpytorch import GPyTorchPosterior
25+
from botorch.posteriors.posterior import Posterior
2226
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
2327
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
2428
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
@@ -51,10 +55,15 @@ class CVResults(NamedTuple):
5155
For ``efficient_loo_cv``, the posterior has the same shape structure to maintain
5256
consistency, though the underlying distribution is constructed from the
5357
efficient LOO formulas rather than from separate model fits.
58+
59+
NOTE: When ``untransform=True`` is used with a nonlinear outcome transform
60+
(e.g., ``Log``), the posterior will be a ``TransformedPosterior`` rather than
61+
a ``GPyTorchPosterior``. For ensemble models, it will be a
62+
``GaussianMixturePosterior``.
5463
"""
5564

5665
model: GPyTorchModel
57-
posterior: GPyTorchPosterior
66+
posterior: Posterior
5867
observed_Y: Tensor
5968
observed_Yvar: Tensor | None = None
6069

@@ -95,8 +104,7 @@ def gen_loo_cv_folds(
95104
>>> cv_folds.train_X.shape
96105
torch.Size([10, 9, 1])
97106
"""
98-
masks = torch.eye(train_X.shape[-2], dtype=torch.uint8, device=train_X.device)
99-
masks = masks.to(dtype=torch.bool)
107+
masks = torch.eye(train_X.shape[-2], dtype=torch.bool, device=train_X.device)
100108
if train_Y.dim() < train_X.dim():
101109
# add output dimension
102110
train_Y = train_Y.unsqueeze(-1)
@@ -223,7 +231,11 @@ def batch_cross_validation(
223231
)
224232

225233

226-
def loo_cv(model: GPyTorchModel, observation_noise: bool = True) -> CVResults:
234+
def loo_cv(
235+
model: GPyTorchModel,
236+
observation_noise: bool = True,
237+
untransform: bool = True,
238+
) -> CVResults:
227239
r"""Compute efficient Leave-One-Out cross-validation for a GP model.
228240
229241
This is a high-level convenience function that automatically dispatches to
@@ -243,6 +255,11 @@ def loo_cv(model: GPyTorchModel, observation_noise: bool = True) -> CVResults:
243255
LOO CV. For models where hyperparameter changes are significant, consider
244256
using ``batch_cross_validation`` instead.
245257
258+
NOTE: The ``untransform`` parameter defaults to True, which means results
259+
are returned in the original outcome space. Callers that previously relied
260+
on results in the model's internal (transformed) space should pass
261+
``untransform=False`` explicitly.
262+
246263
Args:
247264
model: A fitted GPyTorchModel. The model type determines which LOO CV
248265
implementation is used.
@@ -252,6 +269,12 @@ def loo_cv(model: GPyTorchModel, observation_noise: bool = True) -> CVResults:
252269
observation noise). The posterior variance is computed by
253270
subtracting the observation noise from the posterior predictive
254271
variance.
272+
untransform: If True (default), untransform the LOO predictions and
273+
observed values back to the original outcome space when the model
274+
has an outcome transform (e.g., ``Standardize``). This makes the
275+
results consistent with ``model.posterior()`` and
276+
``batch_cross_validation``. If False, return results in the
277+
model's internal (transformed) space.
255278
256279
Returns:
257280
CVResults: A named tuple containing:
@@ -283,15 +306,16 @@ def loo_cv(model: GPyTorchModel, observation_noise: bool = True) -> CVResults:
283306
- ``ensemble_loo_cv``: Direct access to the ensemble model implementation.
284307
- ``batch_cross_validation``: Full LOO CV with model refitting.
285308
"""
286-
if getattr(model, "_is_ensemble", False):
287-
return ensemble_loo_cv(model, observation_noise=observation_noise)
288-
else:
289-
return efficient_loo_cv(model, observation_noise=observation_noise)
309+
loo_fun = (
310+
ensemble_loo_cv if getattr(model, "_is_ensemble", False) else efficient_loo_cv
311+
)
312+
return loo_fun(model, observation_noise=observation_noise, untransform=untransform)
290313

291314

292315
def efficient_loo_cv(
293316
model: GPyTorchModel,
294317
observation_noise: bool = True,
318+
untransform: bool = True,
295319
) -> CVResults:
296320
r"""Compute efficient Leave-One-Out cross-validation for a GP model.
297321
@@ -332,12 +356,20 @@ def efficient_loo_cv(
332356
predictive variance (including observation noise). If False,
333357
return the posterior variance of the latent function (excluding
334358
observation noise).
359+
untransform: If True (default), untransform the LOO predictions and
360+
observed values back to the original outcome space when the model
361+
has an outcome transform (e.g., ``Standardize``). This makes the
362+
results consistent with ``model.posterior()`` and
363+
``batch_cross_validation``. If False, return results in the
364+
model's internal (transformed) space.
335365
336366
Returns:
337367
CVResults: A named tuple containing:
338368
- model: The fitted GP model.
339-
- posterior: A GPyTorchPosterior with the LOO predictive distributions.
340-
The posterior mean and variance have shape ``n x 1 x m`` or
369+
- posterior: The LOO predictive distributions (typically a
370+
``GPyTorchPosterior``; with nonlinear outcome transforms like
371+
``Log``, a ``TransformedPosterior``). The posterior mean and
372+
variance have shape ``n x 1 x m`` or
341373
``batch_shape x n x 1 x m``, matching the structure of
342374
``batch_cross_validation`` (n folds, 1 held-out point per fold,
343375
m outputs). The underlying distribution has diagonal covariance
@@ -385,6 +417,15 @@ def efficient_loo_cv(
385417
if isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
386418
observed_Yvar = _reshape_to_loo_cv_format(model.likelihood.noise, num_outputs)
387419

420+
# Untransform predictions and observed values back to original space
421+
if untransform and hasattr(model, "outcome_transform"):
422+
posterior, observed_Y, observed_Yvar = _untransform_loo_results(
423+
model=model,
424+
posterior=posterior,
425+
observed_Y=observed_Y,
426+
observed_Yvar=observed_Yvar,
427+
)
428+
388429
return CVResults(
389430
model=model,
390431
posterior=posterior,
@@ -393,6 +434,78 @@ def efficient_loo_cv(
393434
)
394435

395436

437+
def _untransform_loo_results(
438+
model: GPyTorchModel,
439+
posterior: Posterior,
440+
observed_Y: Tensor,
441+
observed_Yvar: Tensor | None,
442+
) -> tuple[Posterior, Tensor, Tensor | None]:
443+
r"""Untransform LOO CV results from model-internal space to original space.
444+
445+
Applies the model's outcome transform to map the LOO posterior and observed
446+
values back to the original (untransformed) outcome space. This uses
447+
``outcome_transform.untransform_posterior`` for the posterior and
448+
``outcome_transform.untransform`` for the observed values.
449+
450+
For linear transforms like ``Standardize``, the posterior is analytically
451+
rescaled and remains a ``GPyTorchPosterior``. For nonlinear transforms like
452+
``Log``, the posterior is wrapped in a ``TransformedPosterior``.
453+
454+
Args:
455+
model: The GP model with an ``outcome_transform`` attribute.
456+
posterior: The LOO posterior in the model's internal (transformed) space.
457+
Shape: ``n x 1 x m`` or ``batch_shape x n x 1 x m``.
458+
observed_Y: The observed Y values in transformed space with shape
459+
``n x 1 x m`` or ``batch_shape x n x 1 x m``.
460+
observed_Yvar: The observed noise variances in transformed space (if
461+
applicable) with shape ``n x 1 x m`` or ``batch_shape x n x 1 x m``.
462+
463+
Returns:
464+
A tuple of (posterior, observed_Y, observed_Yvar) in the original
465+
(untransformed) outcome space. Note: if the outcome transform has a
466+
``batch_shape`` (e.g., ``Standardize(batch_shape=[num_models])``),
467+
broadcasting during untransform may add leading dimensions to
468+
``observed_Y`` that the caller must align with the posterior layout.
469+
"""
470+
outcome_transform = model.outcome_transform
471+
472+
# Validate observed_Y shape before any transforms.
473+
if observed_Y.shape[-2] != 1:
474+
raise ValueError(
475+
"Expected observed_Y to have size 1 at dim -2 (q dimension), "
476+
f"got shape {observed_Y.shape}."
477+
)
478+
479+
posterior = outcome_transform.untransform_posterior(posterior)
480+
481+
# Untransform observed_Y (and observed_Yvar if present).
482+
# observed_Y has shape n x 1 x m; Standardize.untransform expects
483+
# batch_shape x n x m. We squeeze the q=1 dim, untransform, and restore it.
484+
observed_Y_squeezed = observed_Y.squeeze(-2)
485+
observed_Yvar_squeezed = (
486+
observed_Yvar.squeeze(-2) if observed_Yvar is not None else None
487+
)
488+
observed_Y_utf, observed_Yvar_utf = outcome_transform.untransform(
489+
observed_Y_squeezed, observed_Yvar_squeezed
490+
)
491+
observed_Y = observed_Y_utf.unsqueeze(-2)
492+
observed_Yvar = (
493+
observed_Yvar_utf.unsqueeze(-2) if observed_Yvar_utf is not None else None
494+
)
495+
496+
return posterior, observed_Y, observed_Yvar
497+
498+
499+
def _likelihood_requires_X(likelihood: object) -> bool:
500+
r"""Check if a likelihood requires training inputs for noise computation.
501+
502+
Returns True for likelihoods like SparseOutlierGaussianLikelihood that
503+
need training inputs to determine per-point noise (e.g., outlier variances).
504+
Standard GaussianLikelihood does not need training inputs.
505+
"""
506+
return isinstance(likelihood, SparseOutlierGaussianLikelihood)
507+
508+
396509
def _subtract_observation_noise(model: GPyTorchModel, loo_variance: Tensor) -> Tensor:
397510
r"""Subtract observation noise from LOO variance to get posterior variance.
398511
@@ -426,16 +539,18 @@ def _subtract_observation_noise(model: GPyTorchModel, loo_variance: Tensor) -> T
426539
noise_shape, dtype=loo_variance.dtype, device=loo_variance.device
427540
)
428541

429-
# Some likelihoods (e.g., SparseOutlierGaussianLikelihood) require training
430-
# inputs to be passed to correctly compute the noise. We pass the model's
431-
# train_inputs if available.
432-
train_inputs = getattr(model, "train_inputs", None)
433-
434-
# Call forward to get the observation noise distribution.
435-
# We pass train_inputs as a positional argument so it flows through *params
436-
# to the noise model, which is compatible with both standard Noise classes
437-
# (that use *params) and SparseOutlierNoise (that uses X as the first arg).
438-
noise_dist = likelihood.forward(zeros, train_inputs)
542+
# Only pass training inputs when the likelihood needs them (e.g.,
543+
# SparseOutlierGaussianLikelihood for per-point outlier noise).
544+
# Standard GaussianLikelihood doesn't need inputs, and passing
545+
# pre-transform inputs can cause shape mismatches with models that
546+
# use dimension-changing input transforms (e.g., AppendFeatures).
547+
if _likelihood_requires_X(likelihood):
548+
train_inputs = getattr(model, "train_inputs", None)
549+
noise_dist = likelihood.forward(
550+
zeros, train_inputs[0] if train_inputs else None
551+
)
552+
else:
553+
noise_dist = likelihood.forward(zeros)
439554

440555
# Extract noise variance and reshape to match loo_variance
441556
noise = noise_dist.variance.unsqueeze(-1) # ... x n x 1
@@ -522,12 +637,15 @@ def _compute_loo_predictions(
522637

523638
# Add observation noise to the diagonal via the likelihood
524639
# The likelihood adds noise: K_noisy = K + sigma^2 * I
525-
# Some likelihoods (e.g., SparseOutlierGaussianLikelihood) require training
526-
# inputs to be passed to correctly apply the noise model. We pass them as
527-
# a positional argument for compatibility with both standard likelihoods
528-
# and SparseOutlierGaussianLikelihood.
529-
train_inputs = model.train_inputs
530-
noisy_mvn = model.likelihood(prior_dist, train_inputs)
640+
# Only pass training inputs when the likelihood needs them (e.g.,
641+
# SparseOutlierGaussianLikelihood for per-point outlier noise).
642+
# Standard GaussianLikelihood doesn't need inputs, and passing
643+
# pre-transform inputs can cause shape mismatches with models that
644+
# use dimension-changing input transforms (e.g., AppendFeatures).
645+
if _likelihood_requires_X(model.likelihood):
646+
noisy_mvn = model.likelihood(prior_dist, model.train_inputs[0])
647+
else:
648+
noisy_mvn = model.likelihood(prior_dist)
531649

532650
# Get the covariance matrix - use lazy representation for potential caching
533651
K_noisy = noisy_mvn.lazy_covariance_matrix.to_dense()
@@ -554,7 +672,7 @@ def _compute_loo_predictions(
554672
# K_inv_diag has shape ... x n, so after unsqueeze(-1) we get ... x n x 1
555673
# (the last dim is 1 because each GP is single-output).
556674
loo_variance = (1.0 / K_inv_diag).unsqueeze(-1) # ... x n x 1
557-
loo_mean = train_Y.unsqueeze(-1) - K_inv_residuals * loo_variance # ... x n x 1
675+
loo_mean = train_Y.unsqueeze(-1) - K_inv_residuals * loo_variance
558676

559677
# If we want the posterior (noiseless) variance, subtract the noise
560678
if not observation_noise:
@@ -636,6 +754,7 @@ def _reshape_to_loo_cv_format(tensor: Tensor, num_outputs: int) -> Tensor:
636754
def ensemble_loo_cv(
637755
model: GPyTorchModel,
638756
observation_noise: bool = True,
757+
untransform: bool = True,
639758
) -> CVResults:
640759
r"""Compute efficient LOO cross-validation for ensemble models.
641760
@@ -673,17 +792,27 @@ def ensemble_loo_cv(
673792
predictive variance (including observation noise). If False,
674793
return the posterior variance of the latent function (excluding
675794
observation noise).
795+
untransform: If True (default), untransform the LOO predictions and
796+
observed values back to the original outcome space when the model
797+
has an outcome transform (e.g., ``Standardize``). This makes the
798+
results consistent with ``model.posterior()`` and
799+
``batch_cross_validation``. If False, return results in the
800+
model's internal (transformed) space.
676801
677802
Returns:
678803
CVResults: A named tuple containing:
679804
- model: The fitted ensemble GP model.
680805
- posterior: A ``GaussianMixturePosterior`` with per-member shape
681-
``n x num_models x 1 x 1``. Access per-member statistics via
806+
``n x num_models x 1 x m``. Access per-member statistics via
682807
``posterior.mean`` and ``posterior.variance``, and mixture
683808
statistics via ``posterior.mixture_mean`` and
684809
``posterior.mixture_variance``.
685-
- observed_Y: The observed Y values with shape ``n x 1 x 1``.
686-
- observed_Yvar: The observed noise variances (if provided).
810+
- observed_Y: The observed Y values with shape
811+
``n x num_models x 1 x m``, matching the posterior layout so
812+
that element-wise operations (e.g., ``posterior.mean -
813+
observed_Y``) work correctly.
814+
- observed_Yvar: The observed noise variances (if provided) with
815+
the same shape as ``observed_Y``.
687816
688817
Example:
689818
>>> import torch
@@ -723,7 +852,7 @@ def ensemble_loo_cv(
723852
)
724853

725854
# Get the number of outputs
726-
num_outputs = getattr(model, "_num_outputs", 1)
855+
num_outputs = model.num_outputs
727856

728857
# Build the GaussianMixturePosterior
729858
posterior = _build_ensemble_loo_posterior(
@@ -735,6 +864,44 @@ def ensemble_loo_cv(
735864
model=model, train_Y=train_Y, num_outputs=num_outputs
736865
)
737866

867+
# Untransform predictions and observed values back to original space
868+
if untransform and hasattr(model, "outcome_transform"):
869+
observed_Y_ndim = observed_Y.dim()
870+
posterior, observed_Y, observed_Yvar = _untransform_loo_results(
871+
model=model,
872+
posterior=posterior,
873+
observed_Y=observed_Y,
874+
observed_Yvar=observed_Yvar,
875+
)
876+
# After untransform with a batch-shaped outcome transform (e.g.,
877+
# Standardize(batch_shape=[num_models])), observed_Y gains leading
878+
# batch dimensions from the transform. For example, with
879+
# batch_shape=[num_models], observed_Y goes from [n, 1, m] to
880+
# [num_models, n, 1, m]. The posterior has num_models at MCMC_DIM=-3,
881+
# giving shape [n, num_models, 1, m]. Align observed_Y by moving all
882+
# added leading dims to just before the q=1 dim (MCMC_DIM position).
883+
n_added = observed_Y.dim() - observed_Y_ndim
884+
if n_added > 1:
885+
raise UnsupportedError(
886+
"ensemble_loo_cv does not support outcome transforms that add "
887+
f"more than 1 batch dimension. Got {n_added} added dimensions "
888+
f"(shape went from {observed_Y_ndim}D to {observed_Y.dim()}D)."
889+
)
890+
if n_added == 1:
891+
observed_Y = observed_Y.movedim(0, -3)
892+
if observed_Yvar is not None:
893+
observed_Yvar = observed_Yvar.movedim(0, -3)
894+
895+
# Ensure observed_Y has the num_models dimension at MCMC_DIM=-3 to match
896+
# the posterior shape [n, num_models, 1, m]. Without this, element-wise
897+
# operations like `posterior.mean - observed_Y` would fail due to
898+
# incompatible broadcasting (3D [n, 1, m] vs 4D [n, num_models, 1, m]).
899+
posterior_mean = posterior.mean
900+
if observed_Y.dim() < posterior_mean.dim():
901+
observed_Y = observed_Y.unsqueeze(-3).expand_as(posterior_mean)
902+
if observed_Yvar is not None:
903+
observed_Yvar = observed_Yvar.unsqueeze(-3).expand_as(posterior_mean)
904+
738905
return CVResults(
739906
model=model,
740907
posterior=posterior,
@@ -867,7 +1034,7 @@ def _verify_ensemble_data_consistency(
8671034
first_member = tensor.select(num_models_dim, 0)
8681035
first_expanded = first_member.unsqueeze(num_models_dim).expand_as(tensor)
8691036

870-
if not torch.allclose(tensor, first_expanded):
1037+
if not torch.equal(tensor, first_expanded):
8711038
raise UnsupportedError(
8721039
f"Ensemble members have different {tensor_name}. "
8731040
"ensemble_loo_cv only supports ensembles where all members share the "

0 commit comments

Comments
 (0)