Skip to content

Commit 459f0bf

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Tighten TopSurfacesAnalysis.validate_applicable_state (facebook#4541)
Summary: Pull Request resolved: facebook#4541 Add checks for experiment having trials, data, that the metric name is present on the data, and that the current adapter is the TorchAdapter since Sensitivity will fail if the adapter is anything but the TorchAdapter Reviewed By: mgarrard Differential Revision: D87089811 fbshipit-source-id: 8717ff4d19663c36d8f90e30a830335f40addd90
1 parent 674c017 commit 459f0bf

2 files changed

Lines changed: 120 additions & 6 deletions

File tree

ax/analysis/plotly/tests/test_top_surfaces.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,69 @@ def test_validate_applicable_state(self) -> None:
2424
none_throws(TopSurfacesAnalysis().validate_applicable_state()),
2525
)
2626

27+
client = Client()
28+
client.configure_experiment(
29+
name="foo",
30+
parameters=[
31+
RangeParameterConfig(
32+
name="x1",
33+
parameter_type="float",
34+
bounds=(0, 1),
35+
),
36+
RangeParameterConfig(
37+
name="x2",
38+
parameter_type="float",
39+
bounds=(0, 1),
40+
),
41+
],
42+
)
43+
client.configure_optimization(objective="bar")
44+
45+
for _ in range(1):
46+
for trial_index, parameterization in client.get_next_trials(
47+
max_trials=1
48+
).items():
49+
client.complete_trial(
50+
trial_index=trial_index,
51+
raw_data={
52+
"bar": assert_is_instance(parameterization["x1"], float)
53+
- 2 * assert_is_instance(parameterization["x2"], float)
54+
},
55+
)
56+
57+
self.assertIn(
58+
"Ax has not yet reached a GenerationNode",
59+
none_throws(
60+
TopSurfacesAnalysis(
61+
metric_name="bar", order="first"
62+
).validate_applicable_state(
63+
client._experiment, client._generation_strategy
64+
)
65+
),
66+
)
67+
for _ in range(5):
68+
for trial_index, parameterization in client.get_next_trials(
69+
max_trials=1
70+
).items():
71+
client.complete_trial(
72+
trial_index=trial_index,
73+
raw_data={
74+
"bar": assert_is_instance(parameterization["x1"], float)
75+
- 2 * assert_is_instance(parameterization["x2"], float)
76+
},
77+
)
78+
79+
self.assertIn(
80+
"no data for metrics {'baz'}",
81+
none_throws(
82+
TopSurfacesAnalysis(
83+
metric_name="baz", order="first"
84+
).validate_applicable_state(
85+
client._experiment, client._generation_strategy
86+
)
87+
),
88+
)
89+
2790
@mock_botorch_optimize
2891
def test_compute(self) -> None:
2992
client = Client()

ax/analysis/plotly/top_surfaces.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from typing import final, Literal
99

1010
from ax.adapter.base import Adapter
11+
from ax.adapter.torch import TorchAdapter
1112
from ax.analysis.analysis import Analysis
1213
from ax.analysis.analysis_card import (
1314
AnalysisCard,
1415
AnalysisCardBase,
1516
AnalysisCardGroup,
1617
ErrorAnalysisCard,
1718
)
19+
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
1820
from ax.analysis.plotly.sensitivity import SensitivityAnalysisPlot
1921
from ax.analysis.plotly.surface.contour import (
2022
CONTOUR_CARDGROUP_SUBTITLE,
@@ -27,8 +29,13 @@
2729
SlicePlot,
2830
)
2931
from ax.analysis.plotly.utils import select_metric
30-
from ax.analysis.utils import validate_experiment
32+
from ax.analysis.utils import (
33+
extract_relevant_adapter,
34+
validate_experiment,
35+
validate_experiment_has_trials,
36+
)
3137
from ax.core.experiment import Experiment
38+
from ax.exceptions.core import UserInputError
3239
from ax.generation_strategy.generation_strategy import GenerationStrategy
3340
from pyre_extensions import assert_is_instance, none_throws, override
3441

@@ -68,14 +75,46 @@ def validate_applicable_state(
6875
adapter: Adapter | None = None,
6976
) -> str | None:
7077
"""
71-
TopSurfacesAnalysis requires an experiment with trials and data.
78+
TopSurfacesAnalysis requires an experiment with trials and data as well as
79+
a TorchAdapter.
7280
"""
73-
if self.metric_name is None:
74-
return validate_experiment(
81+
if (
82+
experiment_invalid_reason := validate_experiment(
7583
experiment=experiment,
7684
require_trials=True,
7785
require_data=True,
7886
)
87+
) is not None:
88+
return experiment_invalid_reason
89+
90+
metric_name = (
91+
self.metric_name
92+
if self.metric_name is not None
93+
else select_metric(experiment=none_throws(experiment))
94+
)
95+
96+
if (
97+
experiment_invalid_reason := validate_experiment_has_trials(
98+
experiment=none_throws(experiment),
99+
required_metric_names=[metric_name],
100+
# Any trial indices and statuses will do since we use all data here
101+
trial_indices=None,
102+
trial_statuses=None,
103+
)
104+
) is not None:
105+
return experiment_invalid_reason
106+
107+
try:
108+
relevant_adapter = extract_relevant_adapter(
109+
experiment=experiment,
110+
generation_strategy=generation_strategy,
111+
adapter=adapter,
112+
)
113+
114+
if not isinstance(relevant_adapter, TorchAdapter):
115+
return f"TorchAdapter is required, found {type(relevant_adapter)}."
116+
except UserInputError as e:
117+
return e.message
79118

80119
@override
81120
def compute(
@@ -93,15 +132,27 @@ def compute(
93132
# Process the sensitivity analysis card to find the top K surfaces which
94133
# consist exclusively of tunable parameters (i.e. no fixed parameters, task
95134
# parameters, or OneHot parameters).
96-
sensitivity_analysis_card = SensitivityAnalysisPlot(
135+
maybe_sensitivity_analysis_card = SensitivityAnalysisPlot(
97136
metric_name=metric_name,
98137
order=self.order,
99138
top_k=self.top_k,
100-
).compute(
139+
).compute_result(
101140
experiment=experiment,
102141
generation_strategy=generation_strategy,
103142
adapter=adapter,
104143
)
144+
145+
if maybe_sensitivity_analysis_card.is_err():
146+
err = none_throws(maybe_sensitivity_analysis_card.err)
147+
raise err.exception or RuntimeError(
148+
"Failed to compute SensitivityAnalysisPlot"
149+
f"({metric_name=}, {self.order=}, {self.top_k=})"
150+
)
151+
152+
sensitivity_analysis_card = assert_is_instance(
153+
maybe_sensitivity_analysis_card.ok, PlotlyAnalysisCard
154+
)
155+
105156
children: list[AnalysisCardBase] = [sensitivity_analysis_card]
106157

107158
sensitivity_df = sensitivity_analysis_card.df.copy()

0 commit comments

Comments
 (0)