Skip to content

Commit fb71734

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Move wheelhouse tier check to OSS (facebook#4629)
Summary: Pull Request resolved: facebook#4629 Differential Revision: D88600526
1 parent c57e6a8 commit fb71734

2 files changed

Lines changed: 831 additions & 0 deletions

File tree

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from ax.exceptions.core import UserInputError
10+
from ax.utils.common.testutils import TestCase
11+
from ax.utils.common.wheelhouse_utils import (
12+
ADVANCED_TIER_MESSAGE,
13+
check_if_in_wheelhouse,
14+
DEFAULT_TIER_MESSAGES,
15+
ExperimentSummary,
16+
format_tier_message,
17+
TierMessages,
18+
UNSUPPORTED_TIER_MESSAGE,
19+
ValidationMessages,
20+
WHEELHOUSE_TIER_MESSAGE,
21+
)
22+
23+
24+
def get_experiment_summary(
25+
max_trials: int | None = 100,
26+
num_params: int = 10,
27+
num_binary: int = 0,
28+
num_categorical_3_5: int = 0,
29+
num_categorical_6_inf: int = 0,
30+
num_parameter_constraints: int = 0,
31+
num_objectives: int = 1,
32+
num_outcome_constraints: int = 0,
33+
uses_early_stopping: bool = False,
34+
uses_global_stopping: bool = False,
35+
all_inputs_are_configs: bool = True,
36+
tolerated_trial_failure_rate: float | None = 0.5,
37+
max_pending_trials: int | None = 5,
38+
min_failed_trials_for_failure_rate_check: int | None = 5,
39+
non_default_advanced_options: bool | None = None,
40+
uses_merge_multiple_curves: bool | None = None,
41+
) -> ExperimentSummary:
42+
"""Create an ExperimentSummary for testing."""
43+
return {
44+
"max_trials": max_trials,
45+
"num_params": num_params,
46+
"num_binary": num_binary,
47+
"num_categorical_3_5": num_categorical_3_5,
48+
"num_categorical_6_inf": num_categorical_6_inf,
49+
"num_parameter_constraints": num_parameter_constraints,
50+
"num_objectives": num_objectives,
51+
"num_outcome_constraints": num_outcome_constraints,
52+
"uses_early_stopping": uses_early_stopping,
53+
"uses_global_stopping": uses_global_stopping,
54+
"all_inputs_are_configs": all_inputs_are_configs,
55+
"tolerated_trial_failure_rate": tolerated_trial_failure_rate,
56+
"max_pending_trials": max_pending_trials,
57+
"min_failed_trials_for_failure_rate_check": (
58+
min_failed_trials_for_failure_rate_check
59+
),
60+
"non_default_advanced_options": non_default_advanced_options,
61+
"uses_merge_multiple_curves": uses_merge_multiple_curves,
62+
}
63+
64+
65+
class TestCheckIfInWheelhouse(TestCase):
66+
"""Tests for check_if_in_wheelhouse."""
67+
68+
def test_wheelhouse_tier_for_simple_experiment(self) -> None:
69+
"""Test that a simple experiment is classified as Wheelhouse tier."""
70+
summary = get_experiment_summary(
71+
max_trials=100,
72+
num_params=30,
73+
num_binary=10,
74+
num_parameter_constraints=1,
75+
num_objectives=2,
76+
num_outcome_constraints=1,
77+
)
78+
tier, why_not_wheelhouse, why_not_supported = check_if_in_wheelhouse(summary)
79+
80+
self.assertEqual(tier, "Wheelhouse")
81+
self.assertIsNone(why_not_wheelhouse)
82+
self.assertIsNone(why_not_supported)
83+
84+
def test_advanced_tier_conditions(self) -> None:
85+
"""Test conditions that result in Advanced tier."""
86+
test_cases: list[tuple[ExperimentSummary, str]] = [
87+
(get_experiment_summary(max_trials=250), "250 total trials"),
88+
(get_experiment_summary(num_params=60), "60 tunable parameter(s)"),
89+
(get_experiment_summary(num_binary=75), "75 binary tunable parameter(s)"),
90+
(
91+
get_experiment_summary(num_categorical_3_5=1),
92+
"1 unordered choice parameter(s)",
93+
),
94+
(
95+
get_experiment_summary(num_parameter_constraints=4),
96+
"4 parameter constraints",
97+
),
98+
(get_experiment_summary(num_objectives=3), "3 objectives"),
99+
(
100+
get_experiment_summary(num_outcome_constraints=3),
101+
"3 outcome constraints",
102+
),
103+
(
104+
get_experiment_summary(uses_early_stopping=True),
105+
"Early stopping is enabled",
106+
),
107+
(
108+
get_experiment_summary(uses_global_stopping=True),
109+
"Global stopping is enabled",
110+
),
111+
]
112+
113+
for summary, expected_msg in test_cases:
114+
with self.subTest(expected_msg=expected_msg):
115+
tier, why_not_wheelhouse, why_not_supported = check_if_in_wheelhouse(
116+
summary
117+
)
118+
119+
self.assertEqual(tier, "Advanced")
120+
self.assertIsNotNone(why_not_wheelhouse)
121+
self.assertIn(expected_msg, why_not_wheelhouse[0])
122+
self.assertIsNone(why_not_supported)
123+
124+
def test_unsupported_tier_conditions(self) -> None:
125+
"""Test conditions that result in Unsupported tier."""
126+
test_cases: list[tuple[ExperimentSummary, str]] = [
127+
(get_experiment_summary(max_trials=510), "510 total trials"),
128+
(get_experiment_summary(num_params=201), "201 tunable parameter(s)"),
129+
(get_experiment_summary(num_binary=101), "101 binary tunable parameter(s)"),
130+
(
131+
get_experiment_summary(num_categorical_3_5=6),
132+
"unordered choice parameters with more than 3 options",
133+
),
134+
(
135+
get_experiment_summary(num_categorical_6_inf=2),
136+
"unordered choice parameters with more than 5 options",
137+
),
138+
(
139+
get_experiment_summary(num_parameter_constraints=6),
140+
"6 parameter constraints",
141+
),
142+
(get_experiment_summary(num_objectives=5), "5 objectives"),
143+
(
144+
get_experiment_summary(num_outcome_constraints=6),
145+
"6 outcome constraints",
146+
),
147+
(
148+
get_experiment_summary(all_inputs_are_configs=False),
149+
"all_inputs_are_configs=False",
150+
),
151+
(
152+
get_experiment_summary(tolerated_trial_failure_rate=0.99),
153+
"tolerated_trial_failure_rate=0.99",
154+
),
155+
(
156+
get_experiment_summary(non_default_advanced_options=True),
157+
"Non-default advanced_options",
158+
),
159+
(
160+
get_experiment_summary(uses_merge_multiple_curves=True),
161+
"merge_multiple_curves=True",
162+
),
163+
]
164+
165+
for summary, expected_msg in test_cases:
166+
with self.subTest(expected_msg=expected_msg):
167+
tier, _, why_not_supported = check_if_in_wheelhouse(summary)
168+
169+
self.assertEqual(tier, "Unsupported")
170+
self.assertIsNotNone(why_not_supported)
171+
self.assertIn(expected_msg, why_not_supported[0])
172+
173+
def test_unsupported_tier_for_invalid_min_failed_trials(self) -> None:
174+
"""Test min_failed_trials exceeding threshold results in Unsupported tier."""
175+
summary = get_experiment_summary(
176+
max_pending_trials=3, min_failed_trials_for_failure_rate_check=7
177+
)
178+
tier, _, why_not_supported = check_if_in_wheelhouse(summary)
179+
180+
self.assertEqual(tier, "Unsupported")
181+
self.assertIsNotNone(why_not_supported)
182+
self.assertIn(
183+
"min_failed_trials_for_failure_rate_check=7", why_not_supported[0]
184+
)
185+
186+
def test_max_trials_none_raises(self) -> None:
187+
"""Test max_trials=None with all_inputs_are_configs=True raises error."""
188+
summary = get_experiment_summary(all_inputs_are_configs=True, max_trials=None)
189+
190+
with self.assertRaisesRegex(UserInputError, "`max_trials` should not be None!"):
191+
check_if_in_wheelhouse(summary)
192+
193+
def test_custom_validation_messages(self) -> None:
194+
"""Test custom ValidationMessages is used correctly."""
195+
custom_msg = "Custom message."
196+
summary = get_experiment_summary(all_inputs_are_configs=False)
197+
tier, _, why_not_supported = check_if_in_wheelhouse(
198+
summary,
199+
validation_messages=ValidationMessages(not_simple_inputs=custom_msg),
200+
)
201+
202+
self.assertEqual(tier, "Unsupported")
203+
self.assertIsNotNone(why_not_supported)
204+
self.assertIn(custom_msg, why_not_supported[0])
205+
206+
207+
class TestFormatTierMessage(TestCase):
208+
"""Tests for format_tier_message."""
209+
210+
def test_wheelhouse_tier_message(self) -> None:
211+
"""Test formatting of Wheelhouse tier message."""
212+
msg = format_tier_message(
213+
tier="Wheelhouse",
214+
why_not_is_in_wheelhouse=None,
215+
why_not_supported=None,
216+
)
217+
218+
self.assertIn("'Wheelhouse' tier", msg)
219+
self.assertIn(WHEELHOUSE_TIER_MESSAGE, msg)
220+
221+
def test_advanced_tier_message_with_reasons(self) -> None:
222+
"""Test formatting of Advanced tier message with reasons."""
223+
why_not_wheelhouse = ["51 tunable parameters", "Early stopping is enabled"]
224+
msg = format_tier_message(
225+
tier="Advanced",
226+
why_not_is_in_wheelhouse=why_not_wheelhouse,
227+
why_not_supported=None,
228+
)
229+
230+
self.assertIn("'Advanced' tier", msg)
231+
self.assertIn(ADVANCED_TIER_MESSAGE, msg)
232+
self.assertIn("Why this experiment is not in the 'Wheelhouse' tier:", msg)
233+
self.assertIn("51 tunable parameters", msg)
234+
self.assertIn("Early stopping is enabled", msg)
235+
236+
def test_unsupported_tier_message_with_reasons(self) -> None:
237+
"""Test formatting of Unsupported tier message with both reasons."""
238+
why_not_wheelhouse = ["51 tunable parameters"]
239+
why_not_supported = ["201 tunable parameters"]
240+
msg = format_tier_message(
241+
tier="Unsupported",
242+
why_not_is_in_wheelhouse=why_not_wheelhouse,
243+
why_not_supported=why_not_supported,
244+
)
245+
246+
self.assertIn("'Unsupported' tier", msg)
247+
self.assertIn(UNSUPPORTED_TIER_MESSAGE, msg)
248+
self.assertIn("Why this experiment is not in the 'Wheelhouse' tier:", msg)
249+
self.assertIn("51 tunable parameters", msg)
250+
self.assertIn("Why this experiment is not in the 'Advanced' tier:", msg)
251+
self.assertIn("201 tunable parameters", msg)
252+
253+
def test_unknown_tier_raises_error(self) -> None:
254+
"""Test that unknown tier raises UserInputError."""
255+
with self.assertRaisesRegex(UserInputError, 'Got unexpected tier "BadTier"'):
256+
format_tier_message(
257+
tier="BadTier",
258+
why_not_is_in_wheelhouse=None,
259+
why_not_supported=None,
260+
)
261+
262+
def test_custom_tier_messages(self) -> None:
263+
"""Test format_tier_message with custom TierMessages."""
264+
custom_messages = TierMessages(
265+
wheelhouse="Custom wheelhouse message.",
266+
advanced="Custom advanced message.",
267+
unsupported="Custom unsupported message.",
268+
unknown="Custom unknown message.",
269+
wiki_url="https://example.com/wiki",
270+
)
271+
272+
msg = format_tier_message(
273+
tier="Wheelhouse",
274+
why_not_is_in_wheelhouse=None,
275+
why_not_supported=None,
276+
tier_messages=custom_messages,
277+
)
278+
279+
self.assertIn("Custom wheelhouse message.", msg)
280+
self.assertIn("https://example.com/wiki", msg)
281+
282+
def test_custom_tier_messages_advanced(self) -> None:
283+
"""Test custom messages for Advanced tier."""
284+
custom_messages = TierMessages(
285+
wheelhouse="Custom wheelhouse.",
286+
advanced="Custom advanced.",
287+
unsupported="Custom unsupported.",
288+
unknown="Custom unknown.",
289+
wiki_url=None,
290+
)
291+
292+
msg = format_tier_message(
293+
tier="Advanced",
294+
why_not_is_in_wheelhouse=["Some reason"],
295+
why_not_supported=None,
296+
tier_messages=custom_messages,
297+
)
298+
299+
self.assertIn("Custom advanced.", msg)
300+
self.assertIn("Some reason", msg)
301+
# No wiki URL should be appended
302+
self.assertNotIn("For more information", msg)
303+
304+
def test_default_tier_messages_used_when_none_provided(self) -> None:
305+
"""Test that DEFAULT_TIER_MESSAGES is used for tier_messages when not provided."""
306+
msg = format_tier_message(
307+
tier="Wheelhouse",
308+
why_not_is_in_wheelhouse=None,
309+
why_not_supported=None,
310+
)
311+
312+
self.assertIn(DEFAULT_TIER_MESSAGES.wheelhouse, msg)

0 commit comments

Comments
 (0)