77
88from unittest .mock import MagicMock , patch
99
10- import numpy as np
11-
1210import torch
1311from ax .adapter .torch import TorchAdapter
1412from ax .benchmark .benchmark_test_functions .surrogate import SurrogateTestFunction
15- from ax .benchmark .testing .benchmark_stubs import (
16- get_adapter ,
17- get_saas_adapter ,
18- get_soo_surrogate_test_function ,
19- )
20- from ax .generators .torch .botorch_modular .generator import BoTorchGenerator
13+ from ax .benchmark .testing .benchmark_stubs import get_soo_surrogate_test_function
2114from ax .utils .common .testutils import TestCase
22- from ax .utils .testing .core_stubs import (
23- get_branin_experiment ,
24- get_branin_experiment_with_multi_objective ,
25- )
26- from botorch .models .deterministic import PosteriorMeanModel
27- from botorch .sampling .pathwise .posterior_samplers import MatheronPathModel
2815
2916
3017class TestSurrogateTestFunction (TestCase ):
@@ -45,160 +32,6 @@ def test_surrogate_test_function(self) -> None:
4532 self .assertEqual (test_function .name , "test test function" )
4633 self .assertIs (test_function .surrogate , surrogate )
4734
48- def test_equality (self ) -> None :
49- def _construct_test_function (name : str ) -> SurrogateTestFunction :
50- return SurrogateTestFunction (
51- name = name ,
52- _surrogate = MagicMock (),
53- outcome_names = ["dummy_metric" ],
54- )
55-
56- runner_1 = _construct_test_function ("test 1" )
57- runner_2 = _construct_test_function ("test 2" )
58- runner_1a = _construct_test_function ("test 1" )
59- self .assertEqual (runner_1 , runner_1a )
60- self .assertNotEqual (runner_1 , runner_2 )
61- self .assertNotEqual (runner_1 , 1 )
62- self .assertNotEqual (runner_1 , None )
63-
64- def test_surrogate_model_types (self ) -> None :
65- """Test different surrogate model types: sample and mean."""
66- experiment = get_branin_experiment (with_completed_trial = True )
67-
68- for surrogate_model_type in [MatheronPathModel , PosteriorMeanModel ]:
69- with self .subTest (surrogate_model_type = surrogate_model_type ):
70- adapter = get_adapter (experiment )
71-
72- test_function = SurrogateTestFunction (
73- name = f"test_{ surrogate_model_type } _surrogate" ,
74- outcome_names = ["branin" ],
75- _surrogate = adapter ,
76- surrogate_model_type = surrogate_model_type ,
77- seed = 42 ,
78- )
79-
80- # Verify the surrogate type is set correctly
81- self .assertEqual (
82- test_function .surrogate_model_type , surrogate_model_type
83- )
84- self .assertEqual (test_function .seed , 42 )
85-
86- # Test evaluation
87- test_params = {"x1" : 0.5 , "x2" : 0.5 }
88- result = test_function .evaluate_true (test_params )
89-
90- # Ensure result is a tensor
91- self .assertIsInstance (result , torch .Tensor )
92- self .assertEqual (result .dtype , torch .double )
93- self .assertEqual (result .shape , torch .Size ([1 ])) # One outcome
94-
95- def test_surrogate_model_types_with_random_seeds (self ) -> None :
96- """Test that different random seeds produce different results for samples."""
97- experiment = get_branin_experiment (with_completed_trial = True )
98- test_params = {"x1" : 0.5 , "x2" : 0.5 }
99-
100- results = []
101- for seed in [0 , 1 , 2 ]:
102- adapter = get_adapter (experiment )
103- test_function = SurrogateTestFunction (
104- name = f"test_sample_surrogate_seed_{ seed } " ,
105- outcome_names = ["branin" ],
106- _surrogate = adapter ,
107- surrogate_model_type = MatheronPathModel ,
108- seed = seed ,
109- )
110-
111- result = test_function .evaluate_true (test_params )
112- results .append (result .item ())
113-
114- # Different seeds should produce different results for sample type
115- self .assertFalse (
116- all (r == results [0 ] for r in results [1 :]),
117- "Different random seeds should produce different sample results" ,
118- )
119-
120- def test_mean_surrogate_consistency (self ) -> None :
121- """Test that mean surrogate type produces consistent results."""
122- experiment = get_branin_experiment (with_completed_trial = True )
123- test_params = {"x1" : 0.5 , "x2" : 0.5 }
124-
125- results = []
126- # outcomes should be consistent since seed is fixed
127- for i in range (3 ):
128- adapter = get_adapter (experiment )
129- test_function = SurrogateTestFunction (
130- name = f"test_mean_surrogate_{ i } " ,
131- outcome_names = ["branin" ],
132- _surrogate = adapter ,
133- surrogate_model_type = MatheronPathModel ,
134- seed = 42 ,
135- )
136-
137- result = test_function .evaluate_true (test_params )
138- results .append (result .item ())
139-
140- # Mean type should produce consistent results regardless of seed
141- self .assertTrue (np .all (results [0 ] == np .array (results )))
142-
143- def test_surrogate_model_with_multiple_outcomes (self ) -> None :
144- """Test surrogate models with multiple outcome names."""
145- experiment = get_branin_experiment_with_multi_objective (
146- with_completed_trial = True
147- )
148- adapter = TorchAdapter (
149- experiment = experiment ,
150- search_space = experiment .search_space ,
151- generator = BoTorchGenerator (),
152- data = experiment .lookup_data (),
153- transforms = [],
154- )
155-
156- for surrogate_model_type in [MatheronPathModel , PosteriorMeanModel ]:
157- with self .subTest (surrogate_model_type = surrogate_model_type ):
158- test_function = SurrogateTestFunction (
159- name = f"test_multi_outcome_{ surrogate_model_type } " ,
160- outcome_names = ["branin_a" , "branin_b" ],
161- _surrogate = adapter ,
162- surrogate_model_type = surrogate_model_type ,
163- )
164- test_params = {"x1" : 0.5 , "x2" : 0.5 }
165- result = test_function .evaluate_true (test_params )
166-
167- # Should return 2 outcomes
168- self .assertEqual (result .shape , torch .Size ([2 ]))
169-
170- def test_saas_surrogate_model (self ) -> None :
171- """Test surrogate test function with SaasFullyBayesianSingleTaskGP model."""
172- experiment = get_branin_experiment (with_completed_trial = True )
173-
174- # Create adapter with SaasFullyBayesianSingleTaskGP model
175- adapter = get_saas_adapter (experiment )
176-
177- for surrogate_model_type in [MatheronPathModel , PosteriorMeanModel ]:
178- with self .subTest (surrogate_model_type = surrogate_model_type ):
179- test_function = SurrogateTestFunction (
180- name = f"test_saas_surrogate_{ surrogate_model_type } " ,
181- outcome_names = ["branin" ],
182- _surrogate = adapter ,
183- surrogate_model_type = surrogate_model_type ,
184- seed = 123 ,
185- )
186-
187- # Verify the surrogate type is set correctly
188- self .assertEqual (
189- test_function .surrogate_model_type , surrogate_model_type
190- )
191- self .assertEqual (test_function .seed , 123 )
192-
193- # Test evaluation
194- test_params = {"x1" : 0.5 , "x2" : 0.5 }
195- result = test_function .evaluate_true (test_params )
196-
197- # Ensure result is a tensor with correct properties
198- self .assertIsInstance (result , torch .Tensor )
199- self .assertEqual (result .dtype , torch .double )
200- self .assertEqual (result .shape , torch .Size ([1 ])) # One outcome
201-
20235 def test_lazy_instantiation (self ) -> None :
20336 test_function = get_soo_surrogate_test_function ()
20437
@@ -222,46 +55,18 @@ def test_instantiation_raises_with_missing_args(self) -> None:
22255 ):
22356 SurrogateTestFunction (name = "test runner" , outcome_names = [])
22457
225- def test_ensemble_sampling (self ) -> None :
226- """Test that ensemble sampling works correctly."""
227- experiment = get_branin_experiment (with_completed_trial = True )
228- adapter = get_saas_adapter (experiment ) # Creates ensemble model
229-
230- # Test with ensemble sampling enabled (default)
231- test_function = SurrogateTestFunction (
232- name = "test_ensemble_sampling_enabled" ,
233- outcome_names = ["branin" ],
234- _surrogate = adapter ,
235- surrogate_model_type = PosteriorMeanModel ,
236- sample_from_ensemble = True ,
237- )
238-
239- # Access surrogate to trigger wrapping
240- surrogate = test_function .surrogate
241- # pyre-ignore[16]: Access base_model through deterministic wrapper
242- wrapped_model = surrogate .generator .surrogate .model
243-
244- # Check that exactly one model has weight 1.0 and others have weight 0.0
245- weights = wrapped_model .ensemble_weights
246- self .assertEqual (weights .sum ().item (), 1.0 )
247- self .assertEqual ((weights == 1.0 ).sum ().item (), 1 )
248- self .assertEqual ((weights == 0.0 ).sum ().item (), len (weights ) - 1 )
249-
250- def test_ensemble_no_sampling (self ) -> None :
251- """Test that ensemble weights remain unchanged when sampling is disabled."""
252- experiment = get_branin_experiment (with_completed_trial = True )
253- adapter = get_saas_adapter (experiment ) # Creates ensemble model
254-
255- # Test with ensemble sampling disabled
256- test_function = SurrogateTestFunction (
257- name = "test_ensemble_sampling_disabled" ,
258- outcome_names = ["branin" ],
259- _surrogate = adapter ,
260- surrogate_model_type = PosteriorMeanModel ,
261- sample_from_ensemble = False ,
262- )
58+ def test_equality (self ) -> None :
59+ def _construct_test_function (name : str ) -> SurrogateTestFunction :
60+ return SurrogateTestFunction (
61+ name = name ,
62+ _surrogate = MagicMock (),
63+ outcome_names = ["dummy_metric" ],
64+ )
26365
264- # Access surrogate to trigger wrapping
265- surrogate = test_function .surrogate
266- # pyre-ignore[16]: Access base_model through deterministic wrapper
267- self .assertIsNone (surrogate .generator .surrogate .model .ensemble_weights )
66+ runner_1 = _construct_test_function ("test 1" )
67+ runner_2 = _construct_test_function ("test 2" )
68+ runner_1a = _construct_test_function ("test 1" )
69+ self .assertEqual (runner_1 , runner_1a )
70+ self .assertNotEqual (runner_1 , runner_2 )
71+ self .assertNotEqual (runner_1 , 1 )
72+ self .assertNotEqual (runner_1 , None )
0 commit comments