88from typing import final , Literal
99
1010from ax .adapter .base import Adapter
11+ from ax .adapter .torch import TorchAdapter
1112from ax .analysis .analysis import Analysis
1213from ax .analysis .analysis_card import (
1314 AnalysisCard ,
1415 AnalysisCardBase ,
1516 AnalysisCardGroup ,
1617 ErrorAnalysisCard ,
1718)
19+ from ax .analysis .plotly .plotly_analysis import PlotlyAnalysisCard
1820from ax .analysis .plotly .sensitivity import SensitivityAnalysisPlot
1921from ax .analysis .plotly .surface .contour import (
2022 CONTOUR_CARDGROUP_SUBTITLE ,
2729 SlicePlot ,
2830)
2931from ax .analysis .plotly .utils import select_metric
30- from ax .analysis .utils import validate_experiment
32+ from ax .analysis .utils import (
33+ extract_relevant_adapter ,
34+ validate_experiment ,
35+ validate_experiment_has_trials ,
36+ )
3137from ax .core .experiment import Experiment
38+ from ax .exceptions .core import UserInputError
3239from ax .generation_strategy .generation_strategy import GenerationStrategy
3340from pyre_extensions import assert_is_instance , none_throws , override
3441
@@ -68,14 +75,46 @@ def validate_applicable_state(
6875 adapter : Adapter | None = None ,
6976 ) -> str | None :
7077 """
71- TopSurfacesAnalysis requires an experiment with trials and data.
78+ TopSurfacesAnalysis requires an experiment with trials and data as well as
79+ a TorchAdapter.
7280 """
73- if self . metric_name is None :
74- return validate_experiment (
81+ if (
82+ experiment_invalid_reason := validate_experiment (
7583 experiment = experiment ,
7684 require_trials = True ,
7785 require_data = True ,
7886 )
87+ ) is not None :
88+ return experiment_invalid_reason
89+
90+ metric_name = (
91+ self .metric_name
92+ if self .metric_name is not None
93+ else select_metric (experiment = none_throws (experiment ))
94+ )
95+
96+ if (
97+ experiment_invalid_reason := validate_experiment_has_trials (
98+ experiment = none_throws (experiment ),
99+ required_metric_names = [metric_name ],
100+ # Any trial indices and statuses will do since we use all data here
101+ trial_indices = None ,
102+ trial_statuses = None ,
103+ )
104+ ) is not None :
105+ return experiment_invalid_reason
106+
107+ try :
108+ relevant_adapter = extract_relevant_adapter (
109+ experiment = experiment ,
110+ generation_strategy = generation_strategy ,
111+ adapter = adapter ,
112+ )
113+
114+ if not isinstance (relevant_adapter , TorchAdapter ):
115+ return f"TorchAdapter is required, found { type (relevant_adapter )} ."
116+ except UserInputError as e :
117+ return e .message
79118
80119 @override
81120 def compute (
@@ -93,15 +132,27 @@ def compute(
93132 # Process the sensitivity analysis card to find the top K surfaces which
94133 # consist exclusively of tunable parameters (i.e. no fixed parameters, task
95134 # parameters, or OneHot parameters).
96- sensitivity_analysis_card = SensitivityAnalysisPlot (
135+ maybe_sensitivity_analysis_card = SensitivityAnalysisPlot (
97136 metric_name = metric_name ,
98137 order = self .order ,
99138 top_k = self .top_k ,
100- ).compute (
139+ ).compute_result (
101140 experiment = experiment ,
102141 generation_strategy = generation_strategy ,
103142 adapter = adapter ,
104143 )
144+
145+ if maybe_sensitivity_analysis_card .is_err ():
146+ err = none_throws (maybe_sensitivity_analysis_card .err )
147+ raise err .exception or RuntimeError (
148+ "Failed to compute SensitivityAnalysisPlot"
149+ f"({ metric_name = } , { self .order = } , { self .top_k = } )"
150+ )
151+
152+ sensitivity_analysis_card = assert_is_instance (
153+ maybe_sensitivity_analysis_card .ok , PlotlyAnalysisCard
154+ )
155+
105156 children : list [AnalysisCardBase ] = [sensitivity_analysis_card ]
106157
107158 sensitivity_df = sensitivity_analysis_card .df .copy ()
0 commit comments