|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +# pyre-strict |
| 7 | +from itertools import product |
| 8 | +from math import log, pi |
| 9 | + |
| 10 | +import torch |
| 11 | +from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll |
| 12 | +from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP |
| 13 | +from botorch.models.gp_regression import SingleTaskGP |
| 14 | +from botorch.test_utils.mock import mock_optimize |
| 15 | +from botorch.utils.evaluation import AIC, BIC, compute_in_sample_model_fit_metric, MLL |
| 16 | +from botorch.utils.testing import BotorchTestCase |
| 17 | +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood |
| 18 | + |
| 19 | + |
| 20 | +class TestEvaluation(BotorchTestCase): |
| 21 | + @mock_optimize |
| 22 | + def test_compute_in_sample_model_fit_metric(self): |
| 23 | + torch.manual_seed(0) |
| 24 | + for dtype, model_cls in product( |
| 25 | + (torch.float, torch.double), (SingleTaskGP, SaasFullyBayesianSingleTaskGP) |
| 26 | + ): |
| 27 | + train_X = torch.linspace( |
| 28 | + 0, 1, 10, dtype=dtype, device=self.device |
| 29 | + ).unsqueeze(-1) |
| 30 | + train_Y = torch.sin(2 * pi * train_X) |
| 31 | + model = model_cls(train_X=train_X, train_Y=train_Y) |
| 32 | + if model_cls is SingleTaskGP: |
| 33 | + fit_gpytorch_mll(ExactMarginalLogLikelihood(model.likelihood, model)) |
| 34 | + else: |
| 35 | + fit_fully_bayesian_model_nuts( |
| 36 | + model, |
| 37 | + warmup_steps=8, |
| 38 | + num_samples=6, |
| 39 | + thinning=2, |
| 40 | + disable_progbar=True, |
| 41 | + ) |
| 42 | + num_params = sum(p.numel() for p in model.parameters()) |
| 43 | + if model_cls is SaasFullyBayesianSingleTaskGP: |
| 44 | + num_params /= 3 # divide by number of MCMC samples |
| 45 | + mll = compute_in_sample_model_fit_metric(model=model, criterion=MLL) |
| 46 | + aic = compute_in_sample_model_fit_metric(model=model, criterion=AIC) |
| 47 | + bic = compute_in_sample_model_fit_metric(model=model, criterion=BIC) |
| 48 | + self.assertEqual(aic, 2 * num_params - 2 * mll) |
| 49 | + self.assertEqual(bic, log(10) * num_params - 2 * mll) |
0 commit comments