Skip to content

Commit 60aff2b

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Simplify task value remapping API (meta-pytorch#3163)
Summary: X-link: facebook/Ax#4860 Simplifies the get_task_value_remapping() API from 4 parameters to 2, addressing confusion reported in meta-pytorch#3085. The observed_task_values parameter is removed because the parent diff (D90769576) now makes MultiTaskGP track observed/unobserved tasks internally via _observed_task_indices and _unobserved_task_indices. The default_task_value parameter is removed because the previous behavior—silently mapping unknown tasks to an arbitrary fallback—was confusing and error-prone; instead, unrecognized tasks now map to NaN, providing an explicit error sentinel with a clear warning message. Differential Revision: D90998243
1 parent 9554db1 commit 60aff2b

6 files changed

Lines changed: 293 additions & 111 deletions

File tree

botorch/models/multitask.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,10 @@ def __init__(
336336

337337
self.covar_module = data_covar_module * task_covar_module
338338
task_mapper = get_task_value_remapping(
339-
observed_task_values=torch.tensor(
340-
all_tasks_inferred, dtype=torch.long, device=train_X.device
341-
),
342339
all_task_values=torch.tensor(
343340
sorted(all_tasks), dtype=torch.long, device=train_X.device
344341
),
345342
dtype=train_X.dtype,
346-
default_task_value=None if output_tasks is None else output_tasks[0],
347343
)
348344
self.register_buffer("_task_mapper", task_mapper)
349345
self._expected_task_values = set(all_tasks)

botorch/models/transforms/outcome.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from __future__ import annotations
2424

25+
import warnings
2526
from abc import ABC, abstractmethod
2627
from collections import OrderedDict
2728

@@ -511,12 +512,10 @@ class StratifiedStandardize(Standardize):
511512
def __init__(
512513
self,
513514
stratification_idx: int,
514-
observed_task_values: Tensor,
515515
all_task_values: Tensor,
516516
batch_shape: torch.Size = torch.Size(), # noqa: B008
517517
min_stdv: float = 1e-8,
518518
dtype: torch.dtype = torch.double,
519-
default_task_value: int | None = None,
520519
) -> None:
521520
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
522521
@@ -526,28 +525,22 @@ def __init__(
526525
Args:
527526
stratification_idx: The index of the stratification dimension in the
528527
input tensor X.
529-
observed_task_values: ``t``-dim tensor of task values that were actually
530-
observed in the training data.
531528
all_task_values: ``t``-dim tensor of all possible task values that could
532529
appear in the dataset.
533530
batch_shape: The batch_shape of the training targets.
534531
min_stdv: The minimum standard deviation for which to perform
535532
standardization (if lower, only de-mean the data).
536533
dtype: The data type for internal computations.
537-
default_task_value: The default task value that unexpected tasks are
538-
mapped to. This is used in ``get_task_value_remapping``.
539534
"""
540535
OutcomeTransform.__init__(self)
541536
self._stratification_idx = stratification_idx
542-
observed_task_values = observed_task_values.unique(sorted=True)
537+
all_task_values = all_task_values.unique(sorted=True)
543538
self.strata_mapping = get_task_value_remapping(
544-
observed_task_values=observed_task_values,
545-
all_task_values=all_task_values.unique(sorted=True),
539+
all_task_values=all_task_values,
546540
dtype=dtype,
547-
default_task_value=default_task_value,
548541
)
549542
if self.strata_mapping is None:
550-
self.strata_mapping = observed_task_values
543+
self.strata_mapping = all_task_values
551544
n_strata = self.strata_mapping.shape[0]
552545
self._min_stdv = min_stdv
553546
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))
@@ -629,7 +622,20 @@ def _get_per_input_means_stdvs(
629622
- The per-input stdvs squared.
630623
"""
631624
strata = X[..., self._stratification_idx].long()
632-
mapped_strata = self.strata_mapping[strata].unsqueeze(-1).long()
625+
mapped_strata_float = self.strata_mapping[strata]
626+
# Check for unobserved tasks (mapped to NaN) and warn
627+
unobserved_mask = torch.isnan(mapped_strata_float)
628+
if unobserved_mask.any():
629+
warnings.warn(
630+
"Predictions are being made for tasks that were not observed "
631+
"during training. These tasks will use an identity transform "
632+
"(mean=0, stdv=1).",
633+
stacklevel=3,
634+
)
635+
# Map unobserved tasks to index 0 temporarily for gather operation
636+
mapped_strata_float = mapped_strata_float.clone()
637+
mapped_strata_float[unobserved_mask] = 0.0
638+
mapped_strata = mapped_strata_float.unsqueeze(-1).long()
633639
# get means and stdvs for each strata
634640
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
635641
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
@@ -643,12 +649,22 @@ def _get_per_input_means_stdvs(
643649
dim=-2,
644650
index=mapped_strata,
645651
)
652+
# Apply identity transform (mean=0, stdv=1) for unobserved tasks
653+
if unobserved_mask.any():
654+
unobserved_mask_expanded = unobserved_mask.unsqueeze(-1)
655+
means = means.clone()
656+
stdvs = stdvs.clone()
657+
means[unobserved_mask_expanded] = 0.0
658+
stdvs[unobserved_mask_expanded] = 1.0
646659
if include_stdvs_sq:
647660
stdvs_sq = torch.gather(
648661
input=self._stdvs_sq.expand(expand_shape),
649662
dim=-2,
650663
index=mapped_strata,
651664
)
665+
if unobserved_mask.any():
666+
stdvs_sq = stdvs_sq.clone()
667+
stdvs_sq[unobserved_mask_expanded] = 1.0
652668
else:
653669
stdvs_sq = None
654670
return means, stdvs, stdvs_sq

botorch/models/utils/assorted.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -412,59 +412,51 @@ class fantasize(_Flag):
412412

413413

414414
def get_task_value_remapping(
415-
observed_task_values: Tensor,
416415
all_task_values: Tensor,
417416
dtype: torch.dtype,
418-
default_task_value: int | None,
419417
) -> Tensor | None:
420-
"""Construct an mapping of observed task values to contiguous int-valued floats.
418+
"""Construct a mapping of task values to contiguous int-valued floats.
419+
420+
This function creates a mapping tensor that remaps task indices. All tasks
421+
in ``all_task_values`` are mapped to contiguous integers starting from 0.
422+
Task values not in ``all_task_values`` are mapped to NaN.
421423
422424
Args:
423-
observed_task_values: A sorted long-valued tensor of task values.
424-
all_task_values: A sorted long-valued tensor of task values.
425+
all_task_values: A sorted long-valued tensor of all possible task values
426+
in the full task space.
425427
dtype: The dtype of the model inputs (e.g. ``X``), which the new
426428
task values should have mapped to (e.g. float, double).
427-
default_task_value: The default task value to use for missing task values.
428429
429430
Returns:
430-
A tensor of shape ``task_values.max() + 1`` that maps task values
431+
A tensor of shape ``all_task_values.max() + 1`` that maps task values
431432
to new task values. The indexing operation ``mapper[task_value]``
432433
will produce a tensor of new task values, of the same shape as
433-
the original. The elements of the ``mapper`` tensor that do not
434-
appear in the original ``task_values`` are mapped to ``nan``. The
435-
return value will be ``None``, when the task values are contiguous
436-
integers starting from zero.
434+
the original. All task values in ``all_task_values`` are mapped to
435+
contiguous integers [0, 1, ..., n-1] where n is the number of tasks.
436+
Task values not in ``all_task_values`` are mapped to NaN. Returns
437+
``None`` when ``all_task_values`` equals [0, 1, ..., n-1].
437438
"""
438439
if dtype not in (torch.float, torch.double):
439440
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
440441
task_range = torch.arange(
441-
len(observed_task_values),
442+
len(all_task_values),
442443
dtype=all_task_values.dtype,
443444
device=all_task_values.device,
444445
)
445446
mapper = None
446447

447-
if default_task_value is None:
448-
fill_value = float("nan")
449-
else:
450-
mask = observed_task_values == default_task_value
451-
if not mask.any():
452-
fill_value = float("nan")
453-
else:
454-
idx = mask.nonzero().item()
455-
fill_value = task_range[idx]
456-
# if not all tasks are observed or they are not contiguous integers
448+
# if task values are not contiguous integers starting from 0,
457449
# then map them to contiguous integers
458450
if not torch.equal(task_range, all_task_values):
459451
# Create a tensor that maps task values to new task values.
460452
# The number of tasks should be small, so this should be quite efficient.
461453
mapper = torch.full(
462454
(int(all_task_values.max().item()) + 1,),
463-
fill_value,
455+
float("nan"),
464456
dtype=dtype,
465457
device=all_task_values.device,
466458
)
467-
mapper[observed_task_values] = task_range.to(dtype=dtype)
459+
mapper[all_task_values] = task_range.to(dtype=dtype)
468460
return mapper
469461

470462

test/models/test_fully_bayesian_multitask.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,43 @@ def test_fit_model_infer_noise(self):
517517
def test_fit_model_with_outcome_transform(self):
518518
self.test_fit_model(use_outcome_transform=True)
519519

520+
def test_fit_model_with_unobserved_tasks(self) -> None:
521+
"""Test fitting and predicting when some tasks have no training data."""
522+
dtype = torch.double
523+
tkwargs = {"device": self.device, "dtype": dtype}
524+
# Tasks 0 and 2 observed; task 1 has no training data
525+
_, _, _, model = self._get_data_and_model(
526+
infer_noise=True,
527+
use_outcome_transform=True,
528+
output_tasks=[2],
529+
observed_task_values=[0, 2],
530+
all_tasks=[0, 1, 2],
531+
validate_task_values=False,
532+
**tkwargs,
533+
)
534+
# Contiguous all_tasks → no mapper needed
535+
self.assertIsNone(model._task_mapper)
536+
self.assertEqual(model.pyro_model.num_tasks, 3)
537+
538+
fit_fully_bayesian_model_nuts(
539+
model, warmup_steps=8, num_samples=5, thinning=2, disable_progbar=True
540+
)
541+
self.assertIsNotNone(model.mean_module)
542+
543+
# Predict for observed tasks
544+
test_X = torch.rand(3, 4, **tkwargs)
545+
posterior = model.posterior(test_X)
546+
self.assertIsInstance(posterior, GaussianMixturePosterior)
547+
# output_tasks=[2] → single output
548+
self.assertEqual(posterior.mean.shape[-1], 1)
549+
550+
# Predict for the UNOBSERVED task (task 1)
551+
test_X_unobs = torch.cat(
552+
[torch.rand(3, 4, **tkwargs), torch.ones(3, 1, **tkwargs)], dim=-1
553+
)
554+
posterior_unobs = model.posterior(test_X_unobs)
555+
self.assertIsInstance(posterior_unobs, GaussianMixturePosterior)
556+
520557
def test_transforms(self, infer_noise: bool = False):
521558
tkwargs = {"device": self.device, "dtype": torch.double}
522559
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(**tkwargs)

0 commit comments

Comments
 (0)