29
29
from ax .generation_strategy .generation_strategy import GenerationStrategy
30
30
from ax .modelbridge .registry import Generators
31
31
from ax .modelbridge .torch import TorchAdapter
32
+ from ax .modelbridge .transforms .one_hot import OH_PARAM_INFIX
32
33
from ax .models .torch .botorch_modular .surrogate import Surrogate
33
34
from ax .utils .common .logger import get_logger
34
35
from ax .utils .sensitivity .sobol_measures import ax_parameter_sens
39
40
from gpytorch .priors import LogNormalPrior
40
41
from plotly import express as px , graph_objects as go
41
42
from plotly .subplots import make_subplots
42
- from pyre_extensions import assert_is_instance
43
+ from pyre_extensions import assert_is_instance , none_throws
43
44
44
45
logger : Logger = get_logger (__name__ )
45
46
47
+ DISPLAY_SAMPLED_THRESHOLD : int = 50
48
+
46
49
47
50
class InteractionPlot (PlotlyAnalysis ):
48
51
"""
@@ -63,6 +66,7 @@ def __init__(
63
66
metric_name : str | None = None ,
64
67
fit_interactions : bool = True ,
65
68
most_important : bool = True ,
69
+ use_oak_model : bool = False ,
66
70
seed : int = 0 ,
67
71
torch_device : torch .device | None = None ,
68
72
) -> None :
@@ -74,6 +78,8 @@ def __init__(
74
78
most_important: Whether to sort by most or least important features in the
75
79
bar subplot. Also controls whether the six most or least important
76
80
features are plotted in the surface subplots.
81
+ use_oak_model: Whether to use an OAK model for the analysis. If False, use
82
+ Adapter from the current GenerationNode.
77
83
seed: The seed with which to fit the model. Defaults to 0. Used
78
84
to ensure that the model fit is identical across the generation of
79
85
various plots.
@@ -83,6 +89,7 @@ def __init__(
83
89
self .metric_name = metric_name
84
90
self .fit_interactions = fit_interactions
85
91
self .most_important = most_important
92
+ self .use_oak_model = use_oak_model
86
93
self .seed = seed
87
94
self .torch_device = torch_device
88
95
@@ -103,26 +110,55 @@ def compute(
103
110
if experiment is None :
104
111
raise UserInputError ("InteractionPlot requires an Experiment" )
105
112
113
+ if generation_strategy is None and not self .use_oak_model :
114
+ raise UserInputError (
115
+ "InteractionPlot requires a GenerationStrategy when use_oak_model is "
116
+ "False"
117
+ )
118
+
106
119
metric_name = self .metric_name or select_metric (experiment = experiment )
107
120
108
121
# Fix the seed to ensure that the model is fit identically across different
109
122
# analyses of the same experiment.
110
123
with torch .random .fork_rng ():
111
124
torch .torch .manual_seed (self .seed )
112
125
113
- # Fit the OAK model.
114
- oak_model = self ._get_oak_model (
115
- experiment = experiment , metric_name = metric_name
116
- )
126
+ if self .use_oak_model :
127
+ adapter = self ._get_oak_model (
128
+ experiment = experiment , metric_name = metric_name
129
+ )
130
+ else :
131
+ gs = none_throws (generation_strategy )
132
+ if gs .model is None :
133
+ gs ._fit_current_model (None )
117
134
118
- # Calculate first- or second-order Sobol indices.
119
- sens = ax_parameter_sens (
120
- model_bridge = oak_model ,
121
- metrics = [metric_name ],
122
- order = "second" if self .fit_interactions else "first" ,
123
- signed = not self .fit_interactions ,
124
- )[metric_name ]
135
+ adapter = assert_is_instance (gs .model , TorchAdapter )
125
136
137
+ try :
138
+ # Calculate first- or second-order Sobol indices.
139
+ sens = ax_parameter_sens (
140
+ model_bridge = adapter ,
141
+ metrics = [metric_name ],
142
+ order = "second" if self .fit_interactions else "first" ,
143
+ signed = not self .fit_interactions ,
144
+ )[metric_name ]
145
+ except Exception as e :
146
+ logger .exception (
147
+ f"Failed to compute sensitivity analysis with { e } . Falling back "
148
+ "on the surrogate model's feature importances."
149
+ )
150
+
151
+ sens = {
152
+ metric_name : adapter .feature_importances (metric_name )
153
+ for metric_name in adapter .metric_names
154
+ }
155
+ # Filter out an parameters that have been added to the search space via one-hot
156
+ # encoding -- these make the sensitivity analysis less interpretable and break
157
+ # the surface plots.
158
+ # TODO: Do something more principled here.
159
+ sens = {k : v for k , v in sens .items () if OH_PARAM_INFIX not in k }
160
+
161
+ # Create a DataFrame with the sensitivity analysis.
126
162
sensitivity_df = pd .DataFrame (
127
163
[* sens .items ()], columns = ["feature" , "sensitivity" ]
128
164
).sort_values (by = "sensitivity" , key = abs , ascending = self .most_important )
@@ -138,13 +174,16 @@ def compute(
138
174
by = "sensitivity" , ascending = self .most_important , inplace = True
139
175
)
140
176
177
+ plotly_blue = px .colors .qualitative .Plotly [0 ]
178
+ plotly_orange = px .colors .qualitative .Plotly [4 ]
179
+
141
180
sensitivity_fig = px .bar (
142
181
plotting_df ,
143
182
x = "sensitivity" ,
144
183
y = "feature" ,
145
184
color = "direction" ,
146
185
# Increase gets blue, decrease gets orange.
147
- color_discrete_sequence = ["orange" , "blue" ],
186
+ color_discrete_sequence = [plotly_blue , plotly_orange ],
148
187
orientation = "h" ,
149
188
)
150
189
@@ -158,7 +197,7 @@ def compute(
158
197
surface_figs .append (
159
198
_prepare_surface_plot (
160
199
experiment = experiment ,
161
- model = oak_model ,
200
+ model = adapter ,
162
201
feature_name = feature_name ,
163
202
metric_name = metric_name ,
164
203
)
@@ -245,16 +284,18 @@ def compute(
245
284
width = 1000 ,
246
285
)
247
286
248
- subtitle_substring = (
249
- "one- or two-dimensional" if self .fit_interactions else "one-dimensional"
250
- )
287
+ subtitle_substring = ", or pairs of parameters" if self .fit_interactions else ""
251
288
252
289
return self ._create_plotly_analysis_card (
253
290
title = f"Interaction Analysis for { metric_name } " ,
254
291
subtitle = (
255
- f"Understand an Experiment's data as { subtitle_substring } additive "
256
- "components with sparsity. Important components are visualized through "
257
- "slice or contour plots"
292
+ f"Understand how changes to your parameters affect { metric_name } . "
293
+ f"Parameters{ subtitle_substring } which rank higher here explain more "
294
+ f"of the observed variation in { metric_name } . The direction of the "
295
+ "effect is indicated by the color of the bar plot. Additionally, the "
296
+ "six most important parameters are visualized through surface plots "
297
+ f"which show the predicted outcomes for { metric_name } as a function "
298
+ "of the plotted parameters with the other parameters held fixed."
258
299
),
259
300
level = AnalysisCardLevel .MID ,
260
301
df = sensitivity_df ,
@@ -333,7 +374,7 @@ def _prepare_surface_plot(
333
374
log_y = is_axis_log_scale (
334
375
parameter = experiment .search_space .parameters [y_parameter_name ]
335
376
),
336
- display_sampled = True ,
377
+ display_sampled = df [ "sampled" ]. sum () <= DISPLAY_SAMPLED_THRESHOLD ,
337
378
)
338
379
339
380
# If the feature is a first-order component, plot a slice plot.
@@ -351,5 +392,5 @@ def _prepare_surface_plot(
351
392
log_x = is_axis_log_scale (
352
393
parameter = experiment .search_space .parameters [feature_name ]
353
394
),
354
- display_sampled = True ,
395
+ display_sampled = df [ "sampled" ]. sum () <= DISPLAY_SAMPLED_THRESHOLD ,
355
396
)
0 commit comments