Skip to content

Commit 327b230

Browse files
blethammeta-codesync[bot]
authored andcommitted
enable cross validation on tracking metrics (facebook#4961)
Summary: Pull Request resolved: facebook#4961 Reviewed By: saitcakmak Differential Revision: D94553709 fbshipit-source-id: be0ab7c0d3cb4431dcb5e77ca1d9715183629ad9
1 parent 6d2249b commit 327b230

File tree

2 files changed

+272
-2
lines changed

2 files changed

+272
-2
lines changed

ax/analysis/diagnostics.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ class DiagnosticAnalysis(Analysis):
3535
of leave-one-out cross validation.
3636
"""
3737

38+
def __init__(self, include_tracking_metrics: bool = False) -> None:
39+
"""Initialize the DiagnosticAnalysis.
40+
41+
Args:
42+
include_tracking_metrics: Whether to include tracking metrics or just use
43+
the optimization config metrics.
44+
"""
45+
self.include_tracking_metrics = include_tracking_metrics
46+
3847
@override
3948
def validate_applicable_state(
4049
self,
@@ -57,8 +66,11 @@ def compute(
5766
) -> AnalysisCardGroup:
5867
experiment = none_throws(experiment)
5968

60-
# Extract all metric names from the OptimizationConfig.
61-
metric_names = [*none_throws(experiment.optimization_config).metrics.keys()]
69+
if self.include_tracking_metrics:
70+
metric_names = list(experiment.metrics.keys())
71+
else:
72+
# Extract all metric names from the OptimizationConfig.
73+
metric_names = [*none_throws(experiment.optimization_config).metrics.keys()]
6274

6375
is_bandit = generation_strategy and is_bandit_experiment(
6476
generation_strategy_name=generation_strategy.name
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from collections.abc import Callable, Sequence
9+
from unittest.mock import patch
10+
11+
import pandas as pd
12+
from ax.adapter.registry import Generators
13+
from ax.analysis.diagnostics import (
14+
DiagnosticAnalysis,
15+
DIAGNOSTICS_CARDGROUP_SUBTITLE,
16+
DIAGNOSTICS_CARDGROUP_TITLE,
17+
)
18+
from ax.analysis.plotly.cross_validation import CrossValidationPlot
19+
from ax.api.client import Client
20+
from ax.api.configs import RangeParameterConfig
21+
from ax.core.analysis_card import ErrorAnalysisCard
22+
from ax.core.arm import Arm
23+
from ax.core.data import Data
24+
from ax.core.experiment import Experiment
25+
from ax.core.metric import Metric
26+
from ax.core.optimization_config import Objective, OptimizationConfig
27+
from ax.core.parameter import ChoiceParameter, ParameterType
28+
from ax.core.search_space import SearchSpace
29+
from ax.generation_strategy.generation_strategy import (
30+
GenerationNode,
31+
GenerationStrategy,
32+
)
33+
from ax.generation_strategy.generator_spec import GeneratorSpec
34+
from ax.generation_strategy.transition_criterion import MinTrials
35+
from ax.utils.common.constants import Keys
36+
from ax.utils.common.testutils import TestCase
37+
from ax.utils.testing.mock import mock_botorch_optimize
38+
from pyre_extensions import none_throws
39+
40+
41+
class DiagnosticAnalysisTest(TestCase):
42+
def test_validate_applicable_state(self) -> None:
43+
analysis = DiagnosticAnalysis()
44+
45+
# Should return an error message when no experiment is provided
46+
result = analysis.validate_applicable_state()
47+
self.assertIsNotNone(result)
48+
self.assertIsInstance(result, str)
49+
50+
# Should return None when an experiment is provided (no trials/data required)
51+
experiment = Experiment(
52+
name="test",
53+
search_space=SearchSpace(
54+
parameters=[
55+
ChoiceParameter(
56+
name="x",
57+
parameter_type=ParameterType.FLOAT,
58+
values=[0.0, 1.0],
59+
),
60+
]
61+
),
62+
)
63+
result = analysis.validate_applicable_state(experiment=experiment)
64+
self.assertIsNone(result)
65+
66+
@mock_botorch_optimize
67+
def test_compute(self) -> None:
68+
# Set up a basic optimization with the Client
69+
client = Client()
70+
client.configure_experiment(
71+
name="booth_function",
72+
parameters=[
73+
RangeParameterConfig(
74+
name="x1",
75+
bounds=(-10.0, 10.0),
76+
parameter_type="float",
77+
),
78+
RangeParameterConfig(
79+
name="x2",
80+
bounds=(-10.0, 10.0),
81+
parameter_type="float",
82+
),
83+
],
84+
)
85+
client.configure_optimization(objective="-1 * booth")
86+
client.configure_tracking_metrics(["tracking_m"])
87+
88+
# Iterate well into the BO phase
89+
for _ in range(10):
90+
for trial_index, parameters in client.get_next_trials(max_trials=1).items():
91+
raw_data = {
92+
# pyre-ignore[58]
93+
"booth": (parameters["x1"] + 2.0 * parameters["x2"] - 7) ** 2.0
94+
# pyre-ignore[58]
95+
+ (2.0 * parameters["x1"] + parameters["x2"] - 5.0) ** 2.0,
96+
"tracking_m": 0.0,
97+
}
98+
client.complete_trial(trial_index=trial_index, raw_data=raw_data)
99+
100+
experiment = client._experiment
101+
generation_strategy = client._generation_strategy
102+
103+
# Compute with include_tracking_metrics=False (default)
104+
card_group = DiagnosticAnalysis(include_tracking_metrics=False).compute(
105+
experiment=experiment,
106+
generation_strategy=generation_strategy,
107+
)
108+
109+
self.assertEqual(card_group.title, DIAGNOSTICS_CARDGROUP_TITLE)
110+
self.assertEqual(card_group.subtitle, DIAGNOSTICS_CARDGROUP_SUBTITLE)
111+
112+
# Should have CrossValidationPlot and GenerationStrategyGraph children
113+
child_names = [child.name for child in card_group.children]
114+
self.assertIn("CrossValidationPlot", child_names)
115+
self.assertIn("GenerationStrategyGraph", child_names)
116+
117+
# No error cards
118+
for card in card_group.flatten():
119+
self.assertNotIsInstance(card, ErrorAnalysisCard)
120+
121+
# --- Verify metric_names via patching CrossValidationPlot ---
122+
original_cv_init: Callable[..., None] = CrossValidationPlot.__init__
123+
124+
captured_metric_names: list[Sequence[str] | None] = []
125+
126+
def capturing_init(self: CrossValidationPlot, **kwargs: object) -> None:
127+
# pyre-ignore[6]: metric_names is Sequence[str] | None
128+
captured_metric_names.append(kwargs.get("metric_names"))
129+
original_cv_init(self, **kwargs)
130+
131+
# include_tracking_metrics=False should only use optimization config metrics
132+
with patch.object(CrossValidationPlot, "__init__", capturing_init):
133+
DiagnosticAnalysis(include_tracking_metrics=False).compute(
134+
experiment=experiment,
135+
generation_strategy=generation_strategy,
136+
)
137+
self.assertEqual(len(captured_metric_names), 1)
138+
self.assertIn("booth", none_throws(captured_metric_names[0]))
139+
self.assertNotIn("tracking_m", none_throws(captured_metric_names[0]))
140+
141+
# include_tracking_metrics=True should use all experiment metrics
142+
captured_metric_names.clear()
143+
with patch.object(CrossValidationPlot, "__init__", capturing_init):
144+
card_group_tracking = DiagnosticAnalysis(
145+
include_tracking_metrics=True
146+
).compute(
147+
experiment=experiment,
148+
generation_strategy=generation_strategy,
149+
)
150+
self.assertEqual(len(captured_metric_names), 1)
151+
self.assertIn("booth", none_throws(captured_metric_names[0]))
152+
self.assertIn("tracking_m", none_throws(captured_metric_names[0]))
153+
154+
# Verify the card group is still valid
155+
self.assertEqual(card_group_tracking.title, DIAGNOSTICS_CARDGROUP_TITLE)
156+
157+
# --- Without generation_strategy: no GenerationStrategyGraph ---
158+
card_group_no_gs = DiagnosticAnalysis().compute(
159+
experiment=experiment,
160+
generation_strategy=None,
161+
)
162+
child_names_no_gs = [child.name for child in card_group_no_gs.children]
163+
self.assertIn("CrossValidationPlot", child_names_no_gs)
164+
self.assertNotIn("GenerationStrategyGraph", child_names_no_gs)
165+
166+
def test_compute_bandit(self) -> None:
167+
experiment = Experiment(
168+
name="bandit_test",
169+
search_space=SearchSpace(
170+
parameters=[
171+
ChoiceParameter(
172+
name="x1",
173+
parameter_type=ParameterType.FLOAT,
174+
values=[-10.0, -5.0, 0.0, 5.0, 10.0],
175+
),
176+
ChoiceParameter(
177+
name="x2",
178+
parameter_type=ParameterType.FLOAT,
179+
values=[-10.0, -5.0, 0.0, 5.0, 10.0],
180+
),
181+
]
182+
),
183+
optimization_config=OptimizationConfig(
184+
objective=Objective(
185+
metric=Metric(name="booth"),
186+
minimize=True,
187+
)
188+
),
189+
)
190+
191+
# Create batch trials with arms
192+
arm_configs = [
193+
[(-10.0, -10.0), (0.0, 0.0), (10.0, 10.0)],
194+
[(-10.0, 10.0), (10.0, -10.0), (5.0, 5.0)],
195+
]
196+
197+
data_rows = []
198+
for arm_coords in arm_configs:
199+
arms = [Arm(parameters={"x1": x1, "x2": x2}) for x1, x2 in arm_coords]
200+
trial = experiment.new_batch_trial()
201+
trial.add_arms_and_weights(arms=arms).mark_running(no_runner_required=True)
202+
203+
for arm in trial.arms:
204+
x1, x2 = float(arm.parameters["x1"]), float(arm.parameters["x2"])
205+
data_rows.append(
206+
{
207+
"trial_index": trial.index,
208+
"arm_name": arm.name,
209+
"metric_name": "booth",
210+
"metric_signature": "booth",
211+
"mean": (x1 + 2 * x2 - 7) ** 2 + (2 * x1 + x2 - 5) ** 2,
212+
"sem": 0.0,
213+
}
214+
)
215+
216+
experiment.attach_data(Data(df=pd.DataFrame(data_rows)))
217+
218+
# Set up bandit generation strategy
219+
factorial_node = GenerationNode(
220+
name="FACTORIAL",
221+
generator_specs=[
222+
GeneratorSpec(generator_enum=Generators.FACTORIAL),
223+
],
224+
transition_criteria=[
225+
MinTrials(
226+
threshold=1,
227+
transition_to="EMPIRICAL_BAYES_THOMPSON_SAMPLING",
228+
)
229+
],
230+
)
231+
eb_ts_node = GenerationNode(
232+
name="EMPIRICAL_BAYES_THOMPSON_SAMPLING",
233+
generator_specs=[
234+
GeneratorSpec(generator_enum=Generators.EMPIRICAL_BAYES_THOMPSON),
235+
],
236+
transition_criteria=None,
237+
)
238+
bandit_gs = GenerationStrategy(
239+
name=Keys.FACTORIAL_PLUS_EMPIRICAL_BAYES_THOMPSON_SAMPLING,
240+
nodes=[factorial_node, eb_ts_node],
241+
)
242+
243+
card_group = DiagnosticAnalysis().compute(
244+
experiment=experiment,
245+
generation_strategy=bandit_gs,
246+
)
247+
248+
child_names = [child.name for child in card_group.children]
249+
250+
# Bandit experiment should NOT include CrossValidationPlot
251+
self.assertNotIn("CrossValidationPlot", child_names)
252+
253+
# GenerationStrategyGraph should still be included (GS is provided)
254+
self.assertIn("GenerationStrategyGraph", child_names)
255+
256+
# Verify card group metadata
257+
self.assertEqual(card_group.title, DIAGNOSTICS_CARDGROUP_TITLE)
258+
self.assertEqual(card_group.subtitle, DIAGNOSTICS_CARDGROUP_SUBTITLE)

0 commit comments

Comments
 (0)