|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | +from contextlib import contextmanager |
| 9 | +from typing import Iterator |
| 10 | + |
| 11 | +import torch |
| 12 | +from botorch.models.gp_regression import SingleTaskGP |
| 13 | +from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel |
| 14 | +from botorch.models.transforms.input import InputTransform |
| 15 | +from botorch.models.transforms.outcome import OutcomeTransform |
| 16 | +from botorch.posteriors.gpytorch import GPyTorchPosterior |
| 17 | +from botorch.utils.types import _DefaultType, DEFAULT |
| 18 | +from gpytorch.distributions import Distribution, MultivariateNormal |
| 19 | +from gpytorch.kernels import RBFKernel |
| 20 | +from gpytorch.likelihoods import Likelihood |
| 21 | +from gpytorch.means import Mean |
| 22 | +from linear_operator.operators import LinearOperator |
| 23 | +from torch import Tensor |
| 24 | + |
| 25 | + |
| 26 | +class OrthogonalAdditiveGP(SingleTaskGP): |
| 27 | + """A Gaussian Process with Orthogonal Additive Kernel for interpretable modeling. |
| 28 | +
|
| 29 | + This GP model uses an OrthogonalAdditiveKernel which decomposes the function into |
| 30 | + interpretable additive components: a bias term, first-order effects for each input |
| 31 | + dimension, and optionally second-order interaction terms. |
| 32 | +
|
| 33 | + The model supports posterior inference of individual additive components when |
| 34 | + `infer_all_components=True` is passed to the `posterior` method. |
| 35 | + """ |
| 36 | + |
| 37 | + # Class-level default for inference mode (avoids __init__ override) |
| 38 | + _infer_all_components: bool = False |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + train_X: Tensor, |
| 43 | + train_Y: Tensor, |
| 44 | + covar_module: OrthogonalAdditiveKernel | None = None, |
| 45 | + second_order: bool = False, |
| 46 | + likelihood: Likelihood | None = None, |
| 47 | + mean_module: Mean | None = None, |
| 48 | + outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT, |
| 49 | + input_transform: InputTransform | None = None, |
| 50 | + ) -> None: |
| 51 | + """Initialize the OrthogonalAdditiveGP. |
| 52 | +
|
| 53 | + Args: |
| 54 | + train_X: Training inputs (batch_shape x n x d) in [0, 1]^d. |
| 55 | + train_Y: Training outputs (batch_shape x n x 1). |
| 56 | + covar_module: An OrthogonalAdditiveKernel instance. If None, creates a |
| 57 | + default kernel with dim inferred from train_X. |
| 58 | + second_order: If True and covar_module is None, enables second-order |
| 59 | + interactions in the default kernel. Ignored if covar_module is provided. |
| 60 | + likelihood: Optional likelihood (defaults to GaussianLikelihood). |
| 61 | + mean_module: Optional mean module (defaults to ConstantMean). |
| 62 | + outcome_transform: Optional outcome transform. |
| 63 | + input_transform: Optional input transform. |
| 64 | +
|
| 65 | + Raises: |
| 66 | + TypeError: If covar_module is provided but is not an |
| 67 | + OrthogonalAdditiveKernel. |
| 68 | + """ |
| 69 | + if covar_module is None: |
| 70 | + covar_module = OrthogonalAdditiveKernel( |
| 71 | + base_kernel=RBFKernel(), |
| 72 | + dim=train_X.shape[-1], |
| 73 | + second_order=second_order, |
| 74 | + dtype=train_X.dtype, |
| 75 | + device=train_X.device, |
| 76 | + ) |
| 77 | + elif not isinstance(covar_module, OrthogonalAdditiveKernel): |
| 78 | + raise TypeError( |
| 79 | + f"covar_module must be an OrthogonalAdditiveKernel, " |
| 80 | + f"got {type(covar_module).__name__}" |
| 81 | + ) |
| 82 | + |
| 83 | + super().__init__( |
| 84 | + train_X=train_X, |
| 85 | + train_Y=train_Y, |
| 86 | + likelihood=likelihood, |
| 87 | + covar_module=covar_module, |
| 88 | + mean_module=mean_module, |
| 89 | + outcome_transform=outcome_transform, |
| 90 | + input_transform=input_transform, |
| 91 | + ) |
| 92 | + |
| 93 | + @contextmanager |
| 94 | + def _component_inference_context( |
| 95 | + self, infer_all_components: bool = False |
| 96 | + ) -> Iterator[None]: |
| 97 | + """Context manager that temporarily sets component inference mode. |
| 98 | +
|
| 99 | + Args: |
| 100 | + infer_all_components: If True, enables per-component posterior inference. |
| 101 | +
|
| 102 | + Yields: |
| 103 | + None. The context manager sets internal state that is checked by |
| 104 | + `_get_test_prior_mean_and_covariances` to dispatch to the appropriate |
| 105 | + covariance computation. |
| 106 | + """ |
| 107 | + prev_state = self._infer_all_components |
| 108 | + self._infer_all_components = infer_all_components |
| 109 | + try: |
| 110 | + yield |
| 111 | + finally: |
| 112 | + self._infer_all_components = prev_state |
| 113 | + |
| 114 | + def posterior( |
| 115 | + self, |
| 116 | + X: Tensor, |
| 117 | + output_indices: list[int] | None = None, |
| 118 | + observation_noise: bool = False, |
| 119 | + posterior_transform=None, |
| 120 | + infer_all_components: bool = False, |
| 121 | + ) -> GPyTorchPosterior: |
| 122 | + """Posterior inference of the additive Gaussian process. |
| 123 | +
|
| 124 | + Args: |
| 125 | + X: The input tensor of shape (batch_shape x n x d). |
| 126 | + output_indices: Not supported for this model. |
| 127 | + observation_noise: Whether to add observation noise to the posterior. |
| 128 | + posterior_transform: Optional posterior transform. |
| 129 | + infer_all_components: If True, returns a posterior with a batch |
| 130 | + dimension corresponding to each additive component (bias, first-order |
| 131 | + effects, and optionally second-order interactions). The number of |
| 132 | + components is 1 + d (first-order only) or 1 + d + d*(d-1)/2 |
| 133 | + (with second-order interactions). |
| 134 | +
|
| 135 | + Returns: |
| 136 | + The posterior distribution at X. |
| 137 | + """ |
| 138 | + # Use context manager to set inference mode, then delegate to GPyTorch |
| 139 | + with self._component_inference_context(infer_all_components): |
| 140 | + return super().posterior( |
| 141 | + X=X, |
| 142 | + output_indices=output_indices, |
| 143 | + observation_noise=observation_noise, |
| 144 | + posterior_transform=posterior_transform, |
| 145 | + ) |
| 146 | + |
| 147 | + def _get_test_prior_mean_and_covariances( |
| 148 | + self, |
| 149 | + train_inputs: list[Tensor], |
| 150 | + test_inputs: list[Tensor], |
| 151 | + **kwargs, |
| 152 | + ) -> tuple[ |
| 153 | + Tensor, |
| 154 | + LinearOperator, |
| 155 | + LinearOperator, |
| 156 | + torch.Size, |
| 157 | + torch.Size, |
| 158 | + type[Distribution], |
| 159 | + ]: |
| 160 | + """Dispatches to appropriate covariance computation based on inference mode. |
| 161 | +
|
| 162 | + This method is called by GPyTorch's ExactGP.__call__ during posterior |
| 163 | + computation. When `_infer_all_components` is True (set via the context |
| 164 | + manager), it returns per-component covariances with an extra batch dimension. |
| 165 | + Otherwise, it uses the standard kernel forward pass which sums over components. |
| 166 | + """ |
| 167 | + if self._infer_all_components: |
| 168 | + return self._get_test_prior_mean_and_covariances_per_component( |
| 169 | + train_inputs, test_inputs, **kwargs |
| 170 | + ) |
| 171 | + return super()._get_test_prior_mean_and_covariances( |
| 172 | + train_inputs=train_inputs, test_inputs=test_inputs, **kwargs |
| 173 | + ) |
| 174 | + |
| 175 | + def _get_test_prior_mean_and_covariances_per_component( |
| 176 | + self, |
| 177 | + train_inputs: list[Tensor], |
| 178 | + test_inputs: list[Tensor], |
| 179 | + **kwargs, |
| 180 | + ) -> tuple[ |
| 181 | + Tensor, |
| 182 | + LinearOperator, |
| 183 | + LinearOperator, |
| 184 | + torch.Size, |
| 185 | + torch.Size, |
| 186 | + type[Distribution], |
| 187 | + ]: |
| 188 | + """Computes mean and covariances with a batch dimension for each component. |
| 189 | +
|
| 190 | + This enables posterior inference of individual additive components by returning |
| 191 | + covariance matrices with an extra leading batch dimension for each component. |
| 192 | +
|
| 193 | + Returns: |
| 194 | + A tuple containing: |
| 195 | + - test_mean: The mean evaluated on the test set (batch_shape x n_test) |
| 196 | + - test_test_covar: Covariance between test points |
| 197 | + (num_components x batch_shape x n_test x n_test) |
| 198 | + - test_train_covar: Covariance between test and train points |
| 199 | + (num_components x batch_shape x n_test x n_train) |
| 200 | + - batch_shape: The batch shape of the model |
| 201 | + - test_shape: Shape (n_test,) |
| 202 | + - posterior_class: The class of the posterior to be instantiated |
| 203 | + """ |
| 204 | + if len(train_inputs) != 1 or len(test_inputs) != 1: |
| 205 | + raise ValueError( |
| 206 | + "OrthogonalAdditiveGP expects a single input X, but received " |
| 207 | + f"{len(train_inputs)=}, and {len(test_inputs)=}." |
| 208 | + ) |
| 209 | + |
| 210 | + X_train = train_inputs[0] |
| 211 | + X_test = test_inputs[0] |
| 212 | + |
| 213 | + # Batch shape includes the component dimension as the leading dimension |
| 214 | + # This is needed so GPyTorch correctly reshapes the predictive mean |
| 215 | + num_components = self.covar_module.num_components |
| 216 | + batch_shape = torch.Size([num_components]) + X_train.shape[:-2] |
| 217 | + |
| 218 | + # Get component-wise covariances using _non_reduced_forward |
| 219 | + # Shape: (num_components x batch_shape x n_test x n_test) |
| 220 | + test_test_covar = self.covar_module._non_reduced_forward(X_test, X_test) |
| 221 | + # Shape: (num_components x batch_shape x n_test x n_train) |
| 222 | + test_train_covar = self.covar_module._non_reduced_forward(X_test, X_train) |
| 223 | + |
| 224 | + # Prior mean: Only the bias component (index 0) should have the prior mean. |
| 225 | + # All other components represent deviations from the mean, so their prior |
| 226 | + # mean should be zero. This ensures that when we sum over all components, |
| 227 | + # we get the correct total posterior mean (prior mean added once). |
| 228 | + n_test = X_test.shape[-2] |
| 229 | + # Create a (num_components, n_test) tensor of zeros |
| 230 | + test_mean = torch.zeros( |
| 231 | + num_components, n_test, dtype=X_test.dtype, device=X_test.device |
| 232 | + ) |
| 233 | + # Set the bias component's mean to the actual prior mean |
| 234 | + test_mean[0, :] = self.mean_module(X_test) |
| 235 | + test_shape = torch.Size([n_test]) |
| 236 | + |
| 237 | + return ( |
| 238 | + test_mean, |
| 239 | + test_test_covar, |
| 240 | + test_train_covar, |
| 241 | + batch_shape, |
| 242 | + test_shape, |
| 243 | + MultivariateNormal, |
| 244 | + ) |
| 245 | + |
| 246 | + @property |
| 247 | + def component_indices(self) -> dict[str, Tensor]: |
| 248 | + """Returns component indices from the OrthogonalAdditiveKernel.""" |
| 249 | + return self.covar_module.component_indices |
| 250 | + |
| 251 | + def evaluate_first_order_on_grid( |
| 252 | + self, |
| 253 | + grid_1d: Tensor, |
| 254 | + ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: |
| 255 | + r"""Evaluate first-order component posteriors on 1D grids. |
| 256 | +
|
| 257 | + Uses diagonal test inputs with the existing GPyTorch posterior |
| 258 | + infrastructure. Each first-order component is evaluated on its |
| 259 | + own independent 1D grid. |
| 260 | +
|
| 261 | + Args: |
| 262 | + grid_1d: 1D tensor of m points in [0, 1]. |
| 263 | +
|
| 264 | + Returns: |
| 265 | + Tuple of: |
| 266 | + - bias: (mean, variance) - scalar values |
| 267 | + - first_order: ((d, m) means, (d, m) variances) on 1D grids |
| 268 | +
|
| 269 | + Example: |
| 270 | + >>> grid = torch.linspace(0, 1, 50) |
| 271 | + >>> (bias_mean, bias_var), (fo_mean, fo_var) = \\ |
| 272 | + ... model.evaluate_first_order_on_grid(grid) |
| 273 | + >>> # fo_mean[i, :] is component i's posterior mean on the 1D grid |
| 274 | + """ |
| 275 | + self.eval() |
| 276 | + m = len(grid_1d) |
| 277 | + d = self.covar_module.dim |
| 278 | + |
| 279 | + # Diagonal test inputs: X[k, :] = [t_k, t_k, ..., t_k] |
| 280 | + # Each first-order component i sees its own 1D grid on dimension i |
| 281 | + X_diag = grid_1d.unsqueeze(-1).expand(m, d) |
| 282 | + |
| 283 | + # Use existing posterior with all-components mode |
| 284 | + posterior = self.posterior(X_diag, infer_all_components=True) |
| 285 | + |
| 286 | + # Squeeze output dimension (last dim) since this is single-output |
| 287 | + mean = posterior.mean.squeeze(-1) # (num_components, m) |
| 288 | + variance = posterior.variance.squeeze(-1) # (num_components, m) |
| 289 | + |
| 290 | + # Extract bias (component 0) - should be constant across grid |
| 291 | + bias_mean = mean[0, :].mean() |
| 292 | + bias_var = variance[0, :].mean() |
| 293 | + |
| 294 | + # Extract first-order (components 1 to d) |
| 295 | + first_order_means = mean[1 : d + 1, :] # (d, m) |
| 296 | + first_order_vars = variance[1 : d + 1, :] # (d, m) |
| 297 | + |
| 298 | + return (bias_mean, bias_var), (first_order_means, first_order_vars) |
| 299 | + |
| 300 | + @property |
| 301 | + def num_components(self) -> int: |
| 302 | + """Total number of additive components (bias + first-order [+ second-order]).""" |
| 303 | + return self.covar_module.num_components |
| 304 | + |
| 305 | + def get_component_index( |
| 306 | + self, |
| 307 | + component_type: str, |
| 308 | + dim_index: int | tuple[int, int] | None = None, |
| 309 | + ) -> int: |
| 310 | + """Returns the component index for a given component type and dimension. |
| 311 | +
|
| 312 | + See OrthogonalAdditiveKernel.get_component_index for details. |
| 313 | + """ |
| 314 | + return self.covar_module.get_component_index(component_type, dim_index) |
0 commit comments