From df363bea44fab82ce27d4c81f170e477577b6aa4 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Thu, 14 May 2026 09:21:30 -0700 Subject: [PATCH] Wire LearnedFeatureImputation and map_heterogeneous_to_full for MultiTaskGP (#3296) Summary: X-link: https://github.com/facebook/Ax/pull/5192 Automatically configures learned feature imputation for models that pad heterogeneous per-task data to the full joint feature space. Models with native heterogeneity support are excluded from this automatic configuration. Reviewed By: saitcakmak Differential Revision: D101841497 --- botorch/models/heterogeneous_mtgp.py | 5 +++++ test/models/test_heterogeneous_mtgp.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/botorch/models/heterogeneous_mtgp.py b/botorch/models/heterogeneous_mtgp.py index c05c0670a3..7f1663c1fe 100644 --- a/botorch/models/heterogeneous_mtgp.py +++ b/botorch/models/heterogeneous_mtgp.py @@ -327,6 +327,7 @@ def construct_inputs( rank: int | None = None, use_saas_prior: bool = True, use_combinatorial_kernel: bool = True, + map_heterogeneous_to_full: bool = False, ) -> dict[str, Any]: r"""Construct ``Model`` keyword arguments from a given ``MultiTaskDataset``. @@ -341,6 +342,10 @@ def construct_inputs( ``MultiTaskConditionalKernel``. use_combinatorial_kernel: Whether to use a combinatorial kernel over the binary embedding of task features in ``MultiTaskConditionalKernel``. + map_heterogeneous_to_full: Accepted for compatibility with + ``MultiTaskGP.construct_inputs`` but unused. + ``HeterogeneousMTGP`` handles heterogeneous features via + ``MultiTaskConditionalKernel``. """ if training_data.task_feature_index != -1: raise NotImplementedError( diff --git a/test/models/test_heterogeneous_mtgp.py b/test/models/test_heterogeneous_mtgp.py index 1767dca64e..520ba8883d 100644 --- a/test/models/test_heterogeneous_mtgp.py +++ b/test/models/test_heterogeneous_mtgp.py @@ -97,6 +97,12 @@ def test_input_constructor(self) -> None: self.assertEqual(model_inputs["full_feature_dim"], 5) self.assertIsNone(model_inputs["rank"]) + with self.subTest("map_heterogeneous_to_full accepted and ignored"): + model_inputs = HeterogeneousMTGP.construct_inputs( + training_data=self.mtds, map_heterogeneous_to_full=True + ) + self.assertNotIn("map_heterogeneous_to_full", model_inputs) + def test_standard_heterogeneous_mtgp(self) -> None: # Construct the model (inferred noise: train_Yvars is None). model_inputs = HeterogeneousMTGP.construct_inputs(training_data=self.mtds)