Skip to content

Commit 9cc1864

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
OrthogonalAdditiveGP with component-wise inference (meta-pytorch#3187)
Summary: This commit adds a new additive Gaussian process model `OrthogonalAdditiveGP`, which leverages the `OrthogonalAdditiveKernel`, and importantly, posterior inference of individual additive components, conditioned on noisy observations of the sum. This is enabled with the refactored GPyTorch posterior inference stack via `_get_test_prior_mean_and_covariances`. This relies on having an additional batch dimension for the test-test and train-test covariance, corresponding to the kernel matrices of each additive component, while the batch dimension has to be absent on the training set, because we are observing the *sum* of the additive components. Reviewed By: hvarfner Differential Revision: D92461397
1 parent 1855320 commit 9cc1864

6 files changed

Lines changed: 1001 additions & 21 deletions

File tree

botorch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from botorch.models.additive_gp import OrthogonalAdditiveGP
78
from botorch.models.approximate_gp import (
89
ApproximateGPyTorchModel,
910
SingleTaskVariationalGP,
@@ -44,6 +45,7 @@
4445
"ModelList",
4546
"ModelListGP",
4647
"MultiTaskGP",
48+
"OrthogonalAdditiveGP",
4749
"PairwiseGP",
4850
"PairwiseLaplaceMarginalLogLikelihood",
4951
"PosteriorMeanModel",

botorch/models/additive_gp.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
from contextlib import contextmanager
9+
from typing import Iterator
10+
11+
import torch
12+
from botorch.models.gp_regression import SingleTaskGP
13+
from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel
14+
from botorch.models.transforms.input import InputTransform
15+
from botorch.models.transforms.outcome import OutcomeTransform
16+
from botorch.posteriors.gpytorch import GPyTorchPosterior
17+
from botorch.utils.types import _DefaultType, DEFAULT
18+
from gpytorch.distributions import Distribution, MultivariateNormal
19+
from gpytorch.kernels import RBFKernel
20+
from gpytorch.likelihoods import Likelihood
21+
from gpytorch.means import Mean
22+
from linear_operator.operators import LinearOperator
23+
from torch import Tensor
24+
25+
26+
class OrthogonalAdditiveGP(SingleTaskGP):
27+
"""A Gaussian Process with Orthogonal Additive Kernel for interpretable modeling.
28+
29+
This GP model uses an OrthogonalAdditiveKernel which decomposes the function into
30+
interpretable additive components: a bias term, first-order effects for each input
31+
dimension, and optionally second-order interaction terms.
32+
33+
The model supports posterior inference of individual additive components when
34+
`infer_all_components=True` is passed to the `posterior` method.
35+
"""
36+
37+
# Class-level default for inference mode (avoids __init__ override)
38+
_infer_all_components: bool = False
39+
40+
def __init__(
41+
self,
42+
train_X: Tensor,
43+
train_Y: Tensor,
44+
covar_module: OrthogonalAdditiveKernel | None = None,
45+
second_order: bool = False,
46+
likelihood: Likelihood | None = None,
47+
mean_module: Mean | None = None,
48+
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
49+
input_transform: InputTransform | None = None,
50+
) -> None:
51+
"""Initialize the OrthogonalAdditiveGP.
52+
53+
Args:
54+
train_X: Training inputs (batch_shape x n x d) in [0, 1]^d.
55+
train_Y: Training outputs (batch_shape x n x 1).
56+
covar_module: An OrthogonalAdditiveKernel instance. If None, creates a
57+
default kernel with dim inferred from train_X.
58+
second_order: If True and covar_module is None, enables second-order
59+
interactions in the default kernel. Ignored if covar_module is provided.
60+
likelihood: Optional likelihood (defaults to GaussianLikelihood).
61+
mean_module: Optional mean module (defaults to ConstantMean).
62+
outcome_transform: Optional outcome transform.
63+
input_transform: Optional input transform.
64+
65+
Raises:
66+
TypeError: If covar_module is provided but is not an
67+
OrthogonalAdditiveKernel.
68+
"""
69+
if covar_module is None:
70+
covar_module = OrthogonalAdditiveKernel(
71+
base_kernel=RBFKernel(),
72+
dim=train_X.shape[-1],
73+
second_order=second_order,
74+
dtype=train_X.dtype,
75+
device=train_X.device,
76+
)
77+
elif not isinstance(covar_module, OrthogonalAdditiveKernel):
78+
raise TypeError(
79+
f"covar_module must be an OrthogonalAdditiveKernel, "
80+
f"got {type(covar_module).__name__}"
81+
)
82+
83+
super().__init__(
84+
train_X=train_X,
85+
train_Y=train_Y,
86+
likelihood=likelihood,
87+
covar_module=covar_module,
88+
mean_module=mean_module,
89+
outcome_transform=outcome_transform,
90+
input_transform=input_transform,
91+
)
92+
93+
@contextmanager
94+
def _component_inference_context(
95+
self, infer_all_components: bool = False
96+
) -> Iterator[None]:
97+
"""Context manager that temporarily sets component inference mode.
98+
99+
Args:
100+
infer_all_components: If True, enables per-component posterior inference.
101+
102+
Yields:
103+
None. The context manager sets internal state that is checked by
104+
`_get_test_prior_mean_and_covariances` to dispatch to the appropriate
105+
covariance computation.
106+
"""
107+
prev_state = self._infer_all_components
108+
self._infer_all_components = infer_all_components
109+
try:
110+
yield
111+
finally:
112+
self._infer_all_components = prev_state
113+
114+
def posterior(
115+
self,
116+
X: Tensor,
117+
output_indices: list[int] | None = None,
118+
observation_noise: bool = False,
119+
posterior_transform=None,
120+
infer_all_components: bool = False,
121+
) -> GPyTorchPosterior:
122+
"""Posterior inference of the additive Gaussian process.
123+
124+
Args:
125+
X: The input tensor of shape (batch_shape x n x d).
126+
output_indices: Not supported for this model.
127+
observation_noise: Whether to add observation noise to the posterior.
128+
posterior_transform: Optional posterior transform.
129+
infer_all_components: If True, returns a posterior with a batch
130+
dimension corresponding to each additive component (bias, first-order
131+
effects, and optionally second-order interactions). The number of
132+
components is 1 + d (first-order only) or 1 + d + d*(d-1)/2
133+
(with second-order interactions).
134+
135+
Returns:
136+
The posterior distribution at X.
137+
"""
138+
# Use context manager to set inference mode, then delegate to GPyTorch
139+
with self._component_inference_context(infer_all_components):
140+
return super().posterior(
141+
X=X,
142+
output_indices=output_indices,
143+
observation_noise=observation_noise,
144+
posterior_transform=posterior_transform,
145+
)
146+
147+
def _get_test_prior_mean_and_covariances(
148+
self,
149+
train_inputs: list[Tensor],
150+
test_inputs: list[Tensor],
151+
**kwargs,
152+
) -> tuple[
153+
Tensor,
154+
LinearOperator,
155+
LinearOperator,
156+
torch.Size,
157+
torch.Size,
158+
type[Distribution],
159+
]:
160+
"""Dispatches to appropriate covariance computation based on inference mode.
161+
162+
This method is called by GPyTorch's ExactGP.__call__ during posterior
163+
computation. When `_infer_all_components` is True (set via the context
164+
manager), it returns per-component covariances with an extra batch dimension.
165+
Otherwise, it uses the standard kernel forward pass which sums over components.
166+
"""
167+
if self._infer_all_components:
168+
return self._get_test_prior_mean_and_covariances_per_component(
169+
train_inputs, test_inputs, **kwargs
170+
)
171+
return super()._get_test_prior_mean_and_covariances(
172+
train_inputs=train_inputs, test_inputs=test_inputs, **kwargs
173+
)
174+
175+
def _get_test_prior_mean_and_covariances_per_component(
176+
self,
177+
train_inputs: list[Tensor],
178+
test_inputs: list[Tensor],
179+
**kwargs,
180+
) -> tuple[
181+
Tensor,
182+
LinearOperator,
183+
LinearOperator,
184+
torch.Size,
185+
torch.Size,
186+
type[Distribution],
187+
]:
188+
"""Computes mean and covariances with a batch dimension for each component.
189+
190+
This enables posterior inference of individual additive components by returning
191+
covariance matrices with an extra leading batch dimension for each component.
192+
193+
Returns:
194+
A tuple containing:
195+
- test_mean: The mean evaluated on the test set (batch_shape x n_test)
196+
- test_test_covar: Covariance between test points
197+
(num_components x batch_shape x n_test x n_test)
198+
- test_train_covar: Covariance between test and train points
199+
(num_components x batch_shape x n_test x n_train)
200+
- batch_shape: The batch shape of the model
201+
- test_shape: Shape (n_test,)
202+
- posterior_class: The class of the posterior to be instantiated
203+
"""
204+
if len(train_inputs) != 1 or len(test_inputs) != 1:
205+
raise ValueError(
206+
"OrthogonalAdditiveGP expects a single input X, but received "
207+
f"{len(train_inputs)=}, and {len(test_inputs)=}."
208+
)
209+
210+
X_train = train_inputs[0]
211+
X_test = test_inputs[0]
212+
213+
# Batch shape includes the component dimension as the leading dimension
214+
# This is needed so GPyTorch correctly reshapes the predictive mean
215+
num_components = self.covar_module.num_components
216+
batch_shape = torch.Size([num_components]) + X_train.shape[:-2]
217+
218+
# Get component-wise covariances using _non_reduced_forward
219+
# Shape: (num_components x batch_shape x n_test x n_test)
220+
test_test_covar = self.covar_module._non_reduced_forward(X_test, X_test)
221+
# Shape: (num_components x batch_shape x n_test x n_train)
222+
test_train_covar = self.covar_module._non_reduced_forward(X_test, X_train)
223+
224+
# Prior mean: Only the bias component (index 0) should have the prior mean.
225+
# All other components represent deviations from the mean, so their prior
226+
# mean should be zero. This ensures that when we sum over all components,
227+
# we get the correct total posterior mean (prior mean added once).
228+
n_test = X_test.shape[-2]
229+
# Create a (num_components, n_test) tensor of zeros
230+
test_mean = torch.zeros(
231+
num_components, n_test, dtype=X_test.dtype, device=X_test.device
232+
)
233+
# Set the bias component's mean to the actual prior mean
234+
test_mean[0, :] = self.mean_module(X_test)
235+
test_shape = torch.Size([n_test])
236+
237+
return (
238+
test_mean,
239+
test_test_covar,
240+
test_train_covar,
241+
batch_shape,
242+
test_shape,
243+
MultivariateNormal,
244+
)
245+
246+
@property
247+
def component_indices(self) -> dict[str, Tensor]:
248+
"""Returns component indices from the OrthogonalAdditiveKernel."""
249+
return self.covar_module.component_indices
250+
251+
def evaluate_first_order_on_grid(
252+
self,
253+
grid_1d: Tensor,
254+
) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
255+
r"""Evaluate first-order component posteriors on 1D grids.
256+
257+
Uses diagonal test inputs with the existing GPyTorch posterior
258+
infrastructure. Each first-order component is evaluated on its
259+
own independent 1D grid.
260+
261+
Args:
262+
grid_1d: 1D tensor of m points in [0, 1].
263+
264+
Returns:
265+
Tuple of:
266+
- bias: (mean, variance) - scalar values
267+
- first_order: ((d, m) means, (d, m) variances) on 1D grids
268+
269+
Example:
270+
>>> grid = torch.linspace(0, 1, 50)
271+
>>> (bias_mean, bias_var), (fo_mean, fo_var) = \\
272+
... model.evaluate_first_order_on_grid(grid)
273+
>>> # fo_mean[i, :] is component i's posterior mean on the 1D grid
274+
"""
275+
self.eval()
276+
m = len(grid_1d)
277+
d = self.covar_module.dim
278+
279+
# Diagonal test inputs: X[k, :] = [t_k, t_k, ..., t_k]
280+
# Each first-order component i sees its own 1D grid on dimension i
281+
X_diag = grid_1d.unsqueeze(-1).expand(m, d)
282+
283+
# Use existing posterior with all-components mode
284+
posterior = self.posterior(X_diag, infer_all_components=True)
285+
286+
# Squeeze output dimension (last dim) since this is single-output
287+
mean = posterior.mean.squeeze(-1) # (num_components, m)
288+
variance = posterior.variance.squeeze(-1) # (num_components, m)
289+
290+
# Extract bias (component 0) - should be constant across grid
291+
bias_mean = mean[0, :].mean()
292+
bias_var = variance[0, :].mean()
293+
294+
# Extract first-order (components 1 to d)
295+
first_order_means = mean[1 : d + 1, :] # (d, m)
296+
first_order_vars = variance[1 : d + 1, :] # (d, m)
297+
298+
return (bias_mean, bias_var), (first_order_means, first_order_vars)
299+
300+
@property
301+
def num_components(self) -> int:
302+
"""Total number of additive components (bias + first-order [+ second-order])."""
303+
return self.covar_module.num_components
304+
305+
def get_component_index(
306+
self,
307+
component_type: str,
308+
dim_index: int | tuple[int, int] | None = None,
309+
) -> int:
310+
"""Returns the component index for a given component type and dimension.
311+
312+
See OrthogonalAdditiveKernel.get_component_index for details.
313+
"""
314+
return self.covar_module.get_component_index(component_type, dim_index)

0 commit comments

Comments
 (0)