Skip to content

Commit e5dae52

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Flatten LearnedFeatureImputation raw parameter for scipy fitting compatibility (#3299)
Summary: The `raw_imputation_values` parameter was shape `(num_tasks, d+1)`, but its Interval constraint bounds were only `(d+1,)`. When the scipy fitting path in `get_bounds_as_ndarray` flattens the parameter to `num_tasks*(d+1)` elements, it tries to assign `(d+1,)` bounds into a slice of that length, causing a ValueError. Flatten `raw_imputation_values` to 1-D with shape `(num_tasks*(d+1),)` and repeat the Interval bounds to match. The `imputation_values` property reshapes back to `(num_tasks, d+1)` for use in `transform()`. Adds a `fit_gpytorch_mll_with_bounds` subtest that fits a MultiTaskGP with Normalize + LearnedFeatureImputation (with bounds) through the scipy optimizer path, verifying no shape mismatch occurs. Differential Revision: D102789747
1 parent ea00d74 commit e5dae52

2 files changed

Lines changed: 71 additions & 34 deletions

File tree

botorch/models/transforms/input.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,38 +2042,36 @@ def __init__(
20422042
missing_mask[task_pos, feature_indices[task_value]] = False
20432043
self.register_buffer("missing_mask", missing_mask)
20442044

2045-
# Learnable imputation values, shape (num_tasks, d+1). The task column
2046-
# slot is unused but kept for index alignment with X columns.
2045+
# Learnable imputation values stored as 1-D so that gpytorch's scipy
2046+
# fitting path (which flattens parameters) sees a bound tensor with
2047+
# matching numel. Reshaped to (num_tasks, d+1) in `imputation_values`.
20472048
self.register_parameter(
20482049
"raw_imputation_values",
20492050
nn.Parameter(
2050-
torch.zeros(self.num_tasks, d + 1, dtype=dtype, device=device)
2051+
torch.zeros(self.num_tasks * (d + 1), dtype=dtype, device=device)
20512052
),
20522053
)
20532054
if bounds is not None:
2054-
# Pad bounds with dummy [0, 1] for the task column so the Interval
2055-
# constraint has shape (d+1,) matching raw_imputation_values.
20562055
padded_lower = torch.zeros(d + 1, dtype=dtype, device=device)
20572056
padded_upper = torch.ones(d + 1, dtype=dtype, device=device)
20582057
padded_lower[:d] = bounds[0]
20592058
padded_upper[:d] = bounds[1]
20602059
self.register_constraint(
20612060
"raw_imputation_values",
20622061
Interval(
2063-
lower_bound=padded_lower,
2064-
upper_bound=padded_upper,
2062+
lower_bound=padded_lower.repeat(self.num_tasks),
2063+
upper_bound=padded_upper.repeat(self.num_tasks),
20652064
),
20662065
)
20672066

20682067
@property
20692068
def imputation_values(self) -> Tensor:
2070-
r"""The imputation values, mapped through the Interval constraint when
2071-
bounds are present, or the raw values otherwise."""
2069+
r"""The imputation values reshaped to ``(num_tasks, d+1)``, mapped
2070+
through the Interval constraint when bounds are present."""
2071+
raw = self.raw_imputation_values
20722072
if self.bounds is not None:
2073-
return self.raw_imputation_values_constraint.transform(
2074-
self.raw_imputation_values
2075-
)
2076-
return self.raw_imputation_values
2073+
raw = self.raw_imputation_values_constraint.transform(raw)
2074+
return raw.view(self.num_tasks, self.d + 1)
20772075

20782076
def transform(self, X: Tensor) -> Tensor:
20792077
r"""Impute missing features with learned values.

test/models/transforms/test_input.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import torch
1616
from botorch.exceptions.errors import BotorchTensorDimensionError
1717
from botorch.exceptions.warnings import UserInputWarning
18+
from botorch.fit import fit_gpytorch_mll
19+
from botorch.models.multitask import MultiTaskGP
1820
from botorch.models.transforms.input import (
1921
AffineInputTransform,
2022
AppendFeatures,
@@ -36,8 +38,10 @@
3638
)
3739
from botorch.models.transforms.utils import expand_and_copy_tensor
3840
from botorch.models.utils import fantasize
41+
from botorch.test_utils.mock import mock_optimize_context_manager
3942
from botorch.utils.testing import BotorchTestCase
4043
from gpytorch import Module as GPyTorchModule
44+
from gpytorch.mlls import ExactMarginalLogLikelihood
4145
from gpytorch.priors import LogNormalPrior
4246
from torch import Tensor
4347
from torch.distributions import Kumaraswamy
@@ -1657,7 +1661,9 @@ def test_learned_feature_imputation(self) -> None:
16571661
**tkwargs,
16581662
)
16591663
self.assertEqual(tf.num_tasks, 2)
1660-
self.assertEqual(tf.raw_imputation_values.shape, torch.Size([2, d + 1]))
1664+
self.assertEqual(
1665+
tf.raw_imputation_values.shape, torch.Size([2 * (d + 1)])
1666+
)
16611667
# missing_mask: shape (num_tasks, d+1), task col always False.
16621668
self.assertTrue(tf.missing_mask[0, 3].item())
16631669
self.assertFalse(tf.missing_mask[0, 0].item())
@@ -1671,10 +1677,7 @@ def test_learned_feature_imputation(self) -> None:
16711677
**tkwargs,
16721678
)
16731679
tf.raw_imputation_values.data = torch.tensor(
1674-
[
1675-
[0.0, 0.0, 0.0, 0.5, 0.0],
1676-
[0.0, 0.0, 0.7, 0.0, 0.0],
1677-
],
1680+
[0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0],
16781681
**tkwargs,
16791682
)
16801683
X = torch.tensor(
@@ -1746,10 +1749,11 @@ def test_learned_feature_imputation(self) -> None:
17461749
tf(X_grad).sum().backward()
17471750
grad = tf.raw_imputation_values.grad
17481751
self.assertIsNotNone(grad)
1749-
self.assertNotEqual(grad[1, 1].item(), 0.0)
1752+
# d=2 → stride is d+1=3. Task 1, feature 1 → index 4.
1753+
self.assertNotEqual(grad[4].item(), 0.0)
17501754
# Task 0 observes both features → no imputation → no grad.
1751-
self.assertEqual(grad[0, 0].item(), 0.0)
1752-
self.assertEqual(grad[0, 1].item(), 0.0)
1755+
self.assertEqual(grad[0].item(), 0.0)
1756+
self.assertEqual(grad[1].item(), 0.0)
17531757

17541758
with self.subTest("untransform_raises", dtype=dtype):
17551759
tf = LearnedFeatureImputation(feature_indices={0: [0]}, d=1, **tkwargs)
@@ -1769,10 +1773,7 @@ def test_learned_feature_imputation(self) -> None:
17691773
**tkwargs,
17701774
)
17711775
tf.raw_imputation_values.data = torch.tensor(
1772-
[
1773-
[0.0, 0.0, 0.0, 0.5, 0.6, 0.0],
1774-
[0.0, 0.0, 0.7, 0.0, 0.0, 0.0],
1775-
],
1776+
[0.0, 0.0, 0.0, 0.5, 0.6, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 0.0],
17761777
**tkwargs,
17771778
)
17781779
# Imputation with asymmetric observed feature counts.
@@ -1853,7 +1854,8 @@ def test_learned_feature_imputation(self) -> None:
18531854
tf_g(
18541855
torch.tensor([[1.0, 2.0, 0.0], [3.0, 9.0, 1.0]], **tkwargs)
18551856
).sum().backward()
1856-
self.assertNotEqual(tf_g.raw_imputation_values.grad[1, 1].item(), 0.0)
1857+
# d=2 → stride is 3. Task 1, feature 1 → index 4.
1858+
self.assertNotEqual(tf_g.raw_imputation_values.grad[4].item(), 0.0)
18571859

18581860
with self.subTest("three_tasks", dtype=dtype):
18591861
tf = LearnedFeatureImputation(
@@ -1862,11 +1864,7 @@ def test_learned_feature_imputation(self) -> None:
18621864
**tkwargs,
18631865
)
18641866
tf.raw_imputation_values.data = torch.tensor(
1865-
[
1866-
[0.0, 0.0, 0.3, 0.0],
1867-
[0.4, 0.0, 0.0, 0.0],
1868-
[0.0, 0.5, 0.0, 0.0],
1869-
],
1867+
[0.0, 0.0, 0.3, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0],
18701868
**tkwargs,
18711869
)
18721870
X_three_tasks = torch.tensor(
@@ -1895,7 +1893,7 @@ def test_learned_feature_imputation(self) -> None:
18951893
)
18961894
)
18971895
tf.raw_imputation_values.data = torch.tensor(
1898-
[[0.0, 0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.7, 0.0, 0.0]],
1896+
[0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0],
18991897
**tkwargs,
19001898
)
19011899
X_noncontig = torch.tensor(
@@ -1917,7 +1915,7 @@ def test_learned_feature_imputation(self) -> None:
19171915
**tkwargs,
19181916
)
19191917
tf.raw_imputation_values.data = torch.tensor(
1920-
[[0.0, 0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.7, 0.0, 0.0]],
1918+
[0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0],
19211919
**tkwargs,
19221920
)
19231921
X_batch = torch.tensor(
@@ -1940,7 +1938,7 @@ def test_learned_feature_imputation(self) -> None:
19401938
**tkwargs,
19411939
)
19421940
tf.raw_imputation_values.data = torch.tensor(
1943-
[[0.0, 0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.7, 0.0, 0.0]],
1941+
[0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0],
19441942
**tkwargs,
19451943
)
19461944
X_no_task = torch.tensor(
@@ -1984,6 +1982,47 @@ def test_learned_feature_imputation(self) -> None:
19841982
with self.assertRaisesRegex(ValueError, "Expected X.shape"):
19851983
tf(torch.zeros(2, d + 3, **tkwargs))
19861984

1985+
with self.subTest("fit_gpytorch_mll_with_bounds", dtype=dtype):
1986+
n = 5
1987+
X = torch.cat(
1988+
[
1989+
torch.cat(
1990+
[
1991+
torch.rand(n, d, **tkwargs),
1992+
i * torch.ones(n, 1, **tkwargs),
1993+
],
1994+
dim=-1,
1995+
)
1996+
for i in range(len(feature_indices))
1997+
]
1998+
)
1999+
Y = torch.randn(len(feature_indices) * n, 1, **tkwargs)
2000+
bounds = torch.stack(
2001+
[torch.zeros(d, **tkwargs), torch.ones(d, **tkwargs)]
2002+
)
2003+
lfi = LearnedFeatureImputation(
2004+
feature_indices=feature_indices, d=d, bounds=bounds, **tkwargs
2005+
)
2006+
model = MultiTaskGP(
2007+
train_X=X,
2008+
train_Y=Y,
2009+
task_feature=-1,
2010+
input_transform=ChainedInputTransform(
2011+
tf0=Normalize(d=d + 1, indices=list(range(d))),
2012+
tf1=lfi,
2013+
),
2014+
)
2015+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
2016+
with mock_optimize_context_manager():
2017+
fit_gpytorch_mll(mll, max_attempts=1)
2018+
imp = lfi.imputation_values
2019+
# These are the two learnable imputation values under the current
2020+
# setup, so these should be the only non-zero values.
2021+
self.assertNotEqual(imp[0, 3].item(), 0.0)
2022+
self.assertNotEqual(imp[1, 2].item(), 0.0)
2023+
# and here is one that should be zero.
2024+
self.assertNotEqual(imp[0, 1].item(), 0.0)
2025+
19872026

19882027
class TestAppendFeatures(BotorchTestCase):
19892028
def test_append_features(self) -> None:

0 commit comments

Comments
 (0)