@@ -445,9 +445,12 @@ def test_all_tasks_input(self) -> None:
445445 model = MultiTaskGP (
446446 train_X = train_X , train_Y = train_Y , task_feature = 0 , all_tasks = [0 , 1 , 2 , 3 ]
447447 )
448- self .assertEqual (model .num_tasks , 2 )
448+ self .assertEqual (model .num_tasks , 4 )
449449 # Check that PositiveIndexKernel knows of all tasks.
450- self .assertEqual (model .covar_module .kernels [1 ].raw_covar_factor .shape [0 ], 2 )
450+ self .assertEqual (model .covar_module .kernels [1 ].raw_covar_factor .shape [0 ], 4 )
451+ # Check that observed and unobserved task indices are computed correctly.
452+ self .assertEqual (model ._observed_task_indices .tolist (), [0 , 1 ])
453+ self .assertEqual (model ._unobserved_task_indices .tolist (), [2 , 3 ])
451454
452455 def test_MultiTaskGP_construct_inputs (self ) -> None :
453456 for dtype , fixed_noise , skip_task_features_in_datasets in zip (
@@ -540,13 +543,98 @@ def test_validatation_of_task_values(self) -> None:
540543 validate_task_values = True ,
541544 )
542545
546+ # Task 2 is in all_tasks, so it should be valid even with validation enabled
547+ self .assertTrue (
548+ torch .equal (
549+ torch .tensor ([1 ], ** tkwargs ),
550+ model ._map_tasks (task_values = torch .tensor ([2 ], ** tkwargs )),
551+ )
552+ )
553+
554+ # Task 3 is NOT in all_tasks, so it should raise an error
543555 with self .assertRaisesRegex (
544556 ValueError ,
545557 "Received invalid raw task values. Expected raw value to be in"
546- r" \{0, 1\}, but got unexpected task"
547- r" values: \{2 \}." ,
558+ r" \{0, 1, 2 \}, but got unexpected task"
559+ r" values: \{3 \}." ,
548560 ):
549- model ._map_tasks (task_values = torch .tensor ([2 ], ** tkwargs ))
561+ model ._map_tasks (task_values = torch .tensor ([3 ], ** tkwargs ))
562+
563+ def test_multitask_gp_unobserved_tasks (self ) -> None :
564+ """Test MultiTaskGP with unobserved tasks.
565+
566+ This test verifies that:
567+ 1. Creating a model with all_tasks including unobserved tasks works
568+ 2. In train mode, unobserved task covar_factor is at random initialization
569+ 3. In eval mode, unobserved task covar_factor is set to mean of observed
570+ 4. Predictions work for the unobserved task
571+ """
572+ tkwargs = {"device" : self .device , "dtype" : torch .double }
573+
574+ # Create data for tasks 0 and 2 only (task 1 is unobserved)
575+ _ , (train_X , train_Y , _ ) = gen_multi_task_dataset (task_values = [0 , 2 ], ** tkwargs )
576+
577+ # Create model with all_tasks=[0, 1, 2] including unobserved task 1
578+ model = MultiTaskGP (
579+ train_X = train_X ,
580+ train_Y = train_Y ,
581+ task_feature = 0 ,
582+ all_tasks = [0 , 1 , 2 ],
583+ )
584+ model .to (** tkwargs )
585+
586+ # Verify model.num_tasks == 3
587+ self .assertEqual (model .num_tasks , 3 )
588+
589+ # Verify observed and unobserved task indices are correctly set
590+ self .assertEqual (model ._observed_task_indices .tolist (), [0 , 2 ])
591+ self .assertEqual (model ._unobserved_task_indices .tolist (), [1 ])
592+
593+ # Get the task covariance module
594+ task_covar_module = model .covar_module .kernels [1 ]
595+
596+ # In train mode, get the covar_factor for unobserved task (index 1)
597+ model .train ()
598+ train_covar_factor = task_covar_module .covar_factor .clone ()
599+ unobserved_train_covar = train_covar_factor [1 ]
600+ observed_train_covar = train_covar_factor [[0 , 2 ]]
601+ mean_observed_train = observed_train_covar .mean (dim = 0 )
602+
603+ # Unobserved task covar_factor should be at random init in train mode
604+ # (very unlikely to be exactly equal to mean of observed)
605+ self .assertFalse (
606+ torch .allclose (unobserved_train_covar , mean_observed_train , atol = 1e-6 )
607+ )
608+
609+ # Switch to eval mode
610+ model .eval ()
611+
612+ # In eval mode, get the covar_factor for unobserved task
613+ eval_covar_factor = task_covar_module .covar_factor .clone ()
614+ unobserved_eval_covar = eval_covar_factor [1 ]
615+ observed_eval_covar = eval_covar_factor [[0 , 2 ]]
616+ mean_observed_eval = observed_eval_covar .mean (dim = 0 )
617+
618+ # Unobserved task covar_factor should equal mean of observed in eval mode
619+ self .assertTrue (
620+ torch .allclose (unobserved_eval_covar , mean_observed_eval , atol = 1e-6 )
621+ )
622+
623+ # Verify predictions work for the unobserved task
624+ # Create test input for unobserved task (task 1)
625+ test_X = torch .rand (3 , 2 , ** tkwargs )
626+ test_X [:, 0 ] = 1.0 # Set task feature to 1 (unobserved task)
627+
628+ with torch .no_grad ():
629+ posterior = model .posterior (X = test_X )
630+
631+ # Verify posterior has expected shape
632+ self .assertEqual (posterior .mean .shape , torch .Size ([3 , 1 ]))
633+ self .assertEqual (posterior .variance .shape , torch .Size ([3 , 1 ]))
634+
635+ # Verify we can sample from the posterior
636+ samples = posterior .rsample (sample_shape = torch .Size ([2 ]))
637+ self .assertEqual (samples .shape , torch .Size ([2 , 3 , 1 ]))
550638
551639
552640class TestKroneckerMultiTaskGP (BotorchTestCase ):
0 commit comments