Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1417189

Browse files
saitcakmakfacebook-github-bot
authored andcommittedSep 5, 2024·
Update the remaining models to use new default covar & likelihood modules (#2507)
Summary: X-link: facebook/Ax#2742 Pull Request resolved: #2507 Updates the default covar & likelihood modules of BoTorch models. See #2451 for details on the new defaults. For models that utilize a composite kernel, such as multi-fidelity/task/context, this change only affects the base kernel. Exceptions / Models that do not utilize the new modules: - Fully-bayesian models. - Pairwise GP. - Higher order GP: Produced weird division by zero errors after the change. - Fidelity kernels for MF models. - (likelihood only) Any model that utilizes a likelihood other than `GaussianLikelihood` (e.g., `MultiTaskGaussianLikelihood`). Reviewed By: esantorella Differential Revision: D62196414 fbshipit-source-id: e2c8983a49a9f00d878e1fb7cf346212acb895e9
1 parent 3db1a0e commit 1417189

17 files changed

+74
-115
lines changed
 

‎botorch/models/approximate_gp.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
from botorch.models.transforms.outcome import OutcomeTransform
4141
from botorch.models.utils import validate_input_scaling
4242
from botorch.models.utils.gpytorch_modules import (
43-
get_gaussian_likelihood_with_gamma_prior,
44-
get_matern_kernel_with_gamma_prior,
43+
get_covar_module_with_dim_scaled_prior,
44+
get_gaussian_likelihood_with_lognormal_prior,
4545
)
4646
from botorch.models.utils.inducing_point_allocators import (
4747
GreedyVarianceReduction,
@@ -193,7 +193,7 @@ def __init__(
193193
this does not have to be all of the training inputs).
194194
train_Y: Not used.
195195
num_outputs: Number of output responses per input.
196-
covar_module: Kernel function. If omitted, uses a `MaternKernel`.
196+
covar_module: Kernel function. If omitted, uses an `RBFKernel`.
197197
mean_module: Mean of GP model. If omitted, uses a `ConstantMean`.
198198
variational_distribution: Type of variational distribution to use
199199
(default: CholeskyVariationalDistribution), the properties of the
@@ -217,15 +217,10 @@ def __init__(
217217
self._aug_batch_shape = aug_batch_shape
218218

219219
if covar_module is None:
220-
covar_module = get_matern_kernel_with_gamma_prior(
220+
covar_module = get_covar_module_with_dim_scaled_prior(
221221
ard_num_dims=train_X.shape[-1],
222222
batch_shape=self._aug_batch_shape,
223223
).to(train_X)
224-
self._subset_batch_dict = {
225-
"mean_module.constant": -2,
226-
"covar_module.raw_outputscale": -1,
227-
"covar_module.base_kernel.raw_lengthscale": -3,
228-
}
229224

230225
if inducing_point_allocator is None:
231226
inducing_point_allocator = GreedyVarianceReduction()
@@ -343,7 +338,7 @@ def __init__(
343338
either a `GaussianLikelihood` (if `num_outputs=1`) or a
344339
`MultitaskGaussianLikelihood`(if `num_outputs>1`).
345340
num_outputs: Number of output responses per input (default: 1).
346-
covar_module: Kernel function. If omitted, uses a `MaternKernel`.
341+
covar_module: Kernel function. If omitted, uses an `RBFKernel`.
347342
mean_module: Mean of GP model. If omitted, uses a `ConstantMean`.
348343
variational_distribution: Type of variational distribution to use
349344
(default: CholeskyVariationalDistribution), the properties of the
@@ -378,7 +373,7 @@ def __init__(
378373

379374
if likelihood is None:
380375
if num_outputs == 1:
381-
likelihood = get_gaussian_likelihood_with_gamma_prior(
376+
likelihood = get_gaussian_likelihood_with_lognormal_prior(
382377
batch_shape=self._aug_batch_shape
383378
)
384379
else:

‎botorch/models/contextual_multioutput.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
is common across all tasks.
6565
mean_module: The mean function to be used. Defaults to `ConstantMean`.
6666
covar_module: The module for computing the covariance matrix between
67-
the non-task features. Defaults to `MaternKernel`.
67+
the non-task features. Defaults to `RBFKernel`.
6868
likelihood: A likelihood. The default is selected based on `train_Yvar`.
6969
If `train_Yvar` is None, a standard `GaussianLikelihood` with inferred
7070
noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used.

‎botorch/models/gp_regression.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
is None, and a `FixedNoiseGaussianLikelihood` with the given
150150
noise observations if `train_Yvar` is not None.
151151
covar_module: The module computing the covariance (Kernel) matrix.
152-
If omitted, use a `MaternKernel`.
152+
If omitted, uses an `RBFKernel`.
153153
mean_module: The mean function to be used. If omitted, use a
154154
`ConstantMean`.
155155
outcome_transform: An outcome transform that is applied to the
@@ -207,6 +207,7 @@ def __init__(
207207
ard_num_dims=transformed_X.shape[-1],
208208
batch_shape=self._aug_batch_shape,
209209
)
210+
# Used for subsetting along the output dimension. See Model.subset_output.
210211
self._subset_batch_dict = {
211212
"mean_module.raw_constant": -1,
212213
"covar_module.raw_lengthscale": -3,

‎botorch/models/gp_regression_fidelity.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from __future__ import annotations
2727

2828
import warnings
29-
3029
from typing import Any, Optional, Union
3130

3231
import torch
@@ -39,9 +38,9 @@
3938
)
4039
from botorch.models.transforms.input import InputTransform
4140
from botorch.models.transforms.outcome import OutcomeTransform
41+
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
4242
from botorch.utils.datasets import SupervisedDataset
4343
from gpytorch.kernels.kernel import ProductKernel
44-
from gpytorch.kernels.rbf_kernel import RBFKernel
4544
from gpytorch.kernels.scale_kernel import ScaleKernel
4645
from gpytorch.likelihoods.likelihood import Likelihood
4746
from gpytorch.priors.torch_priors import GammaPrior
@@ -153,6 +152,7 @@ def __init__(
153152
outcome_transform=outcome_transform,
154153
input_transform=input_transform,
155154
)
155+
# Used for subsetting along the output dimension. See Model.subset_output.
156156
self._subset_batch_dict = {
157157
"mean_module.raw_constant": -1,
158158
"covar_module.raw_outputscale": -1,
@@ -273,10 +273,9 @@ def _setup_multifidelity_covar_module(
273273
non_active_dims.add(iteration_fidelity)
274274
active_dimsX = sorted(set(range(dim)) - non_active_dims)
275275
kernels.append(
276-
RBFKernel(
276+
get_covar_module_with_dim_scaled_prior(
277277
ard_num_dims=len(active_dimsX),
278278
batch_shape=aug_batch_shape,
279-
lengthscale_prior=GammaPrior(3.0, 6.0),
280279
active_dims=active_dimsX,
281280
)
282281
)

‎botorch/models/gp_regression_mixed.py

+3-28
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
from botorch.models.kernels.categorical import CategoricalKernel
1414
from botorch.models.transforms.input import InputTransform
1515
from botorch.models.transforms.outcome import OutcomeTransform
16+
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
1617
from botorch.utils.datasets import SupervisedDataset
1718
from botorch.utils.transforms import normalize_indices
1819
from gpytorch.constraints import GreaterThan
1920
from gpytorch.kernels.kernel import Kernel
20-
from gpytorch.kernels.matern_kernel import MaternKernel
2121
from gpytorch.kernels.scale_kernel import ScaleKernel
22-
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
2322
from gpytorch.likelihoods.likelihood import Likelihood
24-
from gpytorch.priors import GammaPrior
2523
from torch import Tensor
2624

2725

@@ -82,7 +80,7 @@ def __init__(
8280
cont_kernel_factory: A method that accepts `batch_shape`, `ard_num_dims`,
8381
and `active_dims` arguments and returns an instantiated GPyTorch
8482
`Kernel` object to be used as the base kernel for the continuous
85-
dimensions. If omitted, this model uses a Matern-2.5 kernel as
83+
dimensions. If omitted, this model uses an `RBFKernel` as
8684
the kernel for the ordinal parameters.
8785
likelihood: A likelihood. If omitted, use a standard
8886
GaussianLikelihood with inferred noise level.
@@ -105,30 +103,7 @@ def __init__(
105103
_, aug_batch_shape = self.get_batch_dimensions(train_X=train_X, train_Y=train_Y)
106104

107105
if cont_kernel_factory is None:
108-
109-
def cont_kernel_factory(
110-
batch_shape: torch.Size,
111-
ard_num_dims: int,
112-
active_dims: list[int],
113-
) -> MaternKernel:
114-
return MaternKernel(
115-
nu=2.5,
116-
batch_shape=batch_shape,
117-
ard_num_dims=ard_num_dims,
118-
active_dims=active_dims,
119-
lengthscale_constraint=GreaterThan(1e-04),
120-
)
121-
122-
if likelihood is None and train_Yvar is None:
123-
# This Gamma prior is quite close to the Horseshoe prior
124-
min_noise = 1e-5 if train_X.dtype == torch.float else 1e-6
125-
likelihood = GaussianLikelihood(
126-
batch_shape=aug_batch_shape,
127-
noise_constraint=GreaterThan(
128-
min_noise, transform=None, initial_value=1e-3
129-
),
130-
noise_prior=GammaPrior(0.9, 10.0),
131-
)
106+
cont_kernel_factory = get_covar_module_with_dim_scaled_prior
132107

133108
d = train_X.shape[-1]
134109
cat_dims = normalize_indices(indices=cat_dims, d=d)

‎botorch/models/gpytorch.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,9 @@ def subset_output(self, idcs: list[int]) -> BatchedMultiOutputGPyTorchModel:
543543
subset_batch_dict = self._subset_batch_dict
544544
except AttributeError:
545545
raise NotImplementedError(
546-
"subset_output requires the model to define a `_subset_dict` attribute"
546+
"`subset_output` requires the model to define a `_subset_batch_dict` "
547+
"attribute that lists the indices of the output dimensions in each "
548+
"model parameter that needs to be subset."
547549
)
548550

549551
m = len(idcs)

‎botorch/models/kernels/contextual_lcea.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from typing import Any, Optional
88

99
import torch
10+
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
1011
from gpytorch.constraints import Positive
1112
from gpytorch.kernels.kernel import Kernel
12-
from gpytorch.kernels.matern_kernel import MaternKernel
1313
from gpytorch.priors.torch_priors import GammaPrior
1414
from linear_operator.operators import DiagLinearOperator
1515
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
@@ -158,18 +158,14 @@ def __init__(
158158
if train_embedding:
159159
self._set_emb_layers()
160160
# task covariance matrix
161-
self.task_covar_module = MaternKernel(
162-
nu=2.5,
161+
self.task_covar_module = get_covar_module_with_dim_scaled_prior(
163162
ard_num_dims=self.n_embs,
164163
batch_shape=batch_shape,
165-
lengthscale_prior=GammaPrior(3.0, 6.0),
166164
)
167165
# base kernel
168-
self.base_kernel = MaternKernel(
169-
nu=2.5,
166+
self.base_kernel = get_covar_module_with_dim_scaled_prior(
170167
ard_num_dims=self.num_param,
171168
batch_shape=batch_shape,
172-
lengthscale_prior=GammaPrior(3.0, 6.0),
173169
)
174170
# outputscales for each context (note this is like sqrt of outputscale)
175171
self.context_weight = None

‎botorch/models/kernels/contextual_sac.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import Any, Optional
88

99
import torch
10+
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
1011
from gpytorch.kernels.kernel import Kernel
11-
from gpytorch.kernels.matern_kernel import MaternKernel
1212
from gpytorch.kernels.scale_kernel import ScaleKernel
1313
from gpytorch.priors.torch_priors import GammaPrior
1414
from linear_operator.operators.sum_linear_operator import SumLinearOperator
@@ -36,7 +36,7 @@ class SACKernel(Kernel):
3636
where
3737
* :math: M is the number of partitions of parameter space. Each partition contains
3838
same number of parameters d. Each kernel `k_i` acts only on d parameters of ith
39-
partition i.e. `\mathbf{x}_(i)`. Each kernel `k_i` is a scaled Matern kernel
39+
partition i.e. `\mathbf{x}_(i)`. Each kernel `k_i` is a scaled RBF kernel
4040
with same lengthscales but different outputscales.
4141
"""
4242

@@ -72,11 +72,9 @@ def __init__(
7272
for context, active_params in self.decomposition.items()
7373
}
7474

75-
self.base_kernel = MaternKernel(
76-
nu=2.5,
75+
self.base_kernel = get_covar_module_with_dim_scaled_prior(
7776
ard_num_dims=num_param,
7877
batch_shape=batch_shape,
79-
lengthscale_prior=GammaPrior(3.0, 6.0),
8078
)
8179

8280
self.kernel_dict = {} # scaled kernel for each parameter space partition

‎botorch/models/multitask.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
from botorch.models.transforms.input import InputTransform
4141
from botorch.models.transforms.outcome import OutcomeTransform
4242
from botorch.models.utils.gpytorch_modules import (
43-
get_matern_kernel_with_gamma_prior,
43+
get_covar_module_with_dim_scaled_prior,
44+
get_gaussian_likelihood_with_lognormal_prior,
4445
MIN_INFERRED_NOISE_LEVEL,
4546
)
4647
from botorch.posteriors.multitask import MultitaskGPPosterior
@@ -51,12 +52,8 @@
5152
)
5253
from gpytorch.distributions.multivariate_normal import MultivariateNormal
5354
from gpytorch.kernels.index_kernel import IndexKernel
54-
from gpytorch.kernels.matern_kernel import MaternKernel
5555
from gpytorch.kernels.multitask_kernel import MultitaskKernel
56-
from gpytorch.likelihoods.gaussian_likelihood import (
57-
FixedNoiseGaussianLikelihood,
58-
GaussianLikelihood,
59-
)
56+
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
6057
from gpytorch.likelihoods.likelihood import Likelihood
6158
from gpytorch.likelihoods.multitask_gaussian_likelihood import (
6259
MultitaskGaussianLikelihood,
@@ -167,7 +164,7 @@ def __init__(
167164
Note that the inferred noise is common across all tasks.
168165
mean_module: The mean function to be used. Defaults to `ConstantMean`.
169166
covar_module: The module for computing the covariance matrix between
170-
the non-task features. Defaults to `MaternKernel`.
167+
the non-task features. Defaults to `RBFKernel`.
171168
likelihood: A likelihood. The default is selected based on `train_Yvar`.
172169
If `train_Yvar` is None, a standard `GaussianLikelihood` with inferred
173170
noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used.
@@ -233,7 +230,7 @@ def __init__(
233230
# TODO (T41270962): Support task-specific noise levels in likelihood
234231
if likelihood is None:
235232
if train_Yvar is None:
236-
likelihood = GaussianLikelihood(noise_prior=GammaPrior(1.1, 0.05))
233+
likelihood = get_gaussian_likelihood_with_lognormal_prior()
237234
else:
238235
likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))
239236

@@ -247,7 +244,7 @@ def __init__(
247244
)
248245
self.mean_module = mean_module or ConstantMean()
249246
if covar_module is None:
250-
self.covar_module = get_matern_kernel_with_gamma_prior(
247+
self.covar_module = get_covar_module_with_dim_scaled_prior(
251248
ard_num_dims=self.num_non_task_features
252249
)
253250
else:
@@ -442,7 +439,7 @@ def __init__(
442439
`MultitaskGaussianLikelihood` with a `GammaPrior(1.1, 0.05)`
443440
noise prior.
444441
data_covar_module: The module computing the covariance (Kernel) matrix
445-
in data space. If omitted, use a `MaternKernel`.
442+
in data space. If omitted, uses an `RBFKernel`.
446443
task_covar_prior : A Prior on the task covariance matrix. Must operate
447444
on p.s.d. matrices. A common prior for this is the `LKJ` prior. If
448445
omitted, uses `LKJCovariancePrior` with `eta` parameter as specified
@@ -500,10 +497,8 @@ def __init__(
500497
base_means=ConstantMean(batch_shape=batch_shape), num_tasks=num_tasks
501498
)
502499
if data_covar_module is None:
503-
data_covar_module = MaternKernel(
504-
nu=2.5,
500+
data_covar_module = get_covar_module_with_dim_scaled_prior(
505501
ard_num_dims=ard_num_dims,
506-
lengthscale_prior=GammaPrior(3.0, 6.0),
507502
batch_shape=batch_shape,
508503
)
509504
else:

‎botorch/models/utils/gpytorch_modules.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
from math import log, sqrt
21-
from typing import Optional, Union
21+
from typing import Optional, Sequence, Union
2222

2323
import torch
2424
from gpytorch.constraints.constraints import GreaterThan
@@ -101,14 +101,18 @@ def get_covar_module_with_dim_scaled_prior(
101101
ard_num_dims: int,
102102
batch_shape: Optional[torch.Size] = None,
103103
use_rbf_kernel: bool = True,
104-
) -> Union[MaternKernel, RBFKernel, ScaleKernel]:
104+
active_dims: Optional[Sequence[int]] = None,
105+
) -> Union[MaternKernel, RBFKernel]:
105106
"""Returns an RBF or Matern kernel with priors
106107
from [Hvarfner2024vanilla]_.
107108
108109
Args:
109110
ard_num_dims: Number of feature dimensions for ARD.
110111
batch_shape: Batch shape for the covariance module.
111112
use_rbf_kernel: Whether to use an RBF kernel. If False, uses a Matern kernel.
113+
active_dims: The set of input dimensions to compute the covariances on.
114+
By default, the covariance is computed using the full input tensor.
115+
Set this if you'd like to ignore certain dimensions.
112116
113117
Returns:
114118
A Kernel constructed according to the given arguments. The prior is constrained
@@ -123,5 +127,7 @@ def get_covar_module_with_dim_scaled_prior(
123127
lengthscale_constraint=GreaterThan(
124128
2.5e-2, transform=None, initial_value=lengthscale_prior.mode
125129
),
130+
# pyre-ignore[6] GPyTorch type is unnecessarily restrictive.
131+
active_dims=active_dims,
126132
)
127133
return base_kernel

‎docs/models.md

+9-5
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,14 @@ instead.
121121
a fully Bayesian multi-task GP using an ICM kernel. The data kernel uses the
122122
SAAS prior to model high-dimensional parameter spaces.
123123

124-
All of the above models use Matérn 5/2 kernels with Automatic Relevance
125-
Discovery (ARD), and have reasonable priors on hyperparameters that make them
126-
work well in settings where the **input features are normalized to the unit
127-
cube** and the **observations are standardized** (zero mean, unit variance).
124+
All of the above models use RBF kernels with Automatic Relevance Discovery
125+
(ARD), and have reasonable priors on hyperparameters that make them work well in
126+
settings where the **input features are normalized to the unit cube** and the
127+
**observations are standardized** (zero mean, unit variance). The lengthscale
128+
priors scale with the input dimension, which makes them adaptable to both low
129+
and high dimensional problems. See
130+
[this discussion](https://github.com/pytorch/botorch/discussions/2451) for
131+
additional context on the default hyperparameters.
128132

129133
## Other useful models
130134

@@ -182,6 +186,6 @@ model. If you wish to use gradient-based optimization algorithms, the model
182186
should allow back-propagating gradients through the samples to the model input.
183187

184188
If you happen to implement a model that would be useful for other researchers as
185-
well (and involves more than just swapping out the Matérn kernel for an RBF
189+
well (and involves more than just swapping out the RBF kernel for a Matérn
186190
kernel), please consider [contributing](getting_started#contributing) this model
187191
to BoTorch.

‎test/acquisition/multi_objective/test_monte_carlo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1841,7 +1841,7 @@ def test_with_multitask(self):
18411841
def _test_with_multitask(self, acqf_class: type[AcquisitionFunction]):
18421842
# Verify that _set_sampler works with MTGP, KroneckerMTGP and HOGP.
18431843
torch.manual_seed(1234)
1844-
tkwargs = {"device": self.device, "dtype": torch.double}
1844+
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
18451845
train_x = torch.rand(6, 2, **tkwargs)
18461846
train_y = torch.randn(6, 2, **tkwargs)
18471847
mtgp_task = torch.cat(
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Please sign in to comment.