66
77# pyre-strict
88
9- from ax .exceptions .core import OptimizationNotConfiguredError
9+ from ax .exceptions .core import OptimizationNotConfiguredError , UserInputError
1010from ax .service .orchestrator import OrchestratorOptions
11+
1112from ax .utils .common .testutils import TestCase
1213from ax .utils .common .wheelhouse_utils import (
1314 ADVANCED_TIER_MESSAGE ,
15+ check_if_in_wheelhouse ,
1416 format_tier_message ,
17+ OptimizationSummary ,
1518 summarize_ax_experiment_wheelhouse ,
1619 UNSUPPORTED_TIER_MESSAGE ,
1720 WHEELHOUSE_TIER_MESSAGE ,
@@ -41,31 +44,13 @@ def test_basic_experiment_summary(self) -> None:
4144 tier_metadata = self .tier_metadata ,
4245 )
4346
44- # THEN the summary should contain all expected keys with correct values
45- expected_keys = [
46- "max_trials" ,
47- "num_params" ,
48- "num_binary" ,
49- "num_categorical_3_5" ,
50- "num_categorical_6_inf" ,
51- "num_parameter_constraints" ,
52- "num_objectives" ,
53- "num_outcome_constraints" ,
54- "uses_early_stopping" ,
55- "uses_global_stopping" ,
56- "uses_merge_multiple_curves" ,
57- "all_inputs_are_configs" ,
58- "tolerated_trial_failure_rate" ,
59- "max_pending_trials" ,
60- "min_failed_trials_for_failure_rate_check" ,
61- ]
62- for key in expected_keys :
63- self .assertIn (key , summary )
47+ # THEN the summary should be an OptimizationSummary with correct values
48+ self .assertIsInstance (summary , OptimizationSummary )
6449
6550 # Validate specific values for single-objective experiment
66- self .assertEqual (summary [ " num_objectives" ] , 1 )
67- self .assertFalse (summary [ " uses_early_stopping" ] )
68- self .assertFalse (summary [ " uses_global_stopping" ] )
51+ self .assertEqual (summary . num_objectives , 1 )
52+ self .assertFalse (summary . uses_early_stopping )
53+ self .assertFalse (summary . uses_global_stopping )
6954
7055 def test_multi_objective_experiment (self ) -> None :
7156 # GIVEN a multi-objective experiment
@@ -79,7 +64,7 @@ def test_multi_objective_experiment(self) -> None:
7964 )
8065
8166 # THEN num_objectives should be greater than 1
82- self .assertGreater (summary [ " num_objectives" ] , 1 )
67+ self .assertGreater (summary . num_objectives , 1 )
8368
8469 def test_experiment_without_optimization_config_raises (self ) -> None :
8570 # GIVEN an experiment without optimization config
@@ -128,10 +113,8 @@ def test_tier_metadata_extraction(self) -> None:
128113 )
129114
130115 # THEN the summary should reflect tier metadata values
131- self .assertEqual (summary ["max_trials" ], expected_max_trials )
132- self .assertEqual (
133- summary ["all_inputs_are_configs" ], expected_all_configs
134- )
116+ self .assertEqual (summary .max_trials , expected_max_trials )
117+ self .assertEqual (summary .all_inputs_are_configs , expected_all_configs )
135118
136119 def test_orchestrator_options_extraction (self ) -> None :
137120 # GIVEN custom orchestrator options
@@ -149,9 +132,9 @@ def test_orchestrator_options_extraction(self) -> None:
149132 )
150133
151134 # THEN the summary should reflect orchestrator options
152- self .assertEqual (summary [ " tolerated_trial_failure_rate" ] , 0.25 )
153- self .assertEqual (summary [ " max_pending_trials" ] , 5 )
154- self .assertEqual (summary [ " min_failed_trials_for_failure_rate_check" ] , 10 )
135+ self .assertEqual (summary . tolerated_trial_failure_rate , 0.25 )
136+ self .assertEqual (summary . max_pending_trials , 5 )
137+ self .assertEqual (summary . min_failed_trials_for_failure_rate_check , 10 )
155138
156139 def test_parameter_constraints_counted (self ) -> None :
157140 # GIVEN an experiment with parameter constraints
@@ -165,7 +148,7 @@ def test_parameter_constraints_counted(self) -> None:
165148 )
166149
167150 # THEN num_parameter_constraints should be greater than 0
168- self .assertGreater (summary [ " num_parameter_constraints" ] , 0 )
151+ self .assertGreater (summary . num_parameter_constraints , 0 )
169152
170153
171154class TestFormatTierMessage (TestCase ):
@@ -241,3 +224,164 @@ def test_unknown_tier_raises_error(self) -> None:
241224 why_not_is_in_wheelhouse = None ,
242225 why_not_supported = None ,
243226 )
227+
228+
229+ def get_experiment_summary (
230+ max_trials : int | None = 100 ,
231+ num_params : int = 10 ,
232+ num_binary : int = 0 ,
233+ num_categorical_3_5 : int = 0 ,
234+ num_categorical_6_inf : int = 0 ,
235+ num_parameter_constraints : int = 0 ,
236+ num_objectives : int = 1 ,
237+ num_outcome_constraints : int = 0 ,
238+ uses_early_stopping : bool = False ,
239+ uses_global_stopping : bool = False ,
240+ all_inputs_are_configs : bool = True ,
241+ tolerated_trial_failure_rate : float | None = 0.5 ,
242+ max_pending_trials : int | None = 5 ,
243+ min_failed_trials_for_failure_rate_check : int | None = 5 ,
244+ non_default_advanced_options : bool | None = None ,
245+ uses_merge_multiple_curves : bool | None = None ,
246+ ) -> OptimizationSummary :
247+ """Create an OptimizationSummary for testing."""
248+ return OptimizationSummary (
249+ max_trials = max_trials ,
250+ num_params = num_params ,
251+ num_binary = num_binary ,
252+ num_categorical_3_5 = num_categorical_3_5 ,
253+ num_categorical_6_inf = num_categorical_6_inf ,
254+ num_parameter_constraints = num_parameter_constraints ,
255+ num_objectives = num_objectives ,
256+ num_outcome_constraints = num_outcome_constraints ,
257+ uses_early_stopping = uses_early_stopping ,
258+ uses_global_stopping = uses_global_stopping ,
259+ all_inputs_are_configs = all_inputs_are_configs ,
260+ tolerated_trial_failure_rate = tolerated_trial_failure_rate ,
261+ max_pending_trials = max_pending_trials ,
262+ min_failed_trials_for_failure_rate_check = (
263+ min_failed_trials_for_failure_rate_check
264+ ),
265+ non_default_advanced_options = non_default_advanced_options ,
266+ uses_merge_multiple_curves = uses_merge_multiple_curves ,
267+ )
268+
269+
270+ class TestCheckIfInWheelhouse (TestCase ):
271+ """Tests for check_if_in_wheelhouse."""
272+
273+ def setUp (self ) -> None :
274+ super ().setUp ()
275+ self .base_summary = get_experiment_summary ()
276+
277+ def test_wheelhouse_tier_for_simple_experiment (self ) -> None :
278+ """Test that a simple experiment is classified as Wheelhouse tier."""
279+ tier , why_not_wheelhouse , why_not_supported = check_if_in_wheelhouse (
280+ self .base_summary
281+ )
282+
283+ self .assertEqual (tier , "Wheelhouse" )
284+ self .assertEqual (why_not_wheelhouse , None )
285+ self .assertEqual (why_not_supported , None )
286+
287+ def test_advanced_tier_conditions (self ) -> None :
288+ """Test conditions that result in Advanced tier."""
289+ test_cases : list [tuple [OptimizationSummary , str ]] = [
290+ (get_experiment_summary (max_trials = 250 ), "250 total trials" ),
291+ (get_experiment_summary (num_params = 60 ), "60 tunable parameter(s)" ),
292+ (get_experiment_summary (num_binary = 75 ), "75 binary tunable parameter(s)" ),
293+ (
294+ get_experiment_summary (num_categorical_3_5 = 1 ),
295+ "1 unordered choice parameter(s)" ,
296+ ),
297+ (
298+ get_experiment_summary (num_parameter_constraints = 4 ),
299+ "4 parameter constraints" ,
300+ ),
301+ (get_experiment_summary (num_objectives = 3 ), "3 objectives" ),
302+ (
303+ get_experiment_summary (num_outcome_constraints = 3 ),
304+ "3 outcome constraints" ,
305+ ),
306+ (
307+ get_experiment_summary (uses_early_stopping = True ),
308+ "Early stopping is enabled" ,
309+ ),
310+ (
311+ get_experiment_summary (uses_global_stopping = True ),
312+ "Global stopping is enabled" ,
313+ ),
314+ ]
315+
316+ for summary , expected_msg in test_cases :
317+ with self .subTest (expected_msg = expected_msg ):
318+ tier , why_not_wheelhouse , why_not_supported = check_if_in_wheelhouse (
319+ summary
320+ )
321+
322+ self .assertEqual (tier , "Advanced" )
323+ self .assertIsNotNone (why_not_wheelhouse )
324+ self .assertIn (expected_msg , why_not_wheelhouse [0 ])
325+ self .assertEqual (why_not_supported , None )
326+
327+ def test_unsupported_tier_conditions (self ) -> None :
328+ """Test conditions that result in Unsupported tier."""
329+ test_cases : list [tuple [OptimizationSummary , str ]] = [
330+ (get_experiment_summary (max_trials = 510 ), "510 total trials" ),
331+ (get_experiment_summary (num_params = 201 ), "201 tunable parameter(s)" ),
332+ (get_experiment_summary (num_binary = 101 ), "101 binary tunable parameter(s)" ),
333+ (
334+ get_experiment_summary (num_categorical_3_5 = 6 ),
335+ "unordered choice parameters with more than 3 options" ,
336+ ),
337+ (
338+ get_experiment_summary (num_categorical_6_inf = 2 ),
339+ "unordered choice parameters with more than 5 options" ,
340+ ),
341+ (
342+ get_experiment_summary (num_parameter_constraints = 6 ),
343+ "6 parameter constraints" ,
344+ ),
345+ (get_experiment_summary (num_objectives = 5 ), "5 objectives" ),
346+ (
347+ get_experiment_summary (num_outcome_constraints = 6 ),
348+ "6 outcome constraints" ,
349+ ),
350+ (
351+ get_experiment_summary (all_inputs_are_configs = False ),
352+ "Ax abstractions" ,
353+ ),
354+ (
355+ get_experiment_summary (tolerated_trial_failure_rate = 0.99 ),
356+ "tolerated_trial_failure_rate=0.99" ,
357+ ),
358+ (
359+ get_experiment_summary (non_default_advanced_options = True ),
360+ "Non-default advanced_options" ,
361+ ),
362+ (
363+ get_experiment_summary (uses_merge_multiple_curves = True ),
364+ "merge_multiple_curves=True" ,
365+ ),
366+ (
367+ get_experiment_summary (
368+ max_pending_trials = 3 , min_failed_trials_for_failure_rate_check = 7
369+ ),
370+ "min_failed_trials_for_failure_rate_check=7" ,
371+ ),
372+ ]
373+
374+ for summary , expected_msg in test_cases :
375+ with self .subTest (expected_msg = expected_msg ):
376+ tier , _ , why_not_supported = check_if_in_wheelhouse (summary )
377+
378+ self .assertEqual (tier , "Unsupported" )
379+ self .assertIsNotNone (why_not_supported )
380+ self .assertIn (expected_msg , why_not_supported [0 ])
381+
382+ def test_max_trials_none_raises (self ) -> None :
383+ """Test max_trials=None with all_inputs_are_configs=True raises error."""
384+ summary = get_experiment_summary (all_inputs_are_configs = True , max_trials = None )
385+
386+ with self .assertRaisesRegex (UserInputError , "`max_trials` should not be None!" ):
387+ check_if_in_wheelhouse (summary )
0 commit comments