Skip to content

Commit d1e4369

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add utility for computing AIC/BIC/MLL from a model (pytorch#2785)
Summary: Add utility for computing in-sample model fit metrics Differential Revision: D71827991
1 parent 83f50f4 commit d1e4369

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

botorch/utils/evaluation.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from math import log
8+
9+
import torch
10+
from botorch.models.model import Model
11+
from botorch.utils.transforms import is_fully_bayesian
12+
13+
MLL = "MLL"
14+
AIC = "AIC"
15+
BIC = "BIC"
16+
17+
18+
def compute_in_sample_model_fit_metric(model: Model, criterion: str) -> float:
19+
"""Compute a in-sample model fit metric.
20+
21+
Args:
22+
model: A fitted model.
23+
criterion: Evaluation criterion. One of "MLL", "AIC", "BIC".
24+
25+
Returns:
26+
The in-sample evaluation metric.
27+
"""
28+
if criterion not in (AIC, BIC, MLL):
29+
raise ValueError(f"Invalid evaluation criterion {criterion}.")
30+
if is_fully_bayesian(model=model):
31+
model.train(reset=False)
32+
else:
33+
model.train()
34+
with torch.no_grad():
35+
output = model(*model.train_inputs)
36+
output = model.likelihood(output)
37+
mll = output.log_prob(model.train_targets)
38+
# compute average MLL over MCMC samples if the model is fully bayesian
39+
mll_scalar = mll.mean().item()
40+
model.eval()
41+
num_params = sum(p.numel() for p in model.parameters())
42+
if is_fully_bayesian(model=model):
43+
num_params /= mll.shape[0]
44+
if criterion == AIC:
45+
return 2 * num_params - 2 * mll_scalar
46+
elif criterion == BIC:
47+
return num_params * log(model.train_inputs[0].shape[-2]) - 2 * mll_scalar
48+
return mll_scalar

sphinx/source/utils.rst

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ Dispatcher
3232
.. automodule:: botorch.utils.dispatcher
3333
:members:
3434

35+
Evaluation
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
.. automodule:: botorch.utils.evaluation
38+
:members:
39+
3540
Low-Rank Cholesky Update Utils
3641
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3742
.. automodule:: botorch.utils.low_rank

test/utils/test_evaluation.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
from itertools import product
8+
from math import log, pi
9+
10+
import torch
11+
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
12+
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
13+
from botorch.models.gp_regression import SingleTaskGP
14+
from botorch.test_utils.mock import mock_optimize
15+
from botorch.utils.evaluation import AIC, BIC, compute_in_sample_model_fit_metric, MLL
16+
from botorch.utils.testing import BotorchTestCase
17+
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
18+
19+
20+
class TestEvaluation(BotorchTestCase):
21+
@mock_optimize
22+
def test_compute_in_sample_model_fit_metric(self):
23+
torch.manual_seed(0)
24+
for dtype, model_cls in product(
25+
(torch.float, torch.double), (SingleTaskGP, SaasFullyBayesianSingleTaskGP)
26+
):
27+
train_X = torch.linspace(
28+
0, 1, 10, dtype=dtype, device=self.device
29+
).unsqueeze(-1)
30+
train_Y = torch.sin(2 * pi * train_X)
31+
model = model_cls(train_X=train_X, train_Y=train_Y)
32+
if model_cls is SingleTaskGP:
33+
fit_gpytorch_mll(ExactMarginalLogLikelihood(model.likelihood, model))
34+
else:
35+
fit_fully_bayesian_model_nuts(
36+
model,
37+
warmup_steps=8,
38+
num_samples=6,
39+
thinning=2,
40+
disable_progbar=True,
41+
)
42+
num_params = sum(p.numel() for p in model.parameters())
43+
if model_cls is SaasFullyBayesianSingleTaskGP:
44+
num_params /= 3 # divide by number of MCMC samples
45+
mll = compute_in_sample_model_fit_metric(model=model, criterion=MLL)
46+
aic = compute_in_sample_model_fit_metric(model=model, criterion=AIC)
47+
bic = compute_in_sample_model_fit_metric(model=model, criterion=BIC)
48+
self.assertEqual(aic, 2 * num_params - 2 * mll)
49+
self.assertEqual(bic, log(10) * num_params - 2 * mll)
50+
# test invalid criterion
51+
with self.assertRaisesRegex(
52+
ValueError, "Invalid evaluation criterion invalid."
53+
):
54+
compute_in_sample_model_fit_metric(model=model, criterion="invalid")

0 commit comments

Comments
 (0)