Skip to content

Commit 118f984

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Simplify task value remapping API (#4860)
Summary: X-link: meta-pytorch/botorch#3163 Pull Request resolved: #4860 X-link: meta-pytorch/botorch#3146 Simplifies the get_task_value_remapping() API from 4 parameters to 2, addressing confusion reported in #3085. The observed_task_values parameter is removed because the parent diff (D90769576) now makes MultiTaskGP track observed/unobserved tasks internally via _observed_task_indices and _unobserved_task_indices. The default_task_value parameter is removed because the previous behavior—silently mapping unknown tasks to an arbitrary fallback—was confusing and error-prone; instead, unrecognized tasks now map to NaN, providing an explicit error sentinel with a clear warning message. Differential Revision: D90998243
1 parent ce4dc42 commit 118f984

2 files changed

Lines changed: 11 additions & 21 deletions

File tree

ax/generators/torch/botorch_modular/input_constructors/outcome_transform.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,10 @@ def _outcome_transform_argparse_stratified_standardize(
106106
dataset = assert_is_instance(dataset, MultiTaskDataset)
107107
if dataset.has_heterogeneous_features:
108108
task_feature_index = dataset.task_feature_index or -1
109-
task_values = torch.arange(len(dataset.datasets), dtype=torch.long)
110109
else:
111110
task_feature_index = dataset.task_feature_index
112-
task_values = dataset.X[..., dataset.task_feature_index].unique().long()
113111
ssd = none_throws(search_space_digest)
114-
if (ssd.target_values is not None) and (
115-
target_value := ssd.target_values.get(none_throws(task_feature_index))
116-
) is not None:
117-
outcome_transform_options.setdefault("default_task_value", int(target_value))
118112
outcome_transform_options.setdefault("stratification_idx", task_feature_index)
119-
outcome_transform_options.setdefault("observed_task_values", task_values)
120113
outcome_transform_options.setdefault(
121114
"all_task_values",
122115
torch.tensor(

ax/generators/torch/tests/test_outcome_transform_argparse.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_argparse_stratified_standardize(self) -> None:
8888
dataset=mt_dataset,
8989
search_space_digest=ssd,
9090
)
91-
options_b = {"stratification_idx": 2, "default_task_value": 4}
91+
options_b = {"stratification_idx": 2}
9292
outcome_transform_kwargs_b = outcome_transform_argparse(
9393
StratifiedStandardize,
9494
dataset=mt_dataset,
@@ -97,27 +97,24 @@ def test_argparse_stratified_standardize(self) -> None:
9797
)
9898
expected_options_a = {
9999
"stratification_idx": 3,
100-
"observed_task_values": torch.tensor([0, 1], dtype=torch.long),
101100
"all_task_values": torch.tensor([0, 1, 2], dtype=torch.long),
102-
"default_task_value": 1,
103101
}
104102
expected_options_b = {
105103
"stratification_idx": 2,
106-
"observed_task_values": torch.tensor([0, 1], dtype=torch.long),
107104
"all_task_values": torch.tensor([0, 1, 2], dtype=torch.long),
108-
"default_task_value": 4,
109105
}
110106
for expected_options, actual_options in zip(
111107
(expected_options_a, expected_options_b),
112108
(outcome_transform_kwargs_a, outcome_transform_kwargs_b),
113109
):
114-
self.assertEqual(len(actual_options), 4)
115-
for k in ("stratification_idx", "stratification_idx"):
116-
self.assertEqual(actual_options[k], expected_options[k])
117-
for k in ("observed_task_values", "all_task_values"):
118-
self.assertTrue(
119-
torch.equal(
120-
actual_options[k],
121-
assert_is_instance(expected_options[k], Tensor),
122-
)
110+
self.assertEqual(len(actual_options), 2)
111+
self.assertEqual(
112+
actual_options["stratification_idx"],
113+
expected_options["stratification_idx"],
114+
)
115+
self.assertTrue(
116+
torch.equal(
117+
actual_options["all_task_values"],
118+
assert_is_instance(expected_options["all_task_values"], Tensor),
123119
)
120+
)

0 commit comments

Comments
 (0)