Skip to content

Commit e397819

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
OrthogonalAdditiveGP with component-wise inference (#3187)
Summary: This commit adds a new additive Gaussian process model `OrthogonalAdditiveGP`, which leverages the `OrthogonalAdditiveKernel`, and importantly, posterior inference of individual additive components, conditioned on noisy observations of the sum. This is enabled with the refactored GPyTorch posterior inference stack via `_get_test_prior_mean_and_covariances`. This relies on having an additional batch dimension for the test-test and train-test covariance, corresponding to the kernel matrices of each additive component, while the batch dimension has to be absent on the training set, because we are observing the *sum* of the additive components. Reviewed By: hvarfner Differential Revision: D92461397
1 parent b993a3c commit e397819

7 files changed

Lines changed: 2040 additions & 48 deletions

File tree

botorch/models/additive_gp.py

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from contextlib import contextmanager
10+
from typing import Iterator, TYPE_CHECKING
11+
12+
import torch
13+
from botorch.models.gp_regression import SingleTaskGP
14+
15+
if TYPE_CHECKING:
16+
from botorch.acquisition.objective import PosteriorTransform
17+
from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel
18+
from botorch.models.transforms.input import InputTransform
19+
from botorch.models.transforms.outcome import OutcomeTransform
20+
from botorch.posteriors.gpytorch import GPyTorchPosterior
21+
from botorch.utils.types import _DefaultType, DEFAULT
22+
from gpytorch.distributions import Distribution, MultivariateNormal
23+
from gpytorch.likelihoods import Likelihood
24+
from gpytorch.means import Mean
25+
from linear_operator.operators import LinearOperator
26+
from torch import Tensor
27+
28+
29+
class OrthogonalAdditiveGP(SingleTaskGP):
30+
"""A Gaussian Process with Orthogonal Additive Kernel for interpretable modeling.
31+
32+
This GP model uses an OrthogonalAdditiveKernel which decomposes the function into
33+
interpretable additive components: a bias term, first-order effects for each input
34+
dimension, and optionally second-order interaction terms.
35+
36+
The model supports posterior inference of individual additive components when
37+
`infer_all_components=True` is passed to the `posterior` method.
38+
"""
39+
40+
# Class-level default for inference mode (avoids __init__ override)
41+
_infer_all_components: bool = False
42+
43+
def __init__(
44+
self,
45+
train_X: Tensor,
46+
train_Y: Tensor,
47+
covar_module: OrthogonalAdditiveKernel | None = None,
48+
second_order: bool = False,
49+
likelihood: Likelihood | None = None,
50+
mean_module: Mean | None = None,
51+
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
52+
input_transform: InputTransform | None = None,
53+
) -> None:
54+
"""Initialize the OrthogonalAdditiveGP.
55+
56+
Args:
57+
train_X: Training inputs (batch_shape x n x d) in [0, 1]^d.
58+
train_Y: Training outputs (batch_shape x n x 1).
59+
covar_module: An OrthogonalAdditiveKernel instance. If None, creates a
60+
default OrthogonalAdditiveKernel with dim inferred from train_X.
61+
The kernel's default base kernel uses per-dimension lengthscales,
62+
so each additive component can adapt its smoothness independently.
63+
second_order: If True and covar_module is None, enables second-order
64+
interactions in the default kernel. Ignored if covar_module is provided.
65+
likelihood: Optional likelihood (defaults to GaussianLikelihood).
66+
mean_module: Optional mean module (defaults to ConstantMean).
67+
outcome_transform: Optional outcome transform.
68+
input_transform: Optional input transform.
69+
70+
Raises:
71+
TypeError: If covar_module is provided but is not an
72+
OrthogonalAdditiveKernel.
73+
"""
74+
if covar_module is None:
75+
covar_module = OrthogonalAdditiveKernel(
76+
dim=train_X.shape[-1],
77+
second_order=second_order,
78+
dtype=train_X.dtype,
79+
device=train_X.device,
80+
)
81+
elif not isinstance(covar_module, OrthogonalAdditiveKernel):
82+
raise TypeError(
83+
f"covar_module must be an OrthogonalAdditiveKernel, "
84+
f"got {type(covar_module).__name__}"
85+
)
86+
87+
super().__init__(
88+
train_X=train_X,
89+
train_Y=train_Y,
90+
likelihood=likelihood,
91+
covar_module=covar_module,
92+
mean_module=mean_module,
93+
outcome_transform=outcome_transform,
94+
input_transform=input_transform,
95+
)
96+
97+
@contextmanager
98+
def _component_inference_context(
99+
self, infer_all_components: bool = False
100+
) -> Iterator[None]:
101+
"""Context manager that temporarily sets component inference mode.
102+
103+
Args:
104+
infer_all_components: If True, enables per-component posterior inference.
105+
106+
Yields:
107+
None. The context manager sets internal state that is checked by
108+
`_get_test_prior_mean_and_covariances` to dispatch to the appropriate
109+
covariance computation.
110+
"""
111+
prev_state = self._infer_all_components
112+
self._infer_all_components = infer_all_components
113+
try:
114+
yield
115+
finally:
116+
self._infer_all_components = prev_state
117+
118+
def posterior(
119+
self,
120+
X: Tensor,
121+
output_indices: list[int] | None = None,
122+
observation_noise: bool = False,
123+
posterior_transform: PosteriorTransform | None = None,
124+
infer_all_components: bool = False,
125+
) -> GPyTorchPosterior:
126+
"""Posterior inference of the additive Gaussian process.
127+
128+
Args:
129+
X: The input tensor of shape (batch_shape x n x d).
130+
output_indices: Not supported for this model.
131+
observation_noise: Whether to add observation noise to the posterior.
132+
posterior_transform: Optional posterior transform.
133+
infer_all_components: If True, returns a posterior with a batch
134+
dimension corresponding to each additive component (bias, first-order
135+
effects, and optionally second-order interactions). The number of
136+
components is 1 + d (first-order only) or 1 + d + d*(d-1)/2
137+
(with second-order interactions).
138+
139+
Returns:
140+
The posterior distribution at X.
141+
"""
142+
# Use context manager to set inference mode, then delegate to GPyTorch
143+
with self._component_inference_context(infer_all_components):
144+
return super().posterior(
145+
X=X,
146+
output_indices=output_indices,
147+
observation_noise=observation_noise,
148+
posterior_transform=posterior_transform,
149+
)
150+
151+
def _get_test_prior_mean_and_covariances(
152+
self,
153+
train_inputs: list[Tensor],
154+
test_inputs: list[Tensor],
155+
**kwargs,
156+
) -> tuple[
157+
Tensor,
158+
LinearOperator,
159+
LinearOperator,
160+
torch.Size,
161+
torch.Size,
162+
type[Distribution],
163+
]:
164+
"""Dispatches to appropriate covariance computation based on inference mode.
165+
166+
This method is called by GPyTorch's ExactGP.__call__ during posterior
167+
computation. When `_infer_all_components` is True (set via the context
168+
manager), it returns per-component covariances with an extra batch dimension.
169+
Otherwise, it uses the standard kernel forward pass which sums over components.
170+
"""
171+
if self._infer_all_components:
172+
return self._get_test_prior_mean_and_covariances_per_component(
173+
train_inputs, test_inputs, **kwargs
174+
)
175+
return super()._get_test_prior_mean_and_covariances(
176+
train_inputs=train_inputs, test_inputs=test_inputs, **kwargs
177+
)
178+
179+
def _get_test_prior_mean_and_covariances_per_component(
180+
self,
181+
train_inputs: list[Tensor],
182+
test_inputs: list[Tensor],
183+
**kwargs,
184+
) -> tuple[
185+
Tensor,
186+
LinearOperator,
187+
LinearOperator,
188+
torch.Size,
189+
torch.Size,
190+
type[Distribution],
191+
]:
192+
"""Computes mean and covariances with a batch dimension for each component.
193+
194+
This enables posterior inference of individual additive components by returning
195+
covariance matrices with an extra leading batch dimension for each component.
196+
The component dimension is placed BEFORE any input batch dimensions, which is
197+
required for GPyTorch's prediction strategy to broadcast correctly against
198+
the cached training solve.
199+
200+
Returns:
201+
A tuple containing:
202+
- test_mean: (num_components, *input_batch_shape, n_test)
203+
- test_test_covar: (num_components, *input_batch_shape, n_test, n_test)
204+
- test_train_covar: (num_components, *input_batch_shape, n_test, n_train)
205+
- batch_shape: (num_components, *input_batch_shape)
206+
- test_shape: Shape (n_test,)
207+
- posterior_class: MultivariateNormal
208+
"""
209+
if len(train_inputs) != 1 or len(test_inputs) != 1:
210+
raise ValueError(
211+
"OrthogonalAdditiveGP expects a single input X, but received "
212+
f"{len(train_inputs)=}, and {len(test_inputs)=}."
213+
)
214+
215+
X_train = train_inputs[0]
216+
X_test = test_inputs[0]
217+
218+
# Component dim must be the LEADING batch dimension so that GPyTorch's
219+
# prediction strategy (which has mean_cache of shape (*input_batch, n))
220+
# can broadcast correctly: (num_components, *input_batch, m, n) @
221+
# (*input_batch, n, 1) broadcasts on the input_batch dims.
222+
num_components = self.covar_module.num_components
223+
input_batch_shape = X_train.shape[:-2]
224+
batch_shape = torch.Size([num_components]) + input_batch_shape
225+
226+
# The kernel's _non_reduced_forward returns tensors with shape
227+
# (*input_batch_shape, num_components, n, n), but GPyTorch needs the
228+
# component dim first: (num_components, *input_batch_shape, n, n).
229+
# We permute when there are input batch dims.
230+
test_test_covar = self.covar_module._non_reduced_forward(X_test, X_test)
231+
test_train_covar = self.covar_module._non_reduced_forward(X_test, X_train)
232+
233+
n_input_batch = len(input_batch_shape)
234+
if n_input_batch > 0:
235+
# Move component dim (at position n_input_batch) to position 0
236+
test_test_covar = torch.movedim(test_test_covar, n_input_batch, 0)
237+
test_train_covar = torch.movedim(test_train_covar, n_input_batch, 0)
238+
239+
# Prior mean: Only the bias component (index 0) should have the prior mean.
240+
# All other components represent deviations from the mean, so their prior
241+
# mean should be zero. This ensures that when we sum over all components,
242+
# we get the correct total posterior mean (prior mean added once).
243+
n_test = X_test.shape[-2]
244+
# Create a (num_components x batch_shape x n_test) tensor of zeros
245+
test_mean = torch.zeros(
246+
*batch_shape, n_test, dtype=X_test.dtype, device=X_test.device
247+
)
248+
# Set the bias component's mean to the actual prior mean.
249+
# Component dim is first, so [0, ...] selects bias for all batch elements.
250+
test_mean[0, ...] = self.mean_module(X_test)
251+
test_shape = torch.Size([n_test])
252+
253+
return (
254+
test_mean,
255+
test_test_covar,
256+
test_train_covar,
257+
batch_shape,
258+
test_shape,
259+
MultivariateNormal,
260+
)
261+
262+
@property
263+
def component_indices(self) -> dict[str, Tensor]:
264+
"""Returns component indices from the OrthogonalAdditiveKernel."""
265+
return self.covar_module.component_indices
266+
267+
def evaluate_first_order_on_grid(
268+
self,
269+
grid_1d: Tensor,
270+
) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
271+
r"""Evaluate first-order component posteriors on 1D grids.
272+
273+
Uses diagonal test inputs with the existing GPyTorch posterior
274+
infrastructure. Each first-order component is evaluated on its
275+
own independent 1D grid.
276+
277+
Supports models with batch-shaped training data. The grid itself should
278+
be 1D; batch dimensions come from the model's training data.
279+
280+
Args:
281+
grid_1d: 1D tensor of m points in [0, 1].
282+
283+
Returns:
284+
Tuple of:
285+
- bias: (mean, variance) — shape ``(*batch_shape,)`` each,
286+
averaged over the grid (should be approximately constant).
287+
- first_order: (means, variances) — shape
288+
``(d, *batch_shape, m)`` each, posterior statistics on 1D grids.
289+
290+
Example:
291+
>>> grid = torch.linspace(0, 1, 50)
292+
>>> (bias_mean, bias_var), (fo_mean, fo_var) = \
293+
... model.evaluate_first_order_on_grid(grid)
294+
>>> # fo_mean[i] is component i's posterior mean on the 1D grid
295+
"""
296+
m = len(grid_1d)
297+
d = self.covar_module.dim
298+
299+
# Diagonal test inputs: X[k, :] = [t_k, t_k, ..., t_k]
300+
# Each first-order component i sees its own 1D grid on dimension i.
301+
# No batch dims needed here — GPyTorch broadcasts against training data.
302+
X_diag = grid_1d.unsqueeze(-1).expand(m, d)
303+
304+
# Use existing posterior with all-components mode
305+
posterior = self.posterior(X_diag, infer_all_components=True)
306+
307+
# Squeeze output dimension (last dim) since this is single-output
308+
# Shape: (num_components, *batch_shape, m)
309+
mean = posterior.mean.squeeze(-1)
310+
variance = posterior.variance.squeeze(-1)
311+
312+
# Extract bias (component 0) - should be constant across grid.
313+
# Component dim is at position 0, so mean[0] gives (*batch_shape, m).
314+
# Average over the grid dimension (last dim) only, preserving batch dims.
315+
bias_mean = mean[0].mean(dim=-1)
316+
bias_var = variance[0].mean(dim=-1)
317+
318+
# Extract first-order (components 1 to d)
319+
first_order_means = mean[1 : d + 1] # (d, *batch_shape, m)
320+
first_order_vars = variance[1 : d + 1] # (d, *batch_shape, m)
321+
322+
return (bias_mean, bias_var), (first_order_means, first_order_vars)
323+
324+
@property
325+
def num_components(self) -> int:
326+
"""Total number of additive components (bias + first-order [+ second-order])."""
327+
return self.covar_module.num_components
328+
329+
def get_component_index(
330+
self,
331+
component_type: str,
332+
dim_index: int | tuple[int, int] | None = None,
333+
) -> int:
334+
"""Returns the component index for a given component type and dimension.
335+
336+
See OrthogonalAdditiveKernel.get_component_index for details.
337+
"""
338+
return self.covar_module.get_component_index(component_type, dim_index)

0 commit comments

Comments
 (0)