Skip to content

Commit be9074b

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
(4/5) Port helpers to OSS for the new Complexity Rating Healthcheck - wheelhouse tier check (facebook#4629)
Summary: Pull Request resolved: facebook#4629 Differential Revision: D88600526
1 parent 2895eb8 commit be9074b

2 files changed

Lines changed: 568 additions & 55 deletions

File tree

ax/utils/common/tests/test_wheelhouse_utils.py

Lines changed: 177 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66

77
# pyre-strict
88

9-
from ax.exceptions.core import OptimizationNotConfiguredError
9+
from ax.exceptions.core import OptimizationNotConfiguredError, UserInputError
1010
from ax.service.orchestrator import OrchestratorOptions
11+
1112
from ax.utils.common.testutils import TestCase
1213
from ax.utils.common.wheelhouse_utils import (
1314
ADVANCED_TIER_MESSAGE,
15+
check_if_in_wheelhouse,
1416
format_tier_message,
17+
OptimizationSummary,
1518
summarize_ax_experiment_wheelhouse,
1619
UNSUPPORTED_TIER_MESSAGE,
1720
WHEELHOUSE_TIER_MESSAGE,
@@ -41,31 +44,13 @@ def test_basic_experiment_summary(self) -> None:
4144
tier_metadata=self.tier_metadata,
4245
)
4346

44-
# THEN the summary should contain all expected keys with correct values
45-
expected_keys = [
46-
"max_trials",
47-
"num_params",
48-
"num_binary",
49-
"num_categorical_3_5",
50-
"num_categorical_6_inf",
51-
"num_parameter_constraints",
52-
"num_objectives",
53-
"num_outcome_constraints",
54-
"uses_early_stopping",
55-
"uses_global_stopping",
56-
"uses_merge_multiple_curves",
57-
"all_inputs_are_configs",
58-
"tolerated_trial_failure_rate",
59-
"max_pending_trials",
60-
"min_failed_trials_for_failure_rate_check",
61-
]
62-
for key in expected_keys:
63-
self.assertIn(key, summary)
47+
# THEN the summary should be an OptimizationSummary with correct values
48+
self.assertIsInstance(summary, OptimizationSummary)
6449

6550
# Validate specific values for single-objective experiment
66-
self.assertEqual(summary["num_objectives"], 1)
67-
self.assertFalse(summary["uses_early_stopping"])
68-
self.assertFalse(summary["uses_global_stopping"])
51+
self.assertEqual(summary.num_objectives, 1)
52+
self.assertFalse(summary.uses_early_stopping)
53+
self.assertFalse(summary.uses_global_stopping)
6954

7055
def test_multi_objective_experiment(self) -> None:
7156
# GIVEN a multi-objective experiment
@@ -79,7 +64,7 @@ def test_multi_objective_experiment(self) -> None:
7964
)
8065

8166
# THEN num_objectives should be greater than 1
82-
self.assertGreater(summary["num_objectives"], 1)
67+
self.assertGreater(summary.num_objectives, 1)
8368

8469
def test_experiment_without_optimization_config_raises(self) -> None:
8570
# GIVEN an experiment without optimization config
@@ -128,10 +113,8 @@ def test_tier_metadata_extraction(self) -> None:
128113
)
129114

130115
# THEN the summary should reflect tier metadata values
131-
self.assertEqual(summary["max_trials"], expected_max_trials)
132-
self.assertEqual(
133-
summary["all_inputs_are_configs"], expected_all_configs
134-
)
116+
self.assertEqual(summary.max_trials, expected_max_trials)
117+
self.assertEqual(summary.all_inputs_are_configs, expected_all_configs)
135118

136119
def test_orchestrator_options_extraction(self) -> None:
137120
# GIVEN custom orchestrator options
@@ -149,9 +132,9 @@ def test_orchestrator_options_extraction(self) -> None:
149132
)
150133

151134
# THEN the summary should reflect orchestrator options
152-
self.assertEqual(summary["tolerated_trial_failure_rate"], 0.25)
153-
self.assertEqual(summary["max_pending_trials"], 5)
154-
self.assertEqual(summary["min_failed_trials_for_failure_rate_check"], 10)
135+
self.assertEqual(summary.tolerated_trial_failure_rate, 0.25)
136+
self.assertEqual(summary.max_pending_trials, 5)
137+
self.assertEqual(summary.min_failed_trials_for_failure_rate_check, 10)
155138

156139
def test_parameter_constraints_counted(self) -> None:
157140
# GIVEN an experiment with parameter constraints
@@ -165,7 +148,7 @@ def test_parameter_constraints_counted(self) -> None:
165148
)
166149

167150
# THEN num_parameter_constraints should be greater than 0
168-
self.assertGreater(summary["num_parameter_constraints"], 0)
151+
self.assertGreater(summary.num_parameter_constraints, 0)
169152

170153

171154
class TestFormatTierMessage(TestCase):
@@ -241,3 +224,164 @@ def test_unknown_tier_raises_error(self) -> None:
241224
why_not_is_in_wheelhouse=None,
242225
why_not_supported=None,
243226
)
227+
228+
229+
def get_experiment_summary(
230+
max_trials: int | None = 100,
231+
num_params: int = 10,
232+
num_binary: int = 0,
233+
num_categorical_3_5: int = 0,
234+
num_categorical_6_inf: int = 0,
235+
num_parameter_constraints: int = 0,
236+
num_objectives: int = 1,
237+
num_outcome_constraints: int = 0,
238+
uses_early_stopping: bool = False,
239+
uses_global_stopping: bool = False,
240+
all_inputs_are_configs: bool = True,
241+
tolerated_trial_failure_rate: float | None = 0.5,
242+
max_pending_trials: int | None = 5,
243+
min_failed_trials_for_failure_rate_check: int | None = 5,
244+
non_default_advanced_options: bool | None = None,
245+
uses_merge_multiple_curves: bool | None = None,
246+
) -> OptimizationSummary:
247+
"""Create an OptimizationSummary for testing."""
248+
return OptimizationSummary(
249+
max_trials=max_trials,
250+
num_params=num_params,
251+
num_binary=num_binary,
252+
num_categorical_3_5=num_categorical_3_5,
253+
num_categorical_6_inf=num_categorical_6_inf,
254+
num_parameter_constraints=num_parameter_constraints,
255+
num_objectives=num_objectives,
256+
num_outcome_constraints=num_outcome_constraints,
257+
uses_early_stopping=uses_early_stopping,
258+
uses_global_stopping=uses_global_stopping,
259+
all_inputs_are_configs=all_inputs_are_configs,
260+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
261+
max_pending_trials=max_pending_trials,
262+
min_failed_trials_for_failure_rate_check=(
263+
min_failed_trials_for_failure_rate_check
264+
),
265+
non_default_advanced_options=non_default_advanced_options,
266+
uses_merge_multiple_curves=uses_merge_multiple_curves,
267+
)
268+
269+
270+
class TestCheckIfInWheelhouse(TestCase):
271+
"""Tests for check_if_in_wheelhouse."""
272+
273+
def setUp(self) -> None:
274+
super().setUp()
275+
self.base_summary = get_experiment_summary()
276+
277+
def test_wheelhouse_tier_for_simple_experiment(self) -> None:
278+
"""Test that a simple experiment is classified as Wheelhouse tier."""
279+
tier, why_not_wheelhouse, why_not_supported = check_if_in_wheelhouse(
280+
self.base_summary
281+
)
282+
283+
self.assertEqual(tier, "Wheelhouse")
284+
self.assertEqual(why_not_wheelhouse, None)
285+
self.assertEqual(why_not_supported, None)
286+
287+
def test_advanced_tier_conditions(self) -> None:
288+
"""Test conditions that result in Advanced tier."""
289+
test_cases: list[tuple[OptimizationSummary, str]] = [
290+
(get_experiment_summary(max_trials=250), "250 total trials"),
291+
(get_experiment_summary(num_params=60), "60 tunable parameter(s)"),
292+
(get_experiment_summary(num_binary=75), "75 binary tunable parameter(s)"),
293+
(
294+
get_experiment_summary(num_categorical_3_5=1),
295+
"1 unordered choice parameter(s)",
296+
),
297+
(
298+
get_experiment_summary(num_parameter_constraints=4),
299+
"4 parameter constraints",
300+
),
301+
(get_experiment_summary(num_objectives=3), "3 objectives"),
302+
(
303+
get_experiment_summary(num_outcome_constraints=3),
304+
"3 outcome constraints",
305+
),
306+
(
307+
get_experiment_summary(uses_early_stopping=True),
308+
"Early stopping is enabled",
309+
),
310+
(
311+
get_experiment_summary(uses_global_stopping=True),
312+
"Global stopping is enabled",
313+
),
314+
]
315+
316+
for summary, expected_msg in test_cases:
317+
with self.subTest(expected_msg=expected_msg):
318+
tier, why_not_wheelhouse, why_not_supported = check_if_in_wheelhouse(
319+
summary
320+
)
321+
322+
self.assertEqual(tier, "Advanced")
323+
self.assertIsNotNone(why_not_wheelhouse)
324+
self.assertIn(expected_msg, why_not_wheelhouse[0])
325+
self.assertEqual(why_not_supported, None)
326+
327+
def test_unsupported_tier_conditions(self) -> None:
328+
"""Test conditions that result in Unsupported tier."""
329+
test_cases: list[tuple[OptimizationSummary, str]] = [
330+
(get_experiment_summary(max_trials=510), "510 total trials"),
331+
(get_experiment_summary(num_params=201), "201 tunable parameter(s)"),
332+
(get_experiment_summary(num_binary=101), "101 binary tunable parameter(s)"),
333+
(
334+
get_experiment_summary(num_categorical_3_5=6),
335+
"unordered choice parameters with more than 3 options",
336+
),
337+
(
338+
get_experiment_summary(num_categorical_6_inf=2),
339+
"unordered choice parameters with more than 5 options",
340+
),
341+
(
342+
get_experiment_summary(num_parameter_constraints=6),
343+
"6 parameter constraints",
344+
),
345+
(get_experiment_summary(num_objectives=5), "5 objectives"),
346+
(
347+
get_experiment_summary(num_outcome_constraints=6),
348+
"6 outcome constraints",
349+
),
350+
(
351+
get_experiment_summary(all_inputs_are_configs=False),
352+
"Ax abstractions",
353+
),
354+
(
355+
get_experiment_summary(tolerated_trial_failure_rate=0.99),
356+
"tolerated_trial_failure_rate=0.99",
357+
),
358+
(
359+
get_experiment_summary(non_default_advanced_options=True),
360+
"Non-default advanced_options",
361+
),
362+
(
363+
get_experiment_summary(uses_merge_multiple_curves=True),
364+
"merge_multiple_curves=True",
365+
),
366+
(
367+
get_experiment_summary(
368+
max_pending_trials=3, min_failed_trials_for_failure_rate_check=7
369+
),
370+
"min_failed_trials_for_failure_rate_check=7",
371+
),
372+
]
373+
374+
for summary, expected_msg in test_cases:
375+
with self.subTest(expected_msg=expected_msg):
376+
tier, _, why_not_supported = check_if_in_wheelhouse(summary)
377+
378+
self.assertEqual(tier, "Unsupported")
379+
self.assertIsNotNone(why_not_supported)
380+
self.assertIn(expected_msg, why_not_supported[0])
381+
382+
def test_max_trials_none_raises(self) -> None:
383+
"""Test max_trials=None with all_inputs_are_configs=True raises error."""
384+
summary = get_experiment_summary(all_inputs_are_configs=True, max_trials=None)
385+
386+
with self.assertRaisesRegex(UserInputError, "`max_trials` should not be None!"):
387+
check_if_in_wheelhouse(summary)

0 commit comments

Comments
 (0)