Skip to content

Commit 6c3e2cc

Browse files
mpolson64facebook-github-bot
authored andcommitted
Back out "Add support for missing tasks in mtgp"
Summary: Original commit changeset: f92a49fb4622 Original Phabricator Diff: D79812024 Same motivation as D81695384, will be cleaned up after Ax 1.1.1 release Rollback Plan: Differential Revision: D81784749
1 parent 116c3c3 commit 6c3e2cc

8 files changed

Lines changed: 33 additions & 87 deletions

File tree

ax/generators/tests/test_botorch_defaults.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ def test_get_model(self) -> None:
9696
"sd_prior": GammaPrior(2.0, 0.44),
9797
"eta": 0.6,
9898
}
99-
x[0, 1] = 0
100-
x[1, 1] = 1
10199
model = _get_model(
102100
X=x, Y=y, Yvar=partial_var.clone(), task_feature=1, prior=prior
103101
)
@@ -117,6 +115,7 @@ def test_get_model(self) -> None:
117115
task_covar_module.IndexKernelPrior.correlation_prior.eta,
118116
0.6,
119117
)
118+
120119
model = _get_model(
121120
X=x,
122121
Y=y,

ax/generators/tests/test_botorch_moo_model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,6 @@ def test_BotorchMOOModel_with_qehvi(
510510
[
511511
[11.0, 2.0],
512512
[9.0, 3.0],
513-
[12.0, 0.0],
514-
[13.0, 0.0],
515513
],
516514
**tkwargs,
517515
)
@@ -561,8 +559,16 @@ def test_BotorchMOOModel_with_qehvi(
561559
ckwargs = _mock_model_infer_objective_thresholds.call_args[1]
562560
X_observed = ckwargs["X_observed"]
563561
sorted_idcs = X_observed[:, 0].argsort()
564-
sorted_idcs2 = Xs[:, 0].argsort()
565-
self.assertTrue(torch.equal(X_observed[sorted_idcs], Xs[sorted_idcs2]))
562+
expected_X_observed = torch.tensor(
563+
[[1.0, 2.0, 3.0], [0.9, 1.9, 2.9]], **tkwargs
564+
)
565+
sorted_idcs2 = expected_X_observed[:, 0].argsort()
566+
self.assertTrue(
567+
torch.equal(
568+
X_observed[sorted_idcs],
569+
expected_X_observed[sorted_idcs2],
570+
)
571+
)
566572
self.assertTrue(
567573
torch.equal(
568574
ckwargs["objective_weights"],
@@ -782,7 +788,6 @@ def test_BotorchMOOModel_with_qehvi_and_outcome_constraints(
782788
feature_names,
783789
_,
784790
) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True)
785-
bounds[0] = (0.0, 1.0) # make one data point out of bounds
786791
training_data = [
787792
SupervisedDataset(
788793
X=Xs,

ax/generators/torch/botorch_modular/input_constructors/outcome_transform.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from typing import Any
1212

1313
import torch
14-
from ax.core.search_space import SearchSpaceDigest
15-
from ax.generators.torch.botorch_modular.utils import get_all_task_values_from_ssd
1614

1715
from ax.utils.common.typeutils import _argparse_type_encoder
1816
from botorch.models.transforms.outcome import (
@@ -22,7 +20,7 @@
2220
)
2321
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
2422
from botorch.utils.dispatcher import Dispatcher
25-
from pyre_extensions import assert_is_instance, none_throws
23+
from pyre_extensions import assert_is_instance
2624

2725
outcome_transform_argparse = Dispatcher(
2826
name="outcome_transform_argparse", encoder=_argparse_type_encoder
@@ -34,7 +32,6 @@ def _outcome_transform_argparse_base(
3432
outcome_transform_class: type[OutcomeTransform],
3533
dataset: SupervisedDataset | None = None,
3634
outcome_transform_options: dict[str, Any] | None = None,
37-
search_space_digest: SearchSpaceDigest | None = None,
3835
) -> dict[str, Any]:
3936
"""
4037
Extract the outcome transform kwargs from the given arguments.
@@ -61,7 +58,6 @@ def _outcome_transform_argparse_standardize(
6158
outcome_transform_class: type[Standardize],
6259
dataset: SupervisedDataset,
6360
outcome_transform_options: dict[str, Any] | None = None,
64-
search_space_digest: SearchSpaceDigest | None = None,
6561
) -> dict[str, Any]:
6662
"""Extract the outcome transform kwargs form the given arguments.
6763
@@ -88,7 +84,6 @@ def _outcome_transform_argparse_stratified_standardize(
8884
outcome_transform_class: type[StratifiedStandardize],
8985
dataset: SupervisedDataset,
9086
outcome_transform_options: dict[str, Any] | None = None,
91-
search_space_digest: SearchSpaceDigest | None = None,
9287
) -> dict[str, Any]:
9388
"""Extract the outcome transform kwargs form the given arguments.
9489
@@ -111,20 +106,7 @@ def _outcome_transform_argparse_stratified_standardize(
111106
else:
112107
task_feature_index = dataset.task_feature_index
113108
task_values = dataset.X[..., dataset.task_feature_index].unique().long()
114-
ssd = none_throws(search_space_digest)
115-
if (ssd.target_values is not None) and (
116-
target_value := ssd.target_values.get(none_throws(task_feature_index))
117-
) is not None:
118-
outcome_transform_options.setdefault("default_task_value", int(target_value))
119109
outcome_transform_options.setdefault("stratification_idx", task_feature_index)
120-
outcome_transform_options.setdefault("observed_task_values", task_values)
121-
outcome_transform_options.setdefault(
122-
"all_task_values",
123-
torch.tensor(
124-
get_all_task_values_from_ssd(search_space_digest=ssd),
125-
dtype=torch.long,
126-
device=next(iter(dataset.datasets.values())).X.device,
127-
),
128-
)
110+
outcome_transform_options.setdefault("task_values", task_values)
129111

130112
return outcome_transform_options

ax/generators/torch/botorch_modular/surrogate.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
convert_to_block_design,
3737
copy_model_config_with_default_values,
3838
fit_botorch_model,
39-
get_all_task_values_from_ssd,
4039
get_cv_fold,
4140
ModelConfig,
4241
subset_state_dict,
@@ -272,7 +271,6 @@ def _make_botorch_outcome_transform(
272271
outcome_transform_classes: list[type[OutcomeTransform]],
273272
outcome_transform_options: dict[str, dict[str, Any]],
274273
dataset: SupervisedDataset,
275-
search_space_digest: SearchSpaceDigest,
276274
) -> OutcomeTransform | None:
277275
"""
278276
Makes a BoTorch outcome transform from the provided classes and options.
@@ -292,7 +290,6 @@ def _make_botorch_outcome_transform(
292290
outcome_transform_options.get(transform_class.__name__, {})
293291
),
294292
dataset=dataset,
295-
search_space_digest=search_space_digest,
296293
)
297294
for transform_class in outcome_transform_classes
298295
]
@@ -376,7 +373,6 @@ def _error_if_arg_not_supported(arg_name: str) -> None:
376373
outcome_transform_classes=outcome_transform_classes,
377374
outcome_transform_options=model_config.outcome_transform_options or {},
378375
dataset=dataset,
379-
search_space_digest=search_space_digest,
380376
)
381377
elif "outcome_transform" in botorch_model_class_args:
382378
# This is a temporary solution until all BoTorch models use
@@ -1295,14 +1291,6 @@ def _submodel_input_constructor_mtgp(
12951291
target_value := search_space_digest.target_values.get(task_feature)
12961292
) is not None:
12971293
formatted_model_inputs["output_tasks"] = [int(target_value)]
1298-
# This enables making predictions for inputs at unobserved task values,
1299-
# by making predictions for the target task.
1300-
# This is important for MTGP models that are used in ModelListGPs where
1301-
# some metrics have only been observed for some tasks and not others.
1302-
formatted_model_inputs["validate_task_values"] = False
1303-
formatted_model_inputs["all_tasks"] = get_all_task_values_from_ssd(
1304-
search_space_digest=search_space_digest
1305-
)
13061294
else:
13071295
raise UserInputError(
13081296
"output_tasks or target task value must be provided for MultiTaskGP."

ax/generators/torch/botorch_modular/utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -663,17 +663,3 @@ def get_cv_fold(
663663
test_X=X[idcs],
664664
test_Y=Y[idcs],
665665
)
666-
667-
668-
def get_all_task_values_from_ssd(search_space_digest: SearchSpaceDigest) -> list[int]:
669-
"""Get all task values from a search space digest.
670-
671-
Args:
672-
search_space_digest: The search space digest.
673-
674-
Returns:
675-
A list of all task values.
676-
"""
677-
task_feature = search_space_digest.task_features[0]
678-
task_bounds = search_space_digest.bounds[task_feature]
679-
return list(range(int(task_bounds[0]), int(task_bounds[1] + 1)))

ax/generators/torch/tests/test_outcome_transform_argparse.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# pyre-strict
77

88
import torch
9-
from ax.core.search_space import SearchSpaceDigest
109
from ax.generators.torch.botorch_modular.input_constructors.outcome_transform import (
1110
outcome_transform_argparse,
1211
)
@@ -18,7 +17,6 @@
1817
)
1918
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
2019
from pyre_extensions import assert_is_instance
21-
from torch import Tensor
2220

2321

2422
class DummyOutcomeTransform(OutcomeTransform):
@@ -72,52 +70,39 @@ def test_argparse_stratified_standardize(self) -> None:
7270
X = self.dataset.X
7371
X[:5, 3] = 0
7472
X[5:, 3] = 1
75-
ssd = SearchSpaceDigest(
76-
feature_names=self.dataset.feature_names,
77-
bounds=[(0.0, 1.0)] * 3 + [(0.0, 2.0)],
78-
task_features=[3],
79-
target_values={3: 1},
80-
)
8173
mt_dataset = MultiTaskDataset.from_joint_dataset(
8274
dataset=self.dataset,
8375
task_feature_index=3,
8476
target_task_value=1,
8577
)
8678
outcome_transform_kwargs_a = outcome_transform_argparse(
87-
StratifiedStandardize,
88-
dataset=mt_dataset,
89-
search_space_digest=ssd,
79+
StratifiedStandardize, dataset=mt_dataset
9080
)
91-
options_b = {"stratification_idx": 2, "default_task_value": 4}
81+
options_b = {
82+
"stratification_idx": 2,
83+
"task_values": torch.tensor([0, 3]),
84+
}
9285
outcome_transform_kwargs_b = outcome_transform_argparse(
9386
StratifiedStandardize,
9487
dataset=mt_dataset,
9588
outcome_transform_options=options_b,
96-
search_space_digest=ssd,
9789
)
9890
expected_options_a = {
9991
"stratification_idx": 3,
100-
"observed_task_values": torch.tensor([0, 1], dtype=torch.long),
101-
"all_task_values": torch.tensor([0, 1, 2], dtype=torch.long),
102-
"default_task_value": 1,
103-
}
104-
expected_options_b = {
105-
"stratification_idx": 2,
106-
"observed_task_values": torch.tensor([0, 1], dtype=torch.long),
107-
"all_task_values": torch.tensor([0, 1, 2], dtype=torch.long),
108-
"default_task_value": 4,
92+
"task_values": torch.tensor([0, 1]),
10993
}
11094
for expected_options, actual_options in zip(
111-
(expected_options_a, expected_options_b),
95+
(expected_options_a, options_b),
11296
(outcome_transform_kwargs_a, outcome_transform_kwargs_b),
11397
):
114-
self.assertEqual(len(actual_options), 4)
115-
for k in ("stratification_idx", "stratification_idx"):
116-
self.assertEqual(actual_options[k], expected_options[k])
117-
for k in ("observed_task_values", "all_task_values"):
118-
self.assertTrue(
119-
torch.equal(
120-
actual_options[k],
121-
assert_is_instance(expected_options[k], Tensor),
122-
)
98+
self.assertEqual(len(actual_options), 2)
99+
self.assertEqual(
100+
actual_options["stratification_idx"],
101+
expected_options["stratification_idx"],
102+
)
103+
self.assertTrue(
104+
torch.equal(
105+
actual_options["task_values"],
106+
assert_is_instance(expected_options["task_values"], torch.Tensor),
123107
)
108+
)

ax/generators/torch/tests/test_surrogate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2047,7 +2047,8 @@ def test_fit(self) -> None:
20472047
),
20482048
}
20492049

2050-
Xs, Ys, Yvars, _, _, _, _ = get_torch_test_data(dtype=self.dtype)
2050+
# offset makes task feature point to valid outcome indices
2051+
Xs, Ys, Yvars, _, _, _, _ = get_torch_test_data(dtype=self.dtype, offset=-1)
20512052
ds1 = SupervisedDataset(
20522053
X=Xs,
20532054
Y=Ys,

ax/utils/testing/torch_stubs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_torch_test_data(
4646
Yvar = torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)
4747

4848
bounds = [
49-
(0.0 + offset, 2.0 + offset),
49+
(0.0 + offset, 1.0 + offset),
5050
(1.0 + offset, 4.0 + offset),
5151
(2.0 + offset, 5.0 + offset),
5252
]

0 commit comments

Comments
 (0)