Skip to content

Commit b685ca7

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Wire LearnedFeatureImputation and map_heterogeneous_to_full for MultiTaskGP
Summary: X-link: facebook/Ax#5192 Automatically configures learned feature imputation for models that pad heterogeneous per-task data to the full joint feature space. Models with native heterogeneity support are excluded from this automatic configuration. Differential Revision: D101841497
1 parent 1007867 commit b685ca7

3 files changed

Lines changed: 21 additions & 6 deletions

File tree

botorch/models/heterogeneous_mtgp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def construct_inputs(
312312
rank: int | None = None,
313313
use_saas_prior: bool = True,
314314
use_combinatorial_kernel: bool = True,
315+
map_heterogeneous_to_full: bool = False,
315316
) -> dict[str, Any]:
316317
r"""Construct ``Model`` keyword arguments from a given ``MultiTaskDataset``.
317318
@@ -326,6 +327,10 @@ def construct_inputs(
326327
``MultiTaskConditionalKernel``.
327328
use_combinatorial_kernel: Whether to use a combinatorial kernel over the
328329
binary embedding of task features in ``MultiTaskConditionalKernel``.
330+
map_heterogeneous_to_full: Accepted for compatibility with
331+
``MultiTaskGP.construct_inputs`` but unused.
332+
``HeterogeneousMTGP`` handles heterogeneous features via
333+
``MultiTaskConditionalKernel``.
329334
"""
330335
if training_data.task_feature_index != -1:
331336
raise NotImplementedError(

botorch/models/transforms/input.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,12 +2051,16 @@ def __init__(
20512051
),
20522052
)
20532053
if bounds is not None:
2054-
# Pad bounds with dummy [0, 1] for the task column so the Interval
2055-
# constraint has shape (d+1,) matching raw_imputation_values.
2056-
padded_lower = torch.zeros(d + 1, dtype=dtype, device=device)
2057-
padded_upper = torch.ones(d + 1, dtype=dtype, device=device)
2058-
padded_lower[:d] = bounds[0]
2059-
padded_upper[:d] = bounds[1]
2054+
# Constraint bounds must match raw_imputation_values' (num_tasks,
2055+
# d+1) shape — gpytorch's scipy fitting path flattens the parameter
2056+
# and expects per-element bounds, so a (d+1,) bound would fail to
2057+
# broadcast. Pad the task column slot with dummy [0, 1].
2058+
padded_lower = torch.zeros(
2059+
self.num_tasks, d + 1, dtype=dtype, device=device
2060+
)
2061+
padded_upper = torch.ones(self.num_tasks, d + 1, dtype=dtype, device=device)
2062+
padded_lower[:, :d] = bounds[0]
2063+
padded_upper[:, :d] = bounds[1]
20602064
self.register_constraint(
20612065
"raw_imputation_values",
20622066
Interval(

test/models/test_heterogeneous_mtgp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ def test_input_constructor(self) -> None:
9797
self.assertEqual(model_inputs["full_feature_dim"], 5)
9898
self.assertIsNone(model_inputs["rank"])
9999

100+
with self.subTest("map_heterogeneous_to_full accepted and ignored"):
101+
model_inputs = HeterogeneousMTGP.construct_inputs(
102+
training_data=self.mtds, map_heterogeneous_to_full=True
103+
)
104+
self.assertNotIn("map_heterogeneous_to_full", model_inputs)
105+
100106
def test_standard_heterogeneous_mtgp(self) -> None:
101107
# Construct the model (inferred noise: train_Yvars is None).
102108
model_inputs = HeterogeneousMTGP.construct_inputs(training_data=self.mtds)

0 commit comments

Comments
 (0)