Skip to content

Commit 09370a6

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Enable unobserved task support in MultiTaskGP
Summary: Permits an MTGP to predict on an unobserved task, addressing these issues: meta-pytorch#2360 meta-pytorch#3085 To do this, we assume that the unobserved task is maximally correlated with the target tasks (equally with each, by averaging the elements). Exact heuristic on correlation is definitely up for discussion, but this seems like a decent default assumption. Will come in handy for TL initialization. Differential Revision: D90769576 D90769576
1 parent 7471077 commit 09370a6

4 files changed

Lines changed: 234 additions & 26 deletions

File tree

botorch/models/fully_bayesian_multitask.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
r"""Multi-task Gaussian Process Regression models with fully Bayesian inference."""
88

99
from collections.abc import Mapping
10-
from typing import Any, NoReturn, TypeVar
10+
from typing import Any, NoReturn, Self, TypeVar
1111

1212
import pyro
1313
import torch
@@ -19,7 +19,10 @@
1919
reshape_and_detach,
2020
SaasPyroModel,
2121
)
22-
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
22+
from botorch.models.gpytorch import (
23+
BatchedMultiOutputGPyTorchModel,
24+
MultiTaskGPyTorchModel,
25+
)
2326
from botorch.models.multitask import MultiTaskGP
2427
from botorch.models.transforms.input import InputTransform
2528
from botorch.models.transforms.outcome import OutcomeTransform
@@ -55,6 +58,7 @@ def set_inputs(
5558
train_Yvar: Tensor | None,
5659
task_feature: int,
5760
task_rank: int | None = None,
61+
all_tasks: list[int] | None = None,
5862
) -> None:
5963
"""Set the training data.
6064
@@ -73,7 +77,11 @@ def set_inputs(
7377
task_feature = task_feature % train_X.shape[-1]
7478
super().set_inputs(train_X, train_Y, train_Yvar)
7579
# obtain a list of task indicies
76-
all_tasks = train_X[:, task_feature].unique().to(dtype=torch.long).tolist()
80+
all_tasks = (
81+
train_X[:, task_feature].unique().to(dtype=torch.long).tolist()
82+
if all_tasks is None
83+
else all_tasks
84+
)
7785
self.task_feature = task_feature
7886
self.num_tasks = len(all_tasks)
7987
self.task_rank = task_rank or self.num_tasks
@@ -242,7 +250,10 @@ def __init__(
242250
outputs for. If omitted, return outputs for all task indices.
243251
rank: The num of learned task embeddings to be used in the task kernel.
244252
If omitted, use a full rank (i.e. number of tasks) kernel.
245-
all_tasks: NOT SUPPORTED!
253+
all_tasks: A list of all task indices. If omitted, all tasks will be
254+
inferred from the task feature column of the training data. Used to
255+
inform the model about the total number of tasks, including any
256+
unobserved tasks.
246257
outcome_transform: An outcome transform that is applied to the
247258
training data during instantiation and to the posterior during
248259
inference (that is, the ``Posterior`` obtained by calling
@@ -310,6 +321,7 @@ def __init__(
310321
train_Yvar=train_Yvar,
311322
task_feature=task_feature,
312323
task_rank=self._rank,
324+
all_tasks=all_tasks,
313325
)
314326
self.pyro_model: MultitaskSaasPyroModel = pyro_model
315327
if outcome_transform is not None:
@@ -383,6 +395,20 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
383395
_,
384396
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
385397

398+
def eval(self) -> Self:
399+
r"""Puts the model in eval mode.
400+
401+
Circumvents the need to call MultiTaskGP.eval(), which computes the
402+
task_covar_matrix for non-observed tasks. This is not needed for fully
403+
Bayesian models, since the non-observed tasks' covar factors are instead
404+
sampled.
405+
406+
Returns:
407+
The model itself.
408+
"""
409+
self._check_if_fitted()
410+
return MultiTaskGPyTorchModel.eval(self)
411+
386412
def posterior(
387413
self,
388414
X: Tensor,

botorch/models/multitask.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from __future__ import annotations
3131

3232
import math
33-
from typing import Any
33+
from typing import Any, Self
3434

3535
import torch
3636
from botorch.acquisition.objective import PosteriorTransform
@@ -238,7 +238,11 @@ def __init__(
238238
"This is not allowed as it will lead to errors during model training."
239239
)
240240
all_tasks = all_tasks or all_tasks_inferred
241-
self.num_tasks = len(all_tasks_inferred)
241+
sorted_all_tasks = sorted(all_tasks)
242+
self.num_tasks = len(all_tasks)
243+
# Store for later buffer registration (after super().__init__)
244+
self._all_tasks_inferred = all_tasks_inferred
245+
self._sorted_all_tasks = sorted_all_tasks
242246
if outcome_transform == DEFAULT:
243247
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
244248
if outcome_transform is not None:
@@ -321,14 +325,55 @@ def __init__(
321325
default_task_value=None if output_tasks is None else output_tasks[0],
322326
)
323327
self.register_buffer("_task_mapper", task_mapper)
324-
self._expected_task_values = set(all_tasks_inferred)
328+
self._expected_task_values = set(all_tasks)
325329
if input_transform is not None:
326330
self.input_transform = input_transform
327331
if outcome_transform is not None:
328332
self.outcome_transform = outcome_transform
329333
self._validate_task_values = validate_task_values
330334
self.to(train_X)
331335

336+
# Register observed/unobserved task indices as buffers (must be after
337+
# super().__init__()). Compute observed and unobserved task indices when
338+
# all_tasks includes unobserved tasks
339+
if set(self._sorted_all_tasks) != set(self._all_tasks_inferred):
340+
observed_set = set(self._all_tasks_inferred)
341+
self.register_buffer(
342+
"_observed_task_indices",
343+
torch.tensor(
344+
[
345+
i
346+
for i, t in enumerate(self._sorted_all_tasks)
347+
if t in observed_set
348+
],
349+
dtype=torch.long,
350+
),
351+
)
352+
self.register_buffer(
353+
"_unobserved_task_indices",
354+
torch.tensor(
355+
[
356+
i
357+
for i, t in enumerate(self._sorted_all_tasks)
358+
if t not in observed_set
359+
],
360+
dtype=torch.long,
361+
),
362+
)
363+
else:
364+
# All tasks are observed - set observed indices to all tasks
365+
self.register_buffer(
366+
"_observed_task_indices",
367+
torch.arange(len(self._sorted_all_tasks), dtype=torch.long),
368+
)
369+
self.register_buffer(
370+
"_unobserved_task_indices",
371+
torch.tensor([], dtype=torch.long),
372+
)
373+
# Clean up temporary attributes
374+
del self._all_tasks_inferred
375+
del self._sorted_all_tasks
376+
332377
def _map_tasks(self, task_values: Tensor) -> Tensor:
333378
"""Map raw task values to the task indices used by the model.
334379
@@ -407,6 +452,28 @@ def forward(self, x: Tensor) -> MultivariateNormal:
407452
covar_x = self.covar_module(x_covar)
408453
return MultivariateNormal(mean_x, covar_x)
409454

455+
def eval(self) -> Self:
456+
r"""Puts the model in ``eval`` mode.
457+
458+
When unobserved tasks are present (i.e., ``all_tasks`` includes tasks not in
459+
the training data), this method sets the covariance factor for unobserved tasks
460+
to the mean of the observed tasks' covariance factors. This provides a
461+
reasonable initialization for prediction on unobserved tasks.
462+
"""
463+
if len(self._unobserved_task_indices) > 0:
464+
task_covar_module = self.covar_module.kernels[1]
465+
# Get the current covar_factor (transformed from raw_covar_factor)
466+
covar_factor = task_covar_module.covar_factor
467+
# Compute mean of observed tasks' covar_factor rows
468+
observed_covar_factor = covar_factor[self._observed_task_indices]
469+
mean_covar_factor = observed_covar_factor.mean(dim=0)
470+
# Create new covar_factor with unobserved tasks set to mean
471+
new_covar_factor = covar_factor.clone()
472+
new_covar_factor[self._unobserved_task_indices] = mean_covar_factor
473+
# Set the new covar_factor (this applies inverse_transform internally)
474+
task_covar_module._set_covar_factor(new_covar_factor)
475+
return super().eval()
476+
410477
@classmethod
411478
def get_all_tasks(
412479
cls,

test/models/test_fully_bayesian_multitask.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -417,14 +417,19 @@ def test_fit_model(
417417

418418
# Check the keys in the state dict
419419
true_keys = EXPECTED_KEYS_NOISE if infer_noise else EXPECTED_KEYS
420-
extra_keys = []
420+
extra_keys = [
421+
"_observed_task_indices",
422+
"_unobserved_task_indices",
423+
]
421424
if use_outcome_transform:
422-
extra_keys = [
423-
"outcome_transform.stdvs",
424-
"outcome_transform._is_trained",
425-
"outcome_transform._stdvs_sq",
426-
"outcome_transform.means",
427-
]
425+
extra_keys.extend(
426+
[
427+
"outcome_transform.stdvs",
428+
"outcome_transform._is_trained",
429+
"outcome_transform._stdvs_sq",
430+
"outcome_transform.means",
431+
]
432+
)
428433
if model._task_mapper is not None:
429434
extra_keys.append("_task_mapper")
430435
self.assertEqual(set(model.state_dict().keys()), {*true_keys, *extra_keys})
@@ -514,6 +519,8 @@ def test_fit_model_with_outcome_transform(self):
514519
def test_fit_model_with_task_mapper(self) -> None:
515520
dtype = torch.double
516521
tkwargs = {"device": self.device, "dtype": dtype}
522+
# Test with contiguous all_tasks that includes an unobserved task
523+
# all_tasks=[0, 1, 2] is contiguous from 0, so no mapper is needed
517524
all_tasks = [0, 1, 2]
518525
observed_task_values = [0, 2]
519526
output_tasks = [2]
@@ -526,16 +533,36 @@ def test_fit_model_with_task_mapper(self) -> None:
526533
validate_task_values=False,
527534
**tkwargs,
528535
)
529-
self.assertTrue(
530-
torch.equal(model._task_mapper, torch.tensor([0, 1, 1], **tkwargs))
531-
)
532-
self.test_fit_model(
536+
# With contiguous all_tasks=[0, 1, 2], no task mapper is needed
537+
# because task values are already contiguous integers starting from 0
538+
self.assertIsNone(model._task_mapper)
539+
# Verify the pyro_model has the correct number of tasks (3, not 2)
540+
self.assertEqual(model.pyro_model.num_tasks, 3)
541+
542+
# Also test non-contiguous all_tasks to ensure mapper is created
543+
all_tasks_noncontig = [0, 2, 5]
544+
observed_task_values_noncontig = [0, 5]
545+
output_tasks_noncontig = [5]
546+
_, _, _, model_noncontig = self._get_data_and_model(
547+
infer_noise=True,
533548
use_outcome_transform=True,
534-
all_tasks=all_tasks,
535-
observed_task_values=observed_task_values,
536-
output_tasks=output_tasks,
549+
output_tasks=output_tasks_noncontig,
550+
observed_task_values=observed_task_values_noncontig,
551+
all_tasks=all_tasks_noncontig,
537552
validate_task_values=False,
553+
**tkwargs,
538554
)
555+
# With non-contiguous all_tasks=[0, 2, 5], a mapper is required
556+
# Mapper maps: 0→0, 2→1, 5→2; other indices map to NaN
557+
self.assertIsNotNone(model_noncontig._task_mapper)
558+
self.assertEqual(model_noncontig._task_mapper[0].item(), 0.0)
559+
self.assertEqual(model_noncontig._task_mapper[2].item(), 1.0)
560+
self.assertEqual(model_noncontig._task_mapper[5].item(), 2.0)
561+
self.assertTrue(torch.isnan(model_noncontig._task_mapper[1]))
562+
self.assertTrue(torch.isnan(model_noncontig._task_mapper[3]))
563+
self.assertTrue(torch.isnan(model_noncontig._task_mapper[4]))
564+
# Verify pyro_model has correct number of tasks
565+
self.assertEqual(model_noncontig.pyro_model.num_tasks, 3)
539566

540567
def test_transforms(self, infer_noise: bool = False):
541568
tkwargs = {"device": self.device, "dtype": torch.double}

test/models/test_multitask.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,12 @@ def test_all_tasks_input(self) -> None:
445445
model = MultiTaskGP(
446446
train_X=train_X, train_Y=train_Y, task_feature=0, all_tasks=[0, 1, 2, 3]
447447
)
448-
self.assertEqual(model.num_tasks, 2)
448+
self.assertEqual(model.num_tasks, 4)
449449
# Check that PositiveIndexKernel knows of all tasks.
450-
self.assertEqual(model.covar_module.kernels[1].raw_covar_factor.shape[0], 2)
450+
self.assertEqual(model.covar_module.kernels[1].raw_covar_factor.shape[0], 4)
451+
# Check that observed and unobserved task indices are computed correctly.
452+
self.assertEqual(model._observed_task_indices.tolist(), [0, 1])
453+
self.assertEqual(model._unobserved_task_indices.tolist(), [2, 3])
451454

452455
def test_MultiTaskGP_construct_inputs(self) -> None:
453456
for dtype, fixed_noise, skip_task_features_in_datasets in zip(
@@ -540,13 +543,98 @@ def test_validatation_of_task_values(self) -> None:
540543
validate_task_values=True,
541544
)
542545

546+
# Task 2 is in all_tasks, so it should be valid even with validation enabled
547+
self.assertTrue(
548+
torch.equal(
549+
torch.tensor([1], **tkwargs),
550+
model._map_tasks(task_values=torch.tensor([2], **tkwargs)),
551+
)
552+
)
553+
554+
# Task 3 is NOT in all_tasks, so it should raise an error
543555
with self.assertRaisesRegex(
544556
ValueError,
545557
"Received invalid raw task values. Expected raw value to be in"
546-
r" \{0, 1\}, but got unexpected task"
547-
r" values: \{2\}.",
558+
r" \{0, 1, 2\}, but got unexpected task"
559+
r" values: \{3\}.",
548560
):
549-
model._map_tasks(task_values=torch.tensor([2], **tkwargs))
561+
model._map_tasks(task_values=torch.tensor([3], **tkwargs))
562+
563+
def test_multitask_gp_unobserved_tasks(self) -> None:
564+
"""Test MultiTaskGP with unobserved tasks.
565+
566+
This test verifies that:
567+
1. Creating a model with all_tasks including unobserved tasks works
568+
2. In train mode, unobserved task covar_factor is at random initialization
569+
3. In eval mode, unobserved task covar_factor is set to mean of observed
570+
4. Predictions work for the unobserved task
571+
"""
572+
tkwargs = {"device": self.device, "dtype": torch.double}
573+
574+
# Create data for tasks 0 and 2 only (task 1 is unobserved)
575+
_, (train_X, train_Y, _) = gen_multi_task_dataset(task_values=[0, 2], **tkwargs)
576+
577+
# Create model with all_tasks=[0, 1, 2] including unobserved task 1
578+
model = MultiTaskGP(
579+
train_X=train_X,
580+
train_Y=train_Y,
581+
task_feature=0,
582+
all_tasks=[0, 1, 2],
583+
)
584+
model.to(**tkwargs)
585+
586+
# Verify model.num_tasks == 3
587+
self.assertEqual(model.num_tasks, 3)
588+
589+
# Verify observed and unobserved task indices are correctly set
590+
self.assertEqual(model._observed_task_indices.tolist(), [0, 2])
591+
self.assertEqual(model._unobserved_task_indices.tolist(), [1])
592+
593+
# Get the task covariance module
594+
task_covar_module = model.covar_module.kernels[1]
595+
596+
# In train mode, get the covar_factor for unobserved task (index 1)
597+
model.train()
598+
train_covar_factor = task_covar_module.covar_factor.clone()
599+
unobserved_train_covar = train_covar_factor[1]
600+
observed_train_covar = train_covar_factor[[0, 2]]
601+
mean_observed_train = observed_train_covar.mean(dim=0)
602+
603+
# Unobserved task covar_factor should be at random init in train mode
604+
# (very unlikely to be exactly equal to mean of observed)
605+
self.assertFalse(
606+
torch.allclose(unobserved_train_covar, mean_observed_train, atol=1e-6)
607+
)
608+
609+
# Switch to eval mode
610+
model.eval()
611+
612+
# In eval mode, get the covar_factor for unobserved task
613+
eval_covar_factor = task_covar_module.covar_factor.clone()
614+
unobserved_eval_covar = eval_covar_factor[1]
615+
observed_eval_covar = eval_covar_factor[[0, 2]]
616+
mean_observed_eval = observed_eval_covar.mean(dim=0)
617+
618+
# Unobserved task covar_factor should equal mean of observed in eval mode
619+
self.assertTrue(
620+
torch.allclose(unobserved_eval_covar, mean_observed_eval, atol=1e-6)
621+
)
622+
623+
# Verify predictions work for the unobserved task
624+
# Create test input for unobserved task (task 1)
625+
test_X = torch.rand(3, 2, **tkwargs)
626+
test_X[:, 0] = 1.0 # Set task feature to 1 (unobserved task)
627+
628+
with torch.no_grad():
629+
posterior = model.posterior(X=test_X)
630+
631+
# Verify posterior has expected shape
632+
self.assertEqual(posterior.mean.shape, torch.Size([3, 1]))
633+
self.assertEqual(posterior.variance.shape, torch.Size([3, 1]))
634+
635+
# Verify we can sample from the posterior
636+
samples = posterior.rsample(sample_shape=torch.Size([2]))
637+
self.assertEqual(samples.shape, torch.Size([2, 3, 1]))
550638

551639

552640
class TestKroneckerMultiTaskGP(BotorchTestCase):

0 commit comments

Comments
 (0)