Skip to content

Commit

Permalink
Improve robustness in InteractionAnalysis
Browse files Browse the repository at this point in the history
Summary:
A number of features to improve reliablity of the InteractionAnalysis.
* Hide OAK kernel behind a flag which defaults to False. When false, use the current GenerationNode's adapter
* If ax_parameter_sens fails log an exception and fallback to the surrogate's feature_importances
* Do not plot samples on the slice and countour plots if there are more than 50 samples (it gets too cluttered)
* Changed the orange and blue colors on the importance bar chart to be in the plotly color scheme
* Make plot not error out on unordered choice params
* Improved subtitle

Differential Revision: D69993111
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Feb 21, 2025
1 parent 5135c83 commit 53a1945
Showing 1 changed file with 63 additions and 22 deletions.
85 changes: 63 additions & 22 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Generators
from ax.modelbridge.torch import TorchAdapter
from ax.modelbridge.transforms.one_hot import OH_PARAM_INFIX
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.utils.common.logger import get_logger
from ax.utils.sensitivity.sobol_measures import ax_parameter_sens
Expand All @@ -39,10 +40,12 @@
from gpytorch.priors import LogNormalPrior
from plotly import express as px, graph_objects as go
from plotly.subplots import make_subplots
from pyre_extensions import assert_is_instance
from pyre_extensions import assert_is_instance, none_throws

logger: Logger = get_logger(__name__)

DISPLAY_SAMPLED_THRESHOLD: int = 50


class InteractionPlot(PlotlyAnalysis):
"""
Expand All @@ -63,6 +66,7 @@ def __init__(
metric_name: str | None = None,
fit_interactions: bool = True,
most_important: bool = True,
use_oak_model: bool = False,
seed: int = 0,
torch_device: torch.device | None = None,
) -> None:
Expand All @@ -74,6 +78,8 @@ def __init__(
most_important: Whether to sort by most or least important features in the
bar subplot. Also controls whether the six most or least important
features are plotted in the surface subplots.
use_oak_model: Whether to use an OAK model for the analysis. If False, use
Adapter from the current GenerationNode.
seed: The seed with which to fit the model. Defaults to 0. Used
to ensure that the model fit is identical across the generation of
various plots.
Expand All @@ -83,6 +89,7 @@ def __init__(
self.metric_name = metric_name
self.fit_interactions = fit_interactions
self.most_important = most_important
self.use_oak_model = use_oak_model
self.seed = seed
self.torch_device = torch_device

Expand All @@ -103,26 +110,55 @@ def compute(
if experiment is None:
raise UserInputError("InteractionPlot requires an Experiment")

if generation_strategy is None and not self.use_oak_model:
raise UserInputError(
"InteractionPlot requires a GenerationStrategy when use_oak_model is "
"False"
)

metric_name = self.metric_name or select_metric(experiment=experiment)

# Fix the seed to ensure that the model is fit identically across different
# analyses of the same experiment.
with torch.random.fork_rng():
torch.torch.manual_seed(self.seed)

# Fit the OAK model.
oak_model = self._get_oak_model(
experiment=experiment, metric_name=metric_name
)
if self.use_oak_model:
adapter = self._get_oak_model(
experiment=experiment, metric_name=metric_name
)
else:
gs = none_throws(generation_strategy)
if gs.model is None:
gs._fit_current_model(None)

# Calculate first- or second-order Sobol indices.
sens = ax_parameter_sens(
model_bridge=oak_model,
metrics=[metric_name],
order="second" if self.fit_interactions else "first",
signed=not self.fit_interactions,
)[metric_name]
adapter = assert_is_instance(gs.model, TorchAdapter)

try:
# Calculate first- or second-order Sobol indices.
sens = ax_parameter_sens(
model_bridge=adapter,
metrics=[metric_name],
order="second" if self.fit_interactions else "first",
signed=not self.fit_interactions,
)[metric_name]
except Exception as e:
logger.exception(
f"Failed to compute sensitivity analysis with {e}. Falling back "
"on the surrogate model's feature importances."
)

sens = {
metric_name: adapter.feature_importances(metric_name)
for metric_name in adapter.metric_names
}
# Filter out an parameters that have been added to the search space via one-hot
# encoding -- these make the sensitivity analysis less interpretable and break
# the surface plots.
# TODO: Do something more principled here.
sens = {k: v for k, v in sens.items() if OH_PARAM_INFIX not in k}

# Create a DataFrame with the sensitivity analysis.
sensitivity_df = pd.DataFrame(
[*sens.items()], columns=["feature", "sensitivity"]
).sort_values(by="sensitivity", key=abs, ascending=self.most_important)
Expand All @@ -138,13 +174,16 @@ def compute(
by="sensitivity", ascending=self.most_important, inplace=True
)

plotly_blue = px.colors.qualitative.Plotly[0]
plotly_orange = px.colors.qualitative.Plotly[4]

sensitivity_fig = px.bar(
plotting_df,
x="sensitivity",
y="feature",
color="direction",
# Increase gets blue, decrease gets orange.
color_discrete_sequence=["orange", "blue"],
color_discrete_sequence=[plotly_blue, plotly_orange],
orientation="h",
)

Expand All @@ -158,7 +197,7 @@ def compute(
surface_figs.append(
_prepare_surface_plot(
experiment=experiment,
model=oak_model,
model=adapter,
feature_name=feature_name,
metric_name=metric_name,
)
Expand Down Expand Up @@ -245,16 +284,18 @@ def compute(
width=1000,
)

subtitle_substring = (
"one- or two-dimensional" if self.fit_interactions else "one-dimensional"
)
subtitle_substring = ", or pairs of parameters" if self.fit_interactions else ""

return self._create_plotly_analysis_card(
title=f"Interaction Analysis for {metric_name}",
subtitle=(
f"Understand an Experiment's data as {subtitle_substring} additive "
"components with sparsity. Important components are visualized through "
"slice or contour plots"
f"Understand how changes to your parameters affect {metric_name}. "
f"Parameters{subtitle_substring} which rank higher here explain more "
f"of the observed variation in {metric_name}. The direction of the "
"effect is indicated by the color of the bar plot. Additionally, the "
"six most important parameters are visualized through surface plots "
f"which show the predicted outcomes for {metric_name} as a function "
"of the plotted parameters with the other parameters held fixed."
),
level=AnalysisCardLevel.MID,
df=sensitivity_df,
Expand Down Expand Up @@ -333,7 +374,7 @@ def _prepare_surface_plot(
log_y=is_axis_log_scale(
parameter=experiment.search_space.parameters[y_parameter_name]
),
display_sampled=True,
display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD,
)

# If the feature is a first-order component, plot a slice plot.
Expand All @@ -351,5 +392,5 @@ def _prepare_surface_plot(
log_x=is_axis_log_scale(
parameter=experiment.search_space.parameters[feature_name]
),
display_sampled=True,
display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD,
)

0 comments on commit 53a1945

Please sign in to comment.