Skip to content

Commit c7ac346

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Wire LearnedFeatureImputation and map_heterogeneous_to_full for MultiTaskGP (meta-pytorch#3296)
Summary: X-link: facebook/Ax#5192 Automatically configures learned feature imputation for models that pad heterogeneous per-task data to the full joint feature space. Models with native heterogeneity support are excluded from this automatic configuration. Differential Revision: D101841497
1 parent 35e0e81 commit c7ac346

2 files changed

Lines changed: 11 additions & 0 deletions

File tree

botorch/models/heterogeneous_mtgp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def construct_inputs(
327327
rank: int | None = None,
328328
use_saas_prior: bool = True,
329329
use_combinatorial_kernel: bool = True,
330+
map_heterogeneous_to_full: bool = False,
330331
) -> dict[str, Any]:
331332
r"""Construct ``Model`` keyword arguments from a given ``MultiTaskDataset``.
332333
@@ -341,6 +342,10 @@ def construct_inputs(
341342
``MultiTaskConditionalKernel``.
342343
use_combinatorial_kernel: Whether to use a combinatorial kernel over the
343344
binary embedding of task features in ``MultiTaskConditionalKernel``.
345+
map_heterogeneous_to_full: Accepted for compatibility with
346+
``MultiTaskGP.construct_inputs`` but unused.
347+
``HeterogeneousMTGP`` handles heterogeneous features via
348+
``MultiTaskConditionalKernel``.
344349
"""
345350
if training_data.task_feature_index != -1:
346351
raise NotImplementedError(

test/models/test_heterogeneous_mtgp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ def test_input_constructor(self) -> None:
9797
self.assertEqual(model_inputs["full_feature_dim"], 5)
9898
self.assertIsNone(model_inputs["rank"])
9999

100+
with self.subTest("map_heterogeneous_to_full accepted and ignored"):
101+
model_inputs = HeterogeneousMTGP.construct_inputs(
102+
training_data=self.mtds, map_heterogeneous_to_full=True
103+
)
104+
self.assertNotIn("map_heterogeneous_to_full", model_inputs)
105+
100106
def test_standard_heterogeneous_mtgp(self) -> None:
101107
# Construct the model (inferred noise: train_Yvars is None).
102108
model_inputs = HeterogeneousMTGP.construct_inputs(training_data=self.mtds)

0 commit comments

Comments
 (0)