Skip to content

Commit 10dac10

Browse files
eonofreyfacebook-github-bot
authored andcommitted
Consolidate ax/analysis tests
Summary: Part of a 19-diff stack to consolidate repetitive tests across Ax and PTS using `subTest`. Consolidate 11 test files in ax/analysis/ (plotly, healthcheck, graphviz) — adds subTest to top surfaces, metric summary, search space summary, and summary tests. Differential Revision: D95603642
1 parent bb0530f commit 10dac10

11 files changed

Lines changed: 227 additions & 366 deletions

ax/analysis/graphviz/tests/test_generation_strategy_graph.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,23 @@ def setUp(self) -> None:
9696
],
9797
)
9898

99-
def test_validate_applicable_state_no_gs(self) -> None:
100-
"""Test that validation fails when no GenerationStrategy is provided."""
99+
def test_validate_applicable_state(self) -> None:
100+
"""Test validation with and without a GenerationStrategy."""
101101
analysis = GenerationStrategyGraph()
102-
result = analysis.validate_applicable_state(
103-
experiment=None,
104-
generation_strategy=None,
105-
)
106-
self.assertIsNotNone(result)
107-
self.assertIn("requires a GenerationStrategy", result)
108-
109-
def test_validate_applicable_state_valid(self) -> None:
110-
"""Test that validation passes with a valid GenerationStrategy."""
111-
analysis = GenerationStrategyGraph()
112-
result = analysis.validate_applicable_state(
113-
experiment=None,
114-
generation_strategy=self.node_gs,
115-
)
116-
self.assertIsNone(result)
102+
for label, gs, expect_error in [
103+
("no_gs", None, True),
104+
("valid", self.node_gs, False),
105+
]:
106+
with self.subTest(label=label):
107+
result = analysis.validate_applicable_state(
108+
experiment=None,
109+
generation_strategy=gs,
110+
)
111+
if expect_error:
112+
self.assertIsNotNone(result)
113+
self.assertIn("requires a GenerationStrategy", result)
114+
else:
115+
self.assertIsNone(result)
117116

118117
def test_compute_step_based_gs(self) -> None:
119118
"""Test computing the graph for a step-based GenerationStrategy."""
@@ -218,26 +217,22 @@ def test_generation_strategy_to_graphviz(self) -> None:
218217
self.assertIn("Sobol", dot.source)
219218
self.assertIn("MBM", dot.source)
220219

221-
def test_add_node_to_graph_current(self) -> None:
222-
"""Test adding a current node to the graph."""
223-
dot = Digraph()
224-
node = self.node_gs._nodes[0]
225-
_add_node_to_graph(dot=dot, node=node, is_current=True)
226-
227-
source = dot.source
228-
self.assertIn("Sobol", source)
229-
self.assertIn("lightblue", source)
230-
self.assertIn("bold", source)
231-
232-
def test_add_node_to_graph_non_current(self) -> None:
233-
"""Test adding a non-current node to the graph."""
234-
dot = Digraph()
235-
node = self.node_gs._nodes[1]
236-
_add_node_to_graph(dot=dot, node=node, is_current=False)
237-
238-
source = dot.source
239-
self.assertIn("MBM", source)
240-
self.assertIn("rounded", source)
220+
def test_add_node_to_graph(self) -> None:
221+
"""Test adding current and non-current nodes to the graph."""
222+
for is_current, node_index, expected_name, expected_style in [
223+
(True, 0, "Sobol", "lightblue"),
224+
(False, 1, "MBM", "rounded"),
225+
]:
226+
with self.subTest(is_current=is_current):
227+
dot = Digraph()
228+
node = self.node_gs._nodes[node_index]
229+
_add_node_to_graph(dot=dot, node=node, is_current=is_current)
230+
231+
source = dot.source
232+
self.assertIn(expected_name, source)
233+
self.assertIn(expected_style, source)
234+
if is_current:
235+
self.assertIn("bold", source)
241236

242237
def test_add_edges_for_node(self) -> None:
243238
"""Test adding edges for a node."""

ax/analysis/graphviz/tests/test_hierarchical_search_space_graph.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,17 @@ def test_compute(self) -> None:
148148
for dependent_parameter_name in dependent_parameter_names:
149149
self.assertIn(dependent_parameter_name, source)
150150

151-
def test_online(self) -> None:
151+
def test_online_and_offline(self) -> None:
152152
analysis = HierarchicalSearchSpaceGraph()
153-
for experiment in get_online_experiments():
154-
# If validation fails (i.e. this Experiment is not applicable to HSSGraph)
155-
# then skip it in the tests
156-
if analysis.validate_applicable_state(experiment=experiment) is not None:
157-
continue
158-
159-
_ = analysis.compute(experiment=experiment)
160-
161-
def test_offline(self) -> None:
162-
analysis = HierarchicalSearchSpaceGraph()
163-
for experiment in get_offline_experiments():
164-
# If validation fails (i.e. this Experiment is not applicable to HSSGraph)
165-
# then skip it in the tests
166-
if analysis.validate_applicable_state(experiment=experiment) is not None:
167-
continue
168-
169-
_ = analysis.compute(experiment=experiment)
153+
for label, get_experiments in [
154+
("online", get_online_experiments),
155+
("offline", get_offline_experiments),
156+
]:
157+
with self.subTest(setting=label):
158+
for experiment in get_experiments():
159+
if (
160+
analysis.validate_applicable_state(experiment=experiment)
161+
is not None
162+
):
163+
continue
164+
_ = analysis.compute(experiment=experiment)

ax/analysis/healthcheck/tests/test_predictable_metrics.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -212,40 +212,43 @@ def test_random_adapter_not_applicable(self) -> None:
212212
self.assertIn("RandomAdapter", error)
213213
self.assertIn("no model to evaluate", error)
214214

215-
def test_validate_applicable_state_no_experiment(self) -> None:
216-
"""Test that validation fails when experiment is None."""
215+
def test_validate_applicable_state(self) -> None:
216+
"""Test validation for various input combinations."""
217217
healthcheck = PredictableMetricsAnalysis()
218218

219-
error = healthcheck.validate_applicable_state(
220-
experiment=None,
221-
generation_strategy=self.generation_strategy,
222-
)
223-
224-
self.assertIsNotNone(error)
225-
self.assertIn("Experiment", error)
226-
227-
def test_validate_applicable_state_no_generation_strategy(self) -> None:
228-
"""Test that validation fails when generation_strategy is None."""
229-
healthcheck = PredictableMetricsAnalysis()
230-
231-
error = healthcheck.validate_applicable_state(
232-
experiment=self.experiment,
233-
generation_strategy=None,
234-
)
235-
236-
self.assertIsNotNone(error)
237-
self.assertIn("GenerationStrategy", error)
238-
239-
def test_validate_applicable_state_valid_inputs(self) -> None:
240-
"""Test that validation passes with valid experiment and generation strategy."""
241-
healthcheck = PredictableMetricsAnalysis()
242-
243-
error = healthcheck.validate_applicable_state(
244-
experiment=self.experiment,
245-
generation_strategy=self.generation_strategy,
246-
)
247-
248-
self.assertIsNone(error)
219+
for label, experiment, generation_strategy, expect_error, expected_substr in [
220+
(
221+
"no_experiment",
222+
None,
223+
self.generation_strategy,
224+
True,
225+
"Experiment",
226+
),
227+
(
228+
"no_generation_strategy",
229+
self.experiment,
230+
None,
231+
True,
232+
"GenerationStrategy",
233+
),
234+
(
235+
"valid_inputs",
236+
self.experiment,
237+
self.generation_strategy,
238+
False,
239+
None,
240+
),
241+
]:
242+
with self.subTest(label=label):
243+
error = healthcheck.validate_applicable_state(
244+
experiment=experiment,
245+
generation_strategy=generation_strategy,
246+
)
247+
if expect_error:
248+
self.assertIsNotNone(error)
249+
self.assertIn(expected_substr, error)
250+
else:
251+
self.assertIsNone(error)
249252

250253
@mock_botorch_optimize
251254
def test_adapter_resolved_from_generation_strategy(self) -> None:

ax/analysis/healthcheck/tests/test_should_generate_candidates.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,17 @@
1313

1414

1515
class TestShouldGenerateCandidates(TestCase):
16-
def test_should(self) -> None:
17-
trial_index = randint(0, 10)
18-
card = ShouldGenerateCandidates(
19-
should_generate=True,
20-
reason="Something reassuring",
21-
trial_index=trial_index,
22-
).compute()
23-
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
24-
self.assertEqual(card.subtitle, "Something reassuring")
25-
26-
def test_should_not(self) -> None:
27-
trial_index = randint(0, 10)
28-
card = ShouldGenerateCandidates(
29-
should_generate=False,
30-
reason="Something concerning",
31-
trial_index=trial_index,
32-
).compute()
33-
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
34-
self.assertEqual(card.subtitle, "Something concerning")
16+
def test_should_generate_candidates(self) -> None:
17+
for should_generate, reason, expected_status in [
18+
(True, "Something reassuring", HealthcheckStatus.PASS),
19+
(False, "Something concerning", HealthcheckStatus.WARNING),
20+
]:
21+
with self.subTest(should_generate=should_generate):
22+
trial_index = randint(0, 10)
23+
card = ShouldGenerateCandidates(
24+
should_generate=should_generate,
25+
reason=reason,
26+
trial_index=trial_index,
27+
).compute()
28+
self.assertEqual(card.get_status(), expected_status)
29+
self.assertEqual(card.subtitle, reason)

ax/analysis/plotly/tests/test_cross_validation.py

Lines changed: 35 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -175,64 +175,38 @@ def test_compute_adhoc(self, mock_r2: mock.Mock) -> None:
175175
)
176176
)
177177
@mock_botorch_optimize
178-
def test_online(self) -> None:
179-
# Test CrossValidationPlot can be computed for a variety of experiments which
180-
# resemble those we see in an online setting.
181-
182-
for experiment in get_online_experiments():
183-
for untransform in [True, False]:
184-
for refined_metric_name in [None, "foo"]:
185-
generation_strategy = get_default_generation_strategy_at_MBM_node(
186-
experiment=experiment
187-
)
188-
189-
# Pick an arbitrary metric from the experiment's optimization config
190-
metric_name = none_throws(
191-
experiment.optimization_config
192-
).objective.metric_names[0]
193-
194-
analysis = CrossValidationPlot(
195-
metric_names=[metric_name],
196-
untransform=untransform,
197-
labels={metric_name: refined_metric_name}
198-
if refined_metric_name
199-
else None,
200-
)
201-
202-
_ = analysis.compute(
203-
experiment=experiment, generation_strategy=generation_strategy
204-
)
205-
206-
@TestCase.ax_long_test(
207-
reason=(
208-
"cross_validate still too slow under @mock_botorch_optimize for this test"
209-
)
210-
)
211-
@mock_botorch_optimize
212-
def test_offline(self) -> None:
213-
# Test CrossValidationPlot can be computed for a variety of experiments which
214-
# resemble those we see in an online setting.
215-
216-
for experiment in get_offline_experiments():
217-
for untransform in [True, False]:
218-
for refined_metric_name in [None, "foo"]:
219-
generation_strategy = get_default_generation_strategy_at_MBM_node(
220-
experiment=experiment
221-
)
222-
223-
# Pick an arbitrary metric from the experiment's optimization config
224-
metric_name = none_throws(
225-
experiment.optimization_config
226-
).objective.metric_names[0]
227-
228-
analysis = CrossValidationPlot(
229-
metric_names=[metric_name],
230-
untransform=untransform,
231-
labels={metric_name: refined_metric_name}
232-
if refined_metric_name
233-
else None,
234-
)
235-
236-
_ = analysis.compute(
237-
experiment=experiment, generation_strategy=generation_strategy
238-
)
178+
def test_online_and_offline(self) -> None:
179+
for label, get_experiments in [
180+
("online", get_online_experiments),
181+
("offline", get_offline_experiments),
182+
]:
183+
for experiment in get_experiments():
184+
for untransform in [True, False]:
185+
for refined_metric_name in [None, "foo"]:
186+
with self.subTest(
187+
setting=label,
188+
untransform=untransform,
189+
refined_metric_name=refined_metric_name,
190+
):
191+
generation_strategy = (
192+
get_default_generation_strategy_at_MBM_node(
193+
experiment=experiment
194+
)
195+
)
196+
197+
metric_name = none_throws(
198+
experiment.optimization_config
199+
).objective.metric_names[0]
200+
201+
analysis = CrossValidationPlot(
202+
metric_names=[metric_name],
203+
untransform=untransform,
204+
labels={metric_name: refined_metric_name}
205+
if refined_metric_name
206+
else None,
207+
)
208+
209+
_ = analysis.compute(
210+
experiment=experiment,
211+
generation_strategy=generation_strategy,
212+
)

ax/analysis/plotly/tests/test_parallel_coordinates.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,30 +105,16 @@ def test_get_parameter_dimension(self) -> None:
105105
},
106106
)
107107

108-
def test_online(self) -> None:
109-
# Test ParallelCoordinatesPlot can be computed for a variety of experiments
110-
# which resemble those we see in an online setting.
111-
112-
for experiment in get_online_experiments():
113-
analysis = ParallelCoordinatesPlot(
114-
# Select and arbitrary metric from the optimization config
115-
metric_name=none_throws(
116-
experiment.optimization_config
117-
).objective.metric_names[0]
118-
)
119-
120-
_ = analysis.compute(experiment=experiment)
121-
122-
def test_offline(self) -> None:
123-
# Test ParallelCoordinatesPlot can be computed for a variety of experiments
124-
# which resemble those we see in an offline setting.
125-
126-
for experiment in get_offline_experiments():
127-
analysis = ParallelCoordinatesPlot(
128-
# Select and arbitrary metric from the optimization config
129-
metric_name=none_throws(
130-
experiment.optimization_config
131-
).objective.metric_names[0]
132-
)
133-
134-
_ = analysis.compute(experiment=experiment)
108+
def test_online_and_offline(self) -> None:
109+
for label, get_experiments in [
110+
("online", get_online_experiments),
111+
("offline", get_offline_experiments),
112+
]:
113+
with self.subTest(setting=label):
114+
for experiment in get_experiments():
115+
analysis = ParallelCoordinatesPlot(
116+
metric_name=none_throws(
117+
experiment.optimization_config
118+
).objective.metric_names[0]
119+
)
120+
_ = analysis.compute(experiment=experiment)

0 commit comments

Comments
 (0)