1515import torch
1616from botorch .exceptions .errors import BotorchTensorDimensionError
1717from botorch .exceptions .warnings import UserInputWarning
18+ from botorch .fit import fit_gpytorch_mll
19+ from botorch .models .multitask import MultiTaskGP
1820from botorch .models .transforms .input import (
1921 AffineInputTransform ,
2022 AppendFeatures ,
3638)
3739from botorch .models .transforms .utils import expand_and_copy_tensor
3840from botorch .models .utils import fantasize
41+ from botorch .test_utils .mock import mock_optimize_context_manager
3942from botorch .utils .testing import BotorchTestCase
4043from gpytorch import Module as GPyTorchModule
44+ from gpytorch .mlls import ExactMarginalLogLikelihood
4145from gpytorch .priors import LogNormalPrior
4246from torch import Tensor
4347from torch .distributions import Kumaraswamy
@@ -1657,7 +1661,9 @@ def test_learned_feature_imputation(self) -> None:
16571661 ** tkwargs ,
16581662 )
16591663 self .assertEqual (tf .num_tasks , 2 )
1660- self .assertEqual (tf .raw_imputation_values .shape , torch .Size ([2 , d + 1 ]))
1664+ self .assertEqual (
1665+ tf .raw_imputation_values .shape , torch .Size ([2 * (d + 1 )])
1666+ )
16611667 # missing_mask: shape (num_tasks, d+1), task col always False.
16621668 self .assertTrue (tf .missing_mask [0 , 3 ].item ())
16631669 self .assertFalse (tf .missing_mask [0 , 0 ].item ())
@@ -1671,10 +1677,7 @@ def test_learned_feature_imputation(self) -> None:
16711677 ** tkwargs ,
16721678 )
16731679 tf .raw_imputation_values .data = torch .tensor (
1674- [
1675- [0.0 , 0.0 , 0.0 , 0.5 , 0.0 ],
1676- [0.0 , 0.0 , 0.7 , 0.0 , 0.0 ],
1677- ],
1680+ [0.0 , 0.0 , 0.0 , 0.5 , 0.0 , 0.0 , 0.0 , 0.7 , 0.0 , 0.0 ],
16781681 ** tkwargs ,
16791682 )
16801683 X = torch .tensor (
@@ -1746,10 +1749,11 @@ def test_learned_feature_imputation(self) -> None:
17461749 tf (X_grad ).sum ().backward ()
17471750 grad = tf .raw_imputation_values .grad
17481751 self .assertIsNotNone (grad )
1749- self .assertNotEqual (grad [1 , 1 ].item (), 0.0 )
1752+ # d=2 → stride is d+1=3. Task 1, feature 1 → index 4.
1753+ self .assertNotEqual (grad [4 ].item (), 0.0 )
17501754 # Task 0 observes both features → no imputation → no grad.
1751- self .assertEqual (grad [0 , 0 ].item (), 0.0 )
1752- self .assertEqual (grad [0 , 1 ].item (), 0.0 )
1755+ self .assertEqual (grad [0 ].item (), 0.0 )
1756+ self .assertEqual (grad [1 ].item (), 0.0 )
17531757
17541758 with self .subTest ("untransform_raises" , dtype = dtype ):
17551759 tf = LearnedFeatureImputation (feature_indices = {0 : [0 ]}, d = 1 , ** tkwargs )
@@ -1769,10 +1773,7 @@ def test_learned_feature_imputation(self) -> None:
17691773 ** tkwargs ,
17701774 )
17711775 tf .raw_imputation_values .data = torch .tensor (
1772- [
1773- [0.0 , 0.0 , 0.0 , 0.5 , 0.6 , 0.0 ],
1774- [0.0 , 0.0 , 0.7 , 0.0 , 0.0 , 0.0 ],
1775- ],
1776+ [0.0 , 0.0 , 0.0 , 0.5 , 0.6 , 0.0 , 0.0 , 0.0 , 0.7 , 0.0 , 0.0 , 0.0 ],
17761777 ** tkwargs ,
17771778 )
17781779 # Imputation with asymmetric observed feature counts.
@@ -1853,7 +1854,8 @@ def test_learned_feature_imputation(self) -> None:
18531854 tf_g (
18541855 torch .tensor ([[1.0 , 2.0 , 0.0 ], [3.0 , 9.0 , 1.0 ]], ** tkwargs )
18551856 ).sum ().backward ()
1856- self .assertNotEqual (tf_g .raw_imputation_values .grad [1 , 1 ].item (), 0.0 )
1857+ # d=2 → stride is 3. Task 1, feature 1 → index 4.
1858+ self .assertNotEqual (tf_g .raw_imputation_values .grad [4 ].item (), 0.0 )
18571859
18581860 with self .subTest ("three_tasks" , dtype = dtype ):
18591861 tf = LearnedFeatureImputation (
@@ -1862,11 +1864,7 @@ def test_learned_feature_imputation(self) -> None:
18621864 ** tkwargs ,
18631865 )
18641866 tf .raw_imputation_values .data = torch .tensor (
1865- [
1866- [0.0 , 0.0 , 0.3 , 0.0 ],
1867- [0.4 , 0.0 , 0.0 , 0.0 ],
1868- [0.0 , 0.5 , 0.0 , 0.0 ],
1869- ],
1867+ [0.0 , 0.0 , 0.3 , 0.0 , 0.4 , 0.0 , 0.0 , 0.0 , 0.0 , 0.5 , 0.0 , 0.0 ],
18701868 ** tkwargs ,
18711869 )
18721870 X_three_tasks = torch .tensor (
@@ -1895,7 +1893,7 @@ def test_learned_feature_imputation(self) -> None:
18951893 )
18961894 )
18971895 tf .raw_imputation_values .data = torch .tensor (
1898- [[ 0.0 , 0.0 , 0.0 , 0.5 , 0.0 ], [ 0.0 , 0.0 , 0.7 , 0.0 , 0.0 ] ],
1896+ [0.0 , 0.0 , 0.0 , 0.5 , 0.0 , 0.0 , 0.0 , 0.7 , 0.0 , 0.0 ],
18991897 ** tkwargs ,
19001898 )
19011899 X_noncontig = torch .tensor (
@@ -1917,7 +1915,7 @@ def test_learned_feature_imputation(self) -> None:
19171915 ** tkwargs ,
19181916 )
19191917 tf .raw_imputation_values .data = torch .tensor (
1920- [[ 0.0 , 0.0 , 0.0 , 0.5 , 0.0 ], [ 0.0 , 0.0 , 0.7 , 0.0 , 0.0 ] ],
1918+ [0.0 , 0.0 , 0.0 , 0.5 , 0.0 , 0.0 , 0.0 , 0.7 , 0.0 , 0.0 ],
19211919 ** tkwargs ,
19221920 )
19231921 X_batch = torch .tensor (
@@ -1940,7 +1938,7 @@ def test_learned_feature_imputation(self) -> None:
19401938 ** tkwargs ,
19411939 )
19421940 tf .raw_imputation_values .data = torch .tensor (
1943- [[ 0.0 , 0.0 , 0.0 , 0.5 , 0.0 ], [ 0.0 , 0.0 , 0.7 , 0.0 , 0.0 ] ],
1941+ [0.0 , 0.0 , 0.0 , 0.5 , 0.0 , 0.0 , 0.0 , 0.7 , 0.0 , 0.0 ],
19441942 ** tkwargs ,
19451943 )
19461944 X_no_task = torch .tensor (
@@ -1984,6 +1982,47 @@ def test_learned_feature_imputation(self) -> None:
19841982 with self .assertRaisesRegex (ValueError , "Expected X.shape" ):
19851983 tf (torch .zeros (2 , d + 3 , ** tkwargs ))
19861984
1985+ with self .subTest ("fit_gpytorch_mll_with_bounds" , dtype = dtype ):
1986+ n = 5
1987+ X = torch .cat (
1988+ [
1989+ torch .cat (
1990+ [
1991+ torch .rand (n , d , ** tkwargs ),
1992+ i * torch .ones (n , 1 , ** tkwargs ),
1993+ ],
1994+ dim = - 1 ,
1995+ )
1996+ for i in range (len (feature_indices ))
1997+ ]
1998+ )
1999+ Y = torch .randn (len (feature_indices ) * n , 1 , ** tkwargs )
2000+ bounds = torch .stack (
2001+ [torch .zeros (d , ** tkwargs ), torch .ones (d , ** tkwargs )]
2002+ )
2003+ lfi = LearnedFeatureImputation (
2004+ feature_indices = feature_indices , d = d , bounds = bounds , ** tkwargs
2005+ )
2006+ model = MultiTaskGP (
2007+ train_X = X ,
2008+ train_Y = Y ,
2009+ task_feature = - 1 ,
2010+ input_transform = ChainedInputTransform (
2011+ tf0 = Normalize (d = d + 1 , indices = list (range (d ))),
2012+ tf1 = lfi ,
2013+ ),
2014+ )
2015+ mll = ExactMarginalLogLikelihood (model .likelihood , model )
2016+ with mock_optimize_context_manager ():
2017+ fit_gpytorch_mll (mll , max_attempts = 1 )
2018+ imp = lfi .imputation_values
2019+ # These are the two learnable imputation values under the current
2020+ # setup, so these should be the only non-zero values.
2021+ self .assertNotEqual (imp [0 , 3 ].item (), 0.0 )
2022+ self .assertNotEqual (imp [1 , 2 ].item (), 0.0 )
2023+ # and here is one that should be zero.
2024+ self .assertNotEqual (imp [0 , 1 ].item (), 0.0 )
2025+
19872026
19882027class TestAppendFeatures (BotorchTestCase ):
19892028 def test_append_features (self ) -> None :
0 commit comments