Skip to content

Commit eb9ab4d

Browse files
mpolson64facebook-github-bot
authored andcommitted
Back out "Enabled posterior samples as SurrogateBenchmarks"
Summary: Original commit changeset: b5be2d69d998 Original Phabricator Diff: D80347570 Same motivation as D81695384, will be cleaned up after Ax 1.1.1 release Differential Revision: D81799749
1 parent 5c56f51 commit eb9ab4d

File tree

3 files changed

+15
-299
lines changed

3 files changed

+15
-299
lines changed

ax/benchmark/benchmark_test_functions/surrogate.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,9 @@
1515
from ax.core.types import TParamValue
1616
from ax.utils.common.base import Base
1717
from ax.utils.common.equality import equality_typechecker
18-
from botorch.models.deterministic import (
19-
DeterministicModel,
20-
MatheronPathModel,
21-
PosteriorMeanModel,
22-
)
23-
from botorch.utils.transforms import is_ensemble
2418
from pyre_extensions import none_throws
2519
from torch import Tensor
2620

27-
RANDOM_SURROGATE_TYPES: list[type[DeterministicModel]] = [MatheronPathModel]
28-
2921

3022
@dataclass(kw_only=True)
3123
class SurrogateTestFunction(BenchmarkTestFunction):
@@ -43,21 +35,12 @@ class SurrogateTestFunction(BenchmarkTestFunction):
4335
get_surrogate: Function that returns the surrogate, to allow for lazy
4436
construction. If `get_surrogate` is not provided, `surrogate` must
4537
be provided and vice versa.
46-
surrogate_model_type: The type of surrogate model to use. We either pass
47-
in a type of deterministic model, (e.g. PosteriorMeanModel) or
48-
a function that returns a deterministic model
49-
(e.g. get_matheron_path_model).
50-
sample_from_ensemble: If True, when the surrogate is an ensemble model,
51-
we sample from the ensemble instead of averaging over it.
5238
"""
5339

5440
name: str
5541
outcome_names: Sequence[str]
5642
_surrogate: TorchAdapter | None = None
5743
get_surrogate: None | Callable[[], TorchAdapter] = None
58-
surrogate_model_type: type[DeterministicModel] = PosteriorMeanModel
59-
sample_from_ensemble: bool = False
60-
seed: int = 0
6144

6245
def __post_init__(self) -> None:
6346
if self.get_surrogate is None and self._surrogate is None:
@@ -66,50 +49,10 @@ def __post_init__(self) -> None:
6649
" vice versa."
6750
)
6851

69-
def wrap_surrogate_in_deterministic_model(self) -> None:
70-
"""Substitute the surrogate model for a deterministic model used for
71-
benchmarking."""
72-
# pyre-ignore[16]: `ax.generators.torch_base.TorchGenerator` has no attribute
73-
# `surrogate`.
74-
surrogate_model = none_throws(self._surrogate).generator.surrogate
75-
76-
if isinstance(surrogate_model.model, DeterministicModel):
77-
return # Already wrapped
78-
79-
base_model = surrogate_model.model
80-
81-
# Check if surrogate_model_type accepts a 'seed' argument
82-
if self.surrogate_model_type in RANDOM_SURROGATE_TYPES:
83-
# pyre-ignore[28]: Unexpected keyword argument `seed` to call
84-
# `botorch.models.ensemble.EnsembleModel.__init__`.
85-
wrapped_model = self.surrogate_model_type(base_model, seed=self.seed)
86-
else:
87-
wrapped_model = self.surrogate_model_type(base_model)
88-
89-
# If the model is an ensemble, we want to choose one of the models instead
90-
# of averging over them.
91-
if is_ensemble(base_model) and self.sample_from_ensemble:
92-
num_models = base_model.batch_shape.numel()
93-
# Sample a single index from a uniform multinomial distribution
94-
sampled_idx = torch.multinomial(torch.ones(num_models), 1)
95-
# Set weights: one 1 at sampled_idx, rest 0
96-
wrapped_model.ensemble_weights = torch.zeros(num_models, 1)
97-
none_throws(wrapped_model.ensemble_weights)[sampled_idx] = 1.0
98-
else:
99-
wrapped_model.ensemble_weights = None
100-
surrogate_model._model = wrapped_model
101-
10252
@property
10353
def surrogate(self) -> TorchAdapter:
10454
if self._surrogate is None:
10555
self._surrogate = none_throws(self.get_surrogate)()
106-
if not isinstance(
107-
# pyre-ignore[16]: `ax.generators.torch_base.TorchGenerator` has no
108-
# attribute `surrogate`.
109-
self._surrogate.generator.surrogate.model,
110-
DeterministicModel,
111-
):
112-
self.wrap_surrogate_in_deterministic_model()
11356
return none_throws(self._surrogate)
11457

11558
def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:

ax/benchmark/testing/benchmark_stubs.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,10 @@
4646
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
4747
from ax.generation_strategy.generation_strategy import GenerationStrategy
4848
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
49-
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
5049
from ax.utils.testing.core_stubs import (
5150
get_branin_experiment,
5251
get_branin_experiment_with_multi_objective,
5352
)
54-
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
5553
from botorch.test_functions.multi_objective import BraninCurrin
5654
from botorch.test_functions.synthetic import Branin
5755

@@ -373,33 +371,3 @@ def get_mock_lcbench_data() -> LCBenchData:
373371
metric_series=metric_series,
374372
timestamp_series=timestamp_series,
375373
)
376-
377-
378-
def get_adapter(experiment: Experiment) -> TorchAdapter:
379-
"""Create a generic adapter for testing different surrogate model types."""
380-
adapter = TorchAdapter(
381-
experiment=experiment,
382-
generator=BoTorchGenerator(),
383-
)
384-
return adapter
385-
386-
387-
def get_saas_adapter(experiment: Experiment) -> TorchAdapter:
388-
"""Create an adapter with SaasFullyBayesianSingleTaskGP model."""
389-
return TorchAdapter(
390-
experiment=experiment,
391-
generator=BoTorchGenerator(
392-
surrogate_spec=SurrogateSpec(
393-
model_configs=[
394-
ModelConfig(
395-
botorch_model_class=SaasFullyBayesianSingleTaskGP,
396-
mll_options={
397-
"warmup_steps": 2,
398-
"num_samples": 4,
399-
"thinning": 1,
400-
},
401-
),
402-
]
403-
),
404-
),
405-
)

ax/benchmark/tests/benchmark_test_functions/test_surrogate_test_function.py

Lines changed: 15 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,11 @@
77

88
from unittest.mock import MagicMock, patch
99

10-
import numpy as np
11-
1210
import torch
1311
from ax.adapter.torch import TorchAdapter
1412
from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction
15-
from ax.benchmark.testing.benchmark_stubs import (
16-
get_adapter,
17-
get_saas_adapter,
18-
get_soo_surrogate_test_function,
19-
)
20-
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
13+
from ax.benchmark.testing.benchmark_stubs import get_soo_surrogate_test_function
2114
from ax.utils.common.testutils import TestCase
22-
from ax.utils.testing.core_stubs import (
23-
get_branin_experiment,
24-
get_branin_experiment_with_multi_objective,
25-
)
26-
from botorch.models.deterministic import PosteriorMeanModel
27-
from botorch.sampling.pathwise.posterior_samplers import MatheronPathModel
2815

2916

3017
class TestSurrogateTestFunction(TestCase):
@@ -45,160 +32,6 @@ def test_surrogate_test_function(self) -> None:
4532
self.assertEqual(test_function.name, "test test function")
4633
self.assertIs(test_function.surrogate, surrogate)
4734

48-
def test_equality(self) -> None:
49-
def _construct_test_function(name: str) -> SurrogateTestFunction:
50-
return SurrogateTestFunction(
51-
name=name,
52-
_surrogate=MagicMock(),
53-
outcome_names=["dummy_metric"],
54-
)
55-
56-
runner_1 = _construct_test_function("test 1")
57-
runner_2 = _construct_test_function("test 2")
58-
runner_1a = _construct_test_function("test 1")
59-
self.assertEqual(runner_1, runner_1a)
60-
self.assertNotEqual(runner_1, runner_2)
61-
self.assertNotEqual(runner_1, 1)
62-
self.assertNotEqual(runner_1, None)
63-
64-
def test_surrogate_model_types(self) -> None:
65-
"""Test different surrogate model types: sample and mean."""
66-
experiment = get_branin_experiment(with_completed_trial=True)
67-
68-
for surrogate_model_type in [MatheronPathModel, PosteriorMeanModel]:
69-
with self.subTest(surrogate_model_type=surrogate_model_type):
70-
adapter = get_adapter(experiment)
71-
72-
test_function = SurrogateTestFunction(
73-
name=f"test_{surrogate_model_type}_surrogate",
74-
outcome_names=["branin"],
75-
_surrogate=adapter,
76-
surrogate_model_type=surrogate_model_type,
77-
seed=42,
78-
)
79-
80-
# Verify the surrogate type is set correctly
81-
self.assertEqual(
82-
test_function.surrogate_model_type, surrogate_model_type
83-
)
84-
self.assertEqual(test_function.seed, 42)
85-
86-
# Test evaluation
87-
test_params = {"x1": 0.5, "x2": 0.5}
88-
result = test_function.evaluate_true(test_params)
89-
90-
# Ensure result is a tensor
91-
self.assertIsInstance(result, torch.Tensor)
92-
self.assertEqual(result.dtype, torch.double)
93-
self.assertEqual(result.shape, torch.Size([1])) # One outcome
94-
95-
def test_surrogate_model_types_with_random_seeds(self) -> None:
96-
"""Test that different random seeds produce different results for samples."""
97-
experiment = get_branin_experiment(with_completed_trial=True)
98-
test_params = {"x1": 0.5, "x2": 0.5}
99-
100-
results = []
101-
for seed in [0, 1, 2]:
102-
adapter = get_adapter(experiment)
103-
test_function = SurrogateTestFunction(
104-
name=f"test_sample_surrogate_seed_{seed}",
105-
outcome_names=["branin"],
106-
_surrogate=adapter,
107-
surrogate_model_type=MatheronPathModel,
108-
seed=seed,
109-
)
110-
111-
result = test_function.evaluate_true(test_params)
112-
results.append(result.item())
113-
114-
# Different seeds should produce different results for sample type
115-
self.assertFalse(
116-
all(r == results[0] for r in results[1:]),
117-
"Different random seeds should produce different sample results",
118-
)
119-
120-
def test_mean_surrogate_consistency(self) -> None:
121-
"""Test that mean surrogate type produces consistent results."""
122-
experiment = get_branin_experiment(with_completed_trial=True)
123-
test_params = {"x1": 0.5, "x2": 0.5}
124-
125-
results = []
126-
# outcomes should be consistent since seed is fixed
127-
for i in range(3):
128-
adapter = get_adapter(experiment)
129-
test_function = SurrogateTestFunction(
130-
name=f"test_mean_surrogate_{i}",
131-
outcome_names=["branin"],
132-
_surrogate=adapter,
133-
surrogate_model_type=MatheronPathModel,
134-
seed=42,
135-
)
136-
137-
result = test_function.evaluate_true(test_params)
138-
results.append(result.item())
139-
140-
# Mean type should produce consistent results regardless of seed
141-
self.assertTrue(np.all(results[0] == np.array(results)))
142-
143-
def test_surrogate_model_with_multiple_outcomes(self) -> None:
144-
"""Test surrogate models with multiple outcome names."""
145-
experiment = get_branin_experiment_with_multi_objective(
146-
with_completed_trial=True
147-
)
148-
adapter = TorchAdapter(
149-
experiment=experiment,
150-
search_space=experiment.search_space,
151-
generator=BoTorchGenerator(),
152-
data=experiment.lookup_data(),
153-
transforms=[],
154-
)
155-
156-
for surrogate_model_type in [MatheronPathModel, PosteriorMeanModel]:
157-
with self.subTest(surrogate_model_type=surrogate_model_type):
158-
test_function = SurrogateTestFunction(
159-
name=f"test_multi_outcome_{surrogate_model_type}",
160-
outcome_names=["branin_a", "branin_b"],
161-
_surrogate=adapter,
162-
surrogate_model_type=surrogate_model_type,
163-
)
164-
test_params = {"x1": 0.5, "x2": 0.5}
165-
result = test_function.evaluate_true(test_params)
166-
167-
# Should return 2 outcomes
168-
self.assertEqual(result.shape, torch.Size([2]))
169-
170-
def test_saas_surrogate_model(self) -> None:
171-
"""Test surrogate test function with SaasFullyBayesianSingleTaskGP model."""
172-
experiment = get_branin_experiment(with_completed_trial=True)
173-
174-
# Create adapter with SaasFullyBayesianSingleTaskGP model
175-
adapter = get_saas_adapter(experiment)
176-
177-
for surrogate_model_type in [MatheronPathModel, PosteriorMeanModel]:
178-
with self.subTest(surrogate_model_type=surrogate_model_type):
179-
test_function = SurrogateTestFunction(
180-
name=f"test_saas_surrogate_{surrogate_model_type}",
181-
outcome_names=["branin"],
182-
_surrogate=adapter,
183-
surrogate_model_type=surrogate_model_type,
184-
seed=123,
185-
)
186-
187-
# Verify the surrogate type is set correctly
188-
self.assertEqual(
189-
test_function.surrogate_model_type, surrogate_model_type
190-
)
191-
self.assertEqual(test_function.seed, 123)
192-
193-
# Test evaluation
194-
test_params = {"x1": 0.5, "x2": 0.5}
195-
result = test_function.evaluate_true(test_params)
196-
197-
# Ensure result is a tensor with correct properties
198-
self.assertIsInstance(result, torch.Tensor)
199-
self.assertEqual(result.dtype, torch.double)
200-
self.assertEqual(result.shape, torch.Size([1])) # One outcome
201-
20235
def test_lazy_instantiation(self) -> None:
20336
test_function = get_soo_surrogate_test_function()
20437

@@ -222,46 +55,18 @@ def test_instantiation_raises_with_missing_args(self) -> None:
22255
):
22356
SurrogateTestFunction(name="test runner", outcome_names=[])
22457

225-
def test_ensemble_sampling(self) -> None:
226-
"""Test that ensemble sampling works correctly."""
227-
experiment = get_branin_experiment(with_completed_trial=True)
228-
adapter = get_saas_adapter(experiment) # Creates ensemble model
229-
230-
# Test with ensemble sampling enabled (default)
231-
test_function = SurrogateTestFunction(
232-
name="test_ensemble_sampling_enabled",
233-
outcome_names=["branin"],
234-
_surrogate=adapter,
235-
surrogate_model_type=PosteriorMeanModel,
236-
sample_from_ensemble=True,
237-
)
238-
239-
# Access surrogate to trigger wrapping
240-
surrogate = test_function.surrogate
241-
# pyre-ignore[16]: Access base_model through deterministic wrapper
242-
wrapped_model = surrogate.generator.surrogate.model
243-
244-
# Check that exactly one model has weight 1.0 and others have weight 0.0
245-
weights = wrapped_model.ensemble_weights
246-
self.assertEqual(weights.sum().item(), 1.0)
247-
self.assertEqual((weights == 1.0).sum().item(), 1)
248-
self.assertEqual((weights == 0.0).sum().item(), len(weights) - 1)
249-
250-
def test_ensemble_no_sampling(self) -> None:
251-
"""Test that ensemble weights remain unchanged when sampling is disabled."""
252-
experiment = get_branin_experiment(with_completed_trial=True)
253-
adapter = get_saas_adapter(experiment) # Creates ensemble model
254-
255-
# Test with ensemble sampling disabled
256-
test_function = SurrogateTestFunction(
257-
name="test_ensemble_sampling_disabled",
258-
outcome_names=["branin"],
259-
_surrogate=adapter,
260-
surrogate_model_type=PosteriorMeanModel,
261-
sample_from_ensemble=False,
262-
)
58+
def test_equality(self) -> None:
59+
def _construct_test_function(name: str) -> SurrogateTestFunction:
60+
return SurrogateTestFunction(
61+
name=name,
62+
_surrogate=MagicMock(),
63+
outcome_names=["dummy_metric"],
64+
)
26365

264-
# Access surrogate to trigger wrapping
265-
surrogate = test_function.surrogate
266-
# pyre-ignore[16]: Access base_model through deterministic wrapper
267-
self.assertIsNone(surrogate.generator.surrogate.model.ensemble_weights)
66+
runner_1 = _construct_test_function("test 1")
67+
runner_2 = _construct_test_function("test 2")
68+
runner_1a = _construct_test_function("test 1")
69+
self.assertEqual(runner_1, runner_1a)
70+
self.assertNotEqual(runner_1, runner_2)
71+
self.assertNotEqual(runner_1, 1)
72+
self.assertNotEqual(runner_1, None)

0 commit comments

Comments
 (0)