|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +from contextlib import contextmanager |
| 10 | +from typing import Iterator, TYPE_CHECKING |
| 11 | + |
| 12 | +import torch |
| 13 | +from botorch.models.gp_regression import SingleTaskGP |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from botorch.acquisition.objective import PosteriorTransform |
| 17 | +from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel |
| 18 | +from botorch.models.transforms.input import InputTransform |
| 19 | +from botorch.models.transforms.outcome import OutcomeTransform |
| 20 | +from botorch.posteriors.gpytorch import GPyTorchPosterior |
| 21 | +from botorch.utils.types import _DefaultType, DEFAULT |
| 22 | +from gpytorch.distributions import Distribution, MultivariateNormal |
| 23 | +from gpytorch.likelihoods import Likelihood |
| 24 | +from gpytorch.means import Mean |
| 25 | +from linear_operator.operators import LinearOperator |
| 26 | +from torch import Tensor |
| 27 | + |
| 28 | + |
| 29 | +class OrthogonalAdditiveGP(SingleTaskGP): |
| 30 | + """A Gaussian Process with Orthogonal Additive Kernel for interpretable modeling. |
| 31 | +
|
| 32 | + This GP model uses an OrthogonalAdditiveKernel which decomposes the function into |
| 33 | + interpretable additive components: a bias term, first-order effects for each input |
| 34 | + dimension, and optionally second-order interaction terms. |
| 35 | +
|
| 36 | + The model supports posterior inference of individual additive components when |
| 37 | + `infer_all_components=True` is passed to the `posterior` method. |
| 38 | + """ |
| 39 | + |
| 40 | + # Class-level default for inference mode (avoids __init__ override) |
| 41 | + _infer_all_components: bool = False |
| 42 | + |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + train_X: Tensor, |
| 46 | + train_Y: Tensor, |
| 47 | + covar_module: OrthogonalAdditiveKernel | None = None, |
| 48 | + second_order: bool = False, |
| 49 | + likelihood: Likelihood | None = None, |
| 50 | + mean_module: Mean | None = None, |
| 51 | + outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT, |
| 52 | + input_transform: InputTransform | None = None, |
| 53 | + ) -> None: |
| 54 | + """Initialize the OrthogonalAdditiveGP. |
| 55 | +
|
| 56 | + Args: |
| 57 | + train_X: Training inputs (batch_shape x n x d) in [0, 1]^d. |
| 58 | + train_Y: Training outputs (batch_shape x n x 1). |
| 59 | + covar_module: An OrthogonalAdditiveKernel instance. If None, creates a |
| 60 | + default OrthogonalAdditiveKernel with dim inferred from train_X. |
| 61 | + The kernel's default base kernel uses per-dimension lengthscales, |
| 62 | + so each additive component can adapt its smoothness independently. |
| 63 | + second_order: If True and covar_module is None, enables second-order |
| 64 | + interactions in the default kernel. Ignored if covar_module is provided. |
| 65 | + likelihood: Optional likelihood (defaults to GaussianLikelihood). |
| 66 | + mean_module: Optional mean module (defaults to ConstantMean). |
| 67 | + outcome_transform: Optional outcome transform. |
| 68 | + input_transform: Optional input transform. |
| 69 | +
|
| 70 | + Raises: |
| 71 | + TypeError: If covar_module is provided but is not an |
| 72 | + OrthogonalAdditiveKernel. |
| 73 | + """ |
| 74 | + if covar_module is None: |
| 75 | + covar_module = OrthogonalAdditiveKernel( |
| 76 | + dim=train_X.shape[-1], |
| 77 | + second_order=second_order, |
| 78 | + dtype=train_X.dtype, |
| 79 | + device=train_X.device, |
| 80 | + ) |
| 81 | + elif not isinstance(covar_module, OrthogonalAdditiveKernel): |
| 82 | + raise TypeError( |
| 83 | + f"covar_module must be an OrthogonalAdditiveKernel, " |
| 84 | + f"got {type(covar_module).__name__}" |
| 85 | + ) |
| 86 | + |
| 87 | + super().__init__( |
| 88 | + train_X=train_X, |
| 89 | + train_Y=train_Y, |
| 90 | + likelihood=likelihood, |
| 91 | + covar_module=covar_module, |
| 92 | + mean_module=mean_module, |
| 93 | + outcome_transform=outcome_transform, |
| 94 | + input_transform=input_transform, |
| 95 | + ) |
| 96 | + |
| 97 | + @contextmanager |
| 98 | + def _component_inference_context( |
| 99 | + self, infer_all_components: bool = False |
| 100 | + ) -> Iterator[None]: |
| 101 | + """Context manager that temporarily sets component inference mode. |
| 102 | +
|
| 103 | + Args: |
| 104 | + infer_all_components: If True, enables per-component posterior inference. |
| 105 | +
|
| 106 | + Yields: |
| 107 | + None. The context manager sets internal state that is checked by |
| 108 | + `_get_test_prior_mean_and_covariances` to dispatch to the appropriate |
| 109 | + covariance computation. |
| 110 | + """ |
| 111 | + prev_state = self._infer_all_components |
| 112 | + self._infer_all_components = infer_all_components |
| 113 | + try: |
| 114 | + yield |
| 115 | + finally: |
| 116 | + self._infer_all_components = prev_state |
| 117 | + |
| 118 | + def posterior( |
| 119 | + self, |
| 120 | + X: Tensor, |
| 121 | + output_indices: list[int] | None = None, |
| 122 | + observation_noise: bool = False, |
| 123 | + posterior_transform: PosteriorTransform | None = None, |
| 124 | + infer_all_components: bool = False, |
| 125 | + ) -> GPyTorchPosterior: |
| 126 | + """Posterior inference of the additive Gaussian process. |
| 127 | +
|
| 128 | + Args: |
| 129 | + X: The input tensor of shape (batch_shape x n x d). |
| 130 | + output_indices: Not supported for this model. |
| 131 | + observation_noise: Whether to add observation noise to the posterior. |
| 132 | + posterior_transform: Optional posterior transform. |
| 133 | + infer_all_components: If True, returns a posterior with a batch |
| 134 | + dimension corresponding to each additive component (bias, first-order |
| 135 | + effects, and optionally second-order interactions). The number of |
| 136 | + components is 1 + d (first-order only) or 1 + d + d*(d-1)/2 |
| 137 | + (with second-order interactions). |
| 138 | +
|
| 139 | + Returns: |
| 140 | + The posterior distribution at X. |
| 141 | + """ |
| 142 | + # Use context manager to set inference mode, then delegate to GPyTorch |
| 143 | + with self._component_inference_context(infer_all_components): |
| 144 | + return super().posterior( |
| 145 | + X=X, |
| 146 | + output_indices=output_indices, |
| 147 | + observation_noise=observation_noise, |
| 148 | + posterior_transform=posterior_transform, |
| 149 | + ) |
| 150 | + |
| 151 | + def _get_test_prior_mean_and_covariances( |
| 152 | + self, |
| 153 | + train_inputs: list[Tensor], |
| 154 | + test_inputs: list[Tensor], |
| 155 | + **kwargs, |
| 156 | + ) -> tuple[ |
| 157 | + Tensor, |
| 158 | + LinearOperator, |
| 159 | + LinearOperator, |
| 160 | + torch.Size, |
| 161 | + torch.Size, |
| 162 | + type[Distribution], |
| 163 | + ]: |
| 164 | + """Dispatches to appropriate covariance computation based on inference mode. |
| 165 | +
|
| 166 | + This method is called by GPyTorch's ExactGP.__call__ during posterior |
| 167 | + computation. When `_infer_all_components` is True (set via the context |
| 168 | + manager), it returns per-component covariances with an extra batch dimension. |
| 169 | + Otherwise, it uses the standard kernel forward pass which sums over components. |
| 170 | + """ |
| 171 | + if self._infer_all_components: |
| 172 | + return self._get_test_prior_mean_and_covariances_per_component( |
| 173 | + train_inputs, test_inputs, **kwargs |
| 174 | + ) |
| 175 | + return super()._get_test_prior_mean_and_covariances( |
| 176 | + train_inputs=train_inputs, test_inputs=test_inputs, **kwargs |
| 177 | + ) |
| 178 | + |
| 179 | + def _get_test_prior_mean_and_covariances_per_component( |
| 180 | + self, |
| 181 | + train_inputs: list[Tensor], |
| 182 | + test_inputs: list[Tensor], |
| 183 | + **kwargs, |
| 184 | + ) -> tuple[ |
| 185 | + Tensor, |
| 186 | + LinearOperator, |
| 187 | + LinearOperator, |
| 188 | + torch.Size, |
| 189 | + torch.Size, |
| 190 | + type[Distribution], |
| 191 | + ]: |
| 192 | + """Computes mean and covariances with a batch dimension for each component. |
| 193 | +
|
| 194 | + This enables posterior inference of individual additive components by returning |
| 195 | + covariance matrices with an extra leading batch dimension for each component. |
| 196 | + The component dimension is placed BEFORE any input batch dimensions, which is |
| 197 | + required for GPyTorch's prediction strategy to broadcast correctly against |
| 198 | + the cached training solve. |
| 199 | +
|
| 200 | + Returns: |
| 201 | + A tuple containing: |
| 202 | + - test_mean: (num_components, *input_batch_shape, n_test) |
| 203 | + - test_test_covar: (num_components, *input_batch_shape, n_test, n_test) |
| 204 | + - test_train_covar: (num_components, *input_batch_shape, n_test, n_train) |
| 205 | + - batch_shape: (num_components, *input_batch_shape) |
| 206 | + - test_shape: Shape (n_test,) |
| 207 | + - posterior_class: MultivariateNormal |
| 208 | + """ |
| 209 | + if len(train_inputs) != 1 or len(test_inputs) != 1: |
| 210 | + raise ValueError( |
| 211 | + "OrthogonalAdditiveGP expects a single input X, but received " |
| 212 | + f"{len(train_inputs)=}, and {len(test_inputs)=}." |
| 213 | + ) |
| 214 | + |
| 215 | + X_train = train_inputs[0] |
| 216 | + X_test = test_inputs[0] |
| 217 | + |
| 218 | + # Component dim must be the LEADING batch dimension so that GPyTorch's |
| 219 | + # prediction strategy (which has mean_cache of shape (*input_batch, n)) |
| 220 | + # can broadcast correctly: (num_components, *input_batch, m, n) @ |
| 221 | + # (*input_batch, n, 1) broadcasts on the input_batch dims. |
| 222 | + num_components = self.covar_module.num_components |
| 223 | + input_batch_shape = X_train.shape[:-2] |
| 224 | + batch_shape = torch.Size([num_components]) + input_batch_shape |
| 225 | + |
| 226 | + # The kernel's _non_reduced_forward returns tensors with shape |
| 227 | + # (*input_batch_shape, num_components, n, n), but GPyTorch needs the |
| 228 | + # component dim first: (num_components, *input_batch_shape, n, n). |
| 229 | + # We permute when there are input batch dims. |
| 230 | + test_test_covar = self.covar_module._non_reduced_forward(X_test, X_test) |
| 231 | + test_train_covar = self.covar_module._non_reduced_forward(X_test, X_train) |
| 232 | + |
| 233 | + n_input_batch = len(input_batch_shape) |
| 234 | + if n_input_batch > 0: |
| 235 | + # Move component dim (at position n_input_batch) to position 0 |
| 236 | + test_test_covar = torch.movedim(test_test_covar, n_input_batch, 0) |
| 237 | + test_train_covar = torch.movedim(test_train_covar, n_input_batch, 0) |
| 238 | + |
| 239 | + # Prior mean: Only the bias component (index 0) should have the prior mean. |
| 240 | + # All other components represent deviations from the mean, so their prior |
| 241 | + # mean should be zero. This ensures that when we sum over all components, |
| 242 | + # we get the correct total posterior mean (prior mean added once). |
| 243 | + n_test = X_test.shape[-2] |
| 244 | + # Create a (num_components x batch_shape x n_test) tensor of zeros |
| 245 | + test_mean = torch.zeros( |
| 246 | + *batch_shape, n_test, dtype=X_test.dtype, device=X_test.device |
| 247 | + ) |
| 248 | + # Set the bias component's mean to the actual prior mean. |
| 249 | + # Component dim is first, so [0, ...] selects bias for all batch elements. |
| 250 | + test_mean[0, ...] = self.mean_module(X_test) |
| 251 | + test_shape = torch.Size([n_test]) |
| 252 | + |
| 253 | + return ( |
| 254 | + test_mean, |
| 255 | + test_test_covar, |
| 256 | + test_train_covar, |
| 257 | + batch_shape, |
| 258 | + test_shape, |
| 259 | + MultivariateNormal, |
| 260 | + ) |
| 261 | + |
| 262 | + @property |
| 263 | + def component_indices(self) -> dict[str, Tensor]: |
| 264 | + """Returns component indices from the OrthogonalAdditiveKernel.""" |
| 265 | + return self.covar_module.component_indices |
| 266 | + |
| 267 | + def evaluate_first_order_on_grid( |
| 268 | + self, |
| 269 | + grid_1d: Tensor, |
| 270 | + ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: |
| 271 | + r"""Evaluate first-order component posteriors on 1D grids. |
| 272 | +
|
| 273 | + Uses diagonal test inputs with the existing GPyTorch posterior |
| 274 | + infrastructure. Each first-order component is evaluated on its |
| 275 | + own independent 1D grid. |
| 276 | +
|
| 277 | + Supports models with batch-shaped training data. The grid itself should |
| 278 | + be 1D; batch dimensions come from the model's training data. |
| 279 | +
|
| 280 | + Args: |
| 281 | + grid_1d: 1D tensor of m points in [0, 1]. |
| 282 | +
|
| 283 | + Returns: |
| 284 | + Tuple of: |
| 285 | + - bias: (mean, variance) — shape ``(*batch_shape,)`` each, |
| 286 | + averaged over the grid (should be approximately constant). |
| 287 | + - first_order: (means, variances) — shape |
| 288 | + ``(d, *batch_shape, m)`` each, posterior statistics on 1D grids. |
| 289 | +
|
| 290 | + Example: |
| 291 | + >>> grid = torch.linspace(0, 1, 50) |
| 292 | + >>> (bias_mean, bias_var), (fo_mean, fo_var) = \ |
| 293 | + ... model.evaluate_first_order_on_grid(grid) |
| 294 | + >>> # fo_mean[i] is component i's posterior mean on the 1D grid |
| 295 | + """ |
| 296 | + m = len(grid_1d) |
| 297 | + d = self.covar_module.dim |
| 298 | + |
| 299 | + # Diagonal test inputs: X[k, :] = [t_k, t_k, ..., t_k] |
| 300 | + # Each first-order component i sees its own 1D grid on dimension i. |
| 301 | + # No batch dims needed here — GPyTorch broadcasts against training data. |
| 302 | + X_diag = grid_1d.unsqueeze(-1).expand(m, d) |
| 303 | + |
| 304 | + # Use existing posterior with all-components mode |
| 305 | + posterior = self.posterior(X_diag, infer_all_components=True) |
| 306 | + |
| 307 | + # Squeeze output dimension (last dim) since this is single-output |
| 308 | + # Shape: (num_components, *batch_shape, m) |
| 309 | + mean = posterior.mean.squeeze(-1) |
| 310 | + variance = posterior.variance.squeeze(-1) |
| 311 | + |
| 312 | + # Extract bias (component 0) - should be constant across grid. |
| 313 | + # Component dim is at position 0, so mean[0] gives (*batch_shape, m). |
| 314 | + # Average over the grid dimension (last dim) only, preserving batch dims. |
| 315 | + bias_mean = mean[0].mean(dim=-1) |
| 316 | + bias_var = variance[0].mean(dim=-1) |
| 317 | + |
| 318 | + # Extract first-order (components 1 to d) |
| 319 | + first_order_means = mean[1 : d + 1] # (d, *batch_shape, m) |
| 320 | + first_order_vars = variance[1 : d + 1] # (d, *batch_shape, m) |
| 321 | + |
| 322 | + return (bias_mean, bias_var), (first_order_means, first_order_vars) |
| 323 | + |
| 324 | + @property |
| 325 | + def num_components(self) -> int: |
| 326 | + """Total number of additive components (bias + first-order [+ second-order]).""" |
| 327 | + return self.covar_module.num_components |
| 328 | + |
| 329 | + def get_component_index( |
| 330 | + self, |
| 331 | + component_type: str, |
| 332 | + dim_index: int | tuple[int, int] | None = None, |
| 333 | + ) -> int: |
| 334 | + """Returns the component index for a given component type and dimension. |
| 335 | +
|
| 336 | + See OrthogonalAdditiveKernel.get_component_index for details. |
| 337 | + """ |
| 338 | + return self.covar_module.get_component_index(component_type, dim_index) |
0 commit comments