Skip to content

Commit 53a1945

Browse files
mpolson64facebook-github-bot
authored andcommitted
Improve robustness in InteractionAnalysis
Summary: A number of features to improve reliablity of the InteractionAnalysis. * Hide OAK kernel behind a flag which defaults to False. When false, use the current GenerationNode's adapter * If ax_parameter_sens fails log an exception and fallback to the surrogate's feature_importances * Do not plot samples on the slice and countour plots if there are more than 50 samples (it gets too cluttered) * Changed the orange and blue colors on the importance bar chart to be in the plotly color scheme * Make plot not error out on unordered choice params * Improved subtitle Differential Revision: D69993111
1 parent 5135c83 commit 53a1945

File tree

1 file changed

+63
-22
lines changed

1 file changed

+63
-22
lines changed

ax/analysis/plotly/interaction.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ax.generation_strategy.generation_strategy import GenerationStrategy
3030
from ax.modelbridge.registry import Generators
3131
from ax.modelbridge.torch import TorchAdapter
32+
from ax.modelbridge.transforms.one_hot import OH_PARAM_INFIX
3233
from ax.models.torch.botorch_modular.surrogate import Surrogate
3334
from ax.utils.common.logger import get_logger
3435
from ax.utils.sensitivity.sobol_measures import ax_parameter_sens
@@ -39,10 +40,12 @@
3940
from gpytorch.priors import LogNormalPrior
4041
from plotly import express as px, graph_objects as go
4142
from plotly.subplots import make_subplots
42-
from pyre_extensions import assert_is_instance
43+
from pyre_extensions import assert_is_instance, none_throws
4344

4445
logger: Logger = get_logger(__name__)
4546

47+
DISPLAY_SAMPLED_THRESHOLD: int = 50
48+
4649

4750
class InteractionPlot(PlotlyAnalysis):
4851
"""
@@ -63,6 +66,7 @@ def __init__(
6366
metric_name: str | None = None,
6467
fit_interactions: bool = True,
6568
most_important: bool = True,
69+
use_oak_model: bool = False,
6670
seed: int = 0,
6771
torch_device: torch.device | None = None,
6872
) -> None:
@@ -74,6 +78,8 @@ def __init__(
7478
most_important: Whether to sort by most or least important features in the
7579
bar subplot. Also controls whether the six most or least important
7680
features are plotted in the surface subplots.
81+
use_oak_model: Whether to use an OAK model for the analysis. If False, use
82+
Adapter from the current GenerationNode.
7783
seed: The seed with which to fit the model. Defaults to 0. Used
7884
to ensure that the model fit is identical across the generation of
7985
various plots.
@@ -83,6 +89,7 @@ def __init__(
8389
self.metric_name = metric_name
8490
self.fit_interactions = fit_interactions
8591
self.most_important = most_important
92+
self.use_oak_model = use_oak_model
8693
self.seed = seed
8794
self.torch_device = torch_device
8895

@@ -103,26 +110,55 @@ def compute(
103110
if experiment is None:
104111
raise UserInputError("InteractionPlot requires an Experiment")
105112

113+
if generation_strategy is None and not self.use_oak_model:
114+
raise UserInputError(
115+
"InteractionPlot requires a GenerationStrategy when use_oak_model is "
116+
"False"
117+
)
118+
106119
metric_name = self.metric_name or select_metric(experiment=experiment)
107120

108121
# Fix the seed to ensure that the model is fit identically across different
109122
# analyses of the same experiment.
110123
with torch.random.fork_rng():
111124
torch.torch.manual_seed(self.seed)
112125

113-
# Fit the OAK model.
114-
oak_model = self._get_oak_model(
115-
experiment=experiment, metric_name=metric_name
116-
)
126+
if self.use_oak_model:
127+
adapter = self._get_oak_model(
128+
experiment=experiment, metric_name=metric_name
129+
)
130+
else:
131+
gs = none_throws(generation_strategy)
132+
if gs.model is None:
133+
gs._fit_current_model(None)
117134

118-
# Calculate first- or second-order Sobol indices.
119-
sens = ax_parameter_sens(
120-
model_bridge=oak_model,
121-
metrics=[metric_name],
122-
order="second" if self.fit_interactions else "first",
123-
signed=not self.fit_interactions,
124-
)[metric_name]
135+
adapter = assert_is_instance(gs.model, TorchAdapter)
125136

137+
try:
138+
# Calculate first- or second-order Sobol indices.
139+
sens = ax_parameter_sens(
140+
model_bridge=adapter,
141+
metrics=[metric_name],
142+
order="second" if self.fit_interactions else "first",
143+
signed=not self.fit_interactions,
144+
)[metric_name]
145+
except Exception as e:
146+
logger.exception(
147+
f"Failed to compute sensitivity analysis with {e}. Falling back "
148+
"on the surrogate model's feature importances."
149+
)
150+
151+
sens = {
152+
metric_name: adapter.feature_importances(metric_name)
153+
for metric_name in adapter.metric_names
154+
}
155+
# Filter out an parameters that have been added to the search space via one-hot
156+
# encoding -- these make the sensitivity analysis less interpretable and break
157+
# the surface plots.
158+
# TODO: Do something more principled here.
159+
sens = {k: v for k, v in sens.items() if OH_PARAM_INFIX not in k}
160+
161+
# Create a DataFrame with the sensitivity analysis.
126162
sensitivity_df = pd.DataFrame(
127163
[*sens.items()], columns=["feature", "sensitivity"]
128164
).sort_values(by="sensitivity", key=abs, ascending=self.most_important)
@@ -138,13 +174,16 @@ def compute(
138174
by="sensitivity", ascending=self.most_important, inplace=True
139175
)
140176

177+
plotly_blue = px.colors.qualitative.Plotly[0]
178+
plotly_orange = px.colors.qualitative.Plotly[4]
179+
141180
sensitivity_fig = px.bar(
142181
plotting_df,
143182
x="sensitivity",
144183
y="feature",
145184
color="direction",
146185
# Increase gets blue, decrease gets orange.
147-
color_discrete_sequence=["orange", "blue"],
186+
color_discrete_sequence=[plotly_blue, plotly_orange],
148187
orientation="h",
149188
)
150189

@@ -158,7 +197,7 @@ def compute(
158197
surface_figs.append(
159198
_prepare_surface_plot(
160199
experiment=experiment,
161-
model=oak_model,
200+
model=adapter,
162201
feature_name=feature_name,
163202
metric_name=metric_name,
164203
)
@@ -245,16 +284,18 @@ def compute(
245284
width=1000,
246285
)
247286

248-
subtitle_substring = (
249-
"one- or two-dimensional" if self.fit_interactions else "one-dimensional"
250-
)
287+
subtitle_substring = ", or pairs of parameters" if self.fit_interactions else ""
251288

252289
return self._create_plotly_analysis_card(
253290
title=f"Interaction Analysis for {metric_name}",
254291
subtitle=(
255-
f"Understand an Experiment's data as {subtitle_substring} additive "
256-
"components with sparsity. Important components are visualized through "
257-
"slice or contour plots"
292+
f"Understand how changes to your parameters affect {metric_name}. "
293+
f"Parameters{subtitle_substring} which rank higher here explain more "
294+
f"of the observed variation in {metric_name}. The direction of the "
295+
"effect is indicated by the color of the bar plot. Additionally, the "
296+
"six most important parameters are visualized through surface plots "
297+
f"which show the predicted outcomes for {metric_name} as a function "
298+
"of the plotted parameters with the other parameters held fixed."
258299
),
259300
level=AnalysisCardLevel.MID,
260301
df=sensitivity_df,
@@ -333,7 +374,7 @@ def _prepare_surface_plot(
333374
log_y=is_axis_log_scale(
334375
parameter=experiment.search_space.parameters[y_parameter_name]
335376
),
336-
display_sampled=True,
377+
display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD,
337378
)
338379

339380
# If the feature is a first-order component, plot a slice plot.
@@ -351,5 +392,5 @@ def _prepare_surface_plot(
351392
log_x=is_axis_log_scale(
352393
parameter=experiment.search_space.parameters[feature_name]
353394
),
354-
display_sampled=True,
395+
display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD,
355396
)

0 commit comments

Comments
 (0)