Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve robustness in InteractionAnalysis #3407

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Generators
from ax.modelbridge.torch import TorchAdapter
from ax.modelbridge.transforms.one_hot import OH_PARAM_INFIX
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.utils.common.logger import get_logger
from ax.utils.sensitivity.sobol_measures import ax_parameter_sens
Expand All @@ -39,10 +40,12 @@
from gpytorch.priors import LogNormalPrior
from plotly import express as px, graph_objects as go
from plotly.subplots import make_subplots
from pyre_extensions import assert_is_instance
from pyre_extensions import assert_is_instance, none_throws

logger: Logger = get_logger(__name__)

DISPLAY_SAMPLED_THRESHOLD: int = 50


class InteractionPlot(PlotlyAnalysis):
"""
Expand All @@ -63,6 +66,7 @@ def __init__(
metric_name: str | None = None,
fit_interactions: bool = True,
most_important: bool = True,
use_oak_model: bool = False,
seed: int = 0,
torch_device: torch.device | None = None,
) -> None:
Expand All @@ -74,6 +78,8 @@ def __init__(
most_important: Whether to sort by most or least important features in the
bar subplot. Also controls whether the six most or least important
features are plotted in the surface subplots.
use_oak_model: Whether to use an OAK model for the analysis. If False, use
Adapter from the current GenerationNode.
seed: The seed with which to fit the model. Defaults to 0. Used
to ensure that the model fit is identical across the generation of
various plots.
Expand All @@ -83,6 +89,7 @@ def __init__(
self.metric_name = metric_name
self.fit_interactions = fit_interactions
self.most_important = most_important
self.use_oak_model = use_oak_model
self.seed = seed
self.torch_device = torch_device

Expand All @@ -103,26 +110,55 @@ def compute(
if experiment is None:
raise UserInputError("InteractionPlot requires an Experiment")

if generation_strategy is None and not self.use_oak_model:
raise UserInputError(
"InteractionPlot requires a GenerationStrategy when use_oak_model is "
"False"
)

metric_name = self.metric_name or select_metric(experiment=experiment)

# Fix the seed to ensure that the model is fit identically across different
# analyses of the same experiment.
with torch.random.fork_rng():
torch.torch.manual_seed(self.seed)

# Fit the OAK model.
oak_model = self._get_oak_model(
experiment=experiment, metric_name=metric_name
)
if self.use_oak_model:
adapter = self._get_oak_model(
experiment=experiment, metric_name=metric_name
)
else:
gs = none_throws(generation_strategy)
if gs.model is None:
gs._fit_current_model(None)

# Calculate first- or second-order Sobol indices.
sens = ax_parameter_sens(
model_bridge=oak_model,
metrics=[metric_name],
order="second" if self.fit_interactions else "first",
signed=not self.fit_interactions,
)[metric_name]
adapter = assert_is_instance(gs.model, TorchAdapter)

try:
# Calculate first- or second-order Sobol indices.
sens = ax_parameter_sens(
model_bridge=adapter,
metrics=[metric_name],
order="second" if self.fit_interactions else "first",
signed=not self.fit_interactions,
)[metric_name]
except Exception as e:
logger.exception(
f"Failed to compute sensitivity analysis with {e}. Falling back "
"on the surrogate model's feature importances."
)

sens = {
metric_name: adapter.feature_importances(metric_name)
for metric_name in adapter.metric_names
}
# Filter out an parameters that have been added to the search space via one-hot
# encoding -- these make the sensitivity analysis less interpretable and break
# the surface plots.
# TODO: Do something more principled here.
sens = {k: v for k, v in sens.items() if OH_PARAM_INFIX not in k}

# Create a DataFrame with the sensitivity analysis.
sensitivity_df = pd.DataFrame(
[*sens.items()], columns=["feature", "sensitivity"]
).sort_values(by="sensitivity", key=abs, ascending=self.most_important)
Expand All @@ -138,13 +174,16 @@ def compute(
by="sensitivity", ascending=self.most_important, inplace=True
)

plotly_blue = px.colors.qualitative.Plotly[0]
plotly_orange = px.colors.qualitative.Plotly[4]

sensitivity_fig = px.bar(
plotting_df,
x="sensitivity",
y="feature",
color="direction",
# Increase gets blue, decrease gets orange.
color_discrete_sequence=["orange", "blue"],
color_discrete_sequence=[plotly_blue, plotly_orange],
orientation="h",
)

Expand All @@ -158,7 +197,7 @@ def compute(
surface_figs.append(
_prepare_surface_plot(
experiment=experiment,
model=oak_model,
model=adapter,
feature_name=feature_name,
metric_name=metric_name,
)
Expand Down Expand Up @@ -245,16 +284,18 @@ def compute(
width=1000,
)

subtitle_substring = (
"one- or two-dimensional" if self.fit_interactions else "one-dimensional"
)
subtitle_substring = ", or pairs of parameters" if self.fit_interactions else ""

return self._create_plotly_analysis_card(
title=f"Interaction Analysis for {metric_name}",
subtitle=(
f"Understand an Experiment's data as {subtitle_substring} additive "
"components with sparsity. Important components are visualized through "
"slice or contour plots"
f"Understand how changes to your parameters affect {metric_name}. "
f"Parameters{subtitle_substring} which rank higher here explain more "
f"of the observed variation in {metric_name}. The direction of the "
"effect is indicated by the color of the bar plot. Additionally, the "
"six most important parameters are visualized through surface plots "
f"which show the predicted outcomes for {metric_name} as a function "
"of the plotted parameters with the other parameters held fixed."
),
level=AnalysisCardLevel.MID,
df=sensitivity_df,
Expand Down Expand Up @@ -333,6 +374,7 @@ def _prepare_surface_plot(
log_y=is_axis_log_scale(
parameter=experiment.search_space.parameters[y_parameter_name]
),
display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD,
)

# If the feature is a first-order component, plot a slice plot.
Expand All @@ -350,4 +392,5 @@ def _prepare_surface_plot(
log_x=is_axis_log_scale(
parameter=experiment.search_space.parameters[feature_name]
),
display_sampled=df["sampled"].sum() <= DISPLAY_SAMPLED_THRESHOLD,
)
63 changes: 52 additions & 11 deletions ax/analysis/plotly/surface/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,23 @@ class ContourPlot(PlotlyAnalysis):
- PARAMETER_NAME: The value of the x parameter specified
- PARAMETER_NAME: The value of the y parameter specified
- METRIC_NAME: The predected mean of the metric specified
- sampled: Whether the parameter values were sampled in at least one trial
"""

def __init__(
self,
x_parameter_name: str,
y_parameter_name: str,
metric_name: str | None = None,
display_sampled: bool = True,
) -> None:
"""
Args:
y_parameter_name: The name of the parameter to plot on the x-axis.
y_parameter_name: The name of the parameter to plot on the y-axis.
metric_name: The name of the metric to plot
display_sampled: If True, plot "x"s at x coordinates which have been
sampled in at least one trial.
"""
# TODO: Add a flag to specify whether or not to plot markers at the (x, y)
# coordinates of arms (with hover text). This is fine to exlude for now because
Expand All @@ -57,6 +61,7 @@ def __init__(
self.x_parameter_name = x_parameter_name
self.y_parameter_name = y_parameter_name
self.metric_name = metric_name
self._display_sampled = display_sampled

def compute(
self,
Expand Down Expand Up @@ -93,6 +98,7 @@ def compute(
log_y=is_axis_log_scale(
parameter=experiment.search_space.parameters[self.y_parameter_name]
),
display_sampled=self._display_sampled,
)

return self._create_plotly_analysis_card(
Expand All @@ -116,14 +122,23 @@ def _prepare_data(
y_parameter_name: str,
metric_name: str,
) -> pd.DataFrame:
sampled = [
(arm.parameters[x_parameter_name], arm.parameters[y_parameter_name])
for trial in experiment.trials.values()
for arm in trial.arms
]

# Choose which parameter values to predict points for.
xs = get_parameter_values(
unsampled_xs = get_parameter_values(
parameter=experiment.search_space.parameters[x_parameter_name], density=10
)
ys = get_parameter_values(
unsampled_ys = get_parameter_values(
parameter=experiment.search_space.parameters[y_parameter_name], density=10
)

xs = [*[sample[0] for sample in sampled], *unsampled_xs]
ys = [*[sample[1] for sample in sampled], *unsampled_ys]

# Construct observation features for each parameter value previously chosen by
# fixing all other parameters to their status-quo value or mean.
features = [
Expand All @@ -147,15 +162,22 @@ def _prepare_data(

predictions = model.predict(observation_features=features)

return pd.DataFrame.from_records(
[
{
x_parameter_name: features[i].parameters[x_parameter_name],
y_parameter_name: features[i].parameters[y_parameter_name],
f"{metric_name}_mean": predictions[0][metric_name][i],
}
for i in range(len(features))
]
return none_throws(
pd.DataFrame.from_records(
[
{
x_parameter_name: features[i].parameters[x_parameter_name],
y_parameter_name: features[i].parameters[y_parameter_name],
f"{metric_name}_mean": predictions[0][metric_name][i],
"sampled": (
features[i].parameters[x_parameter_name],
features[i].parameters[y_parameter_name],
)
in sampled,
}
for i in range(len(features))
]
).drop_duplicates()
)


Expand All @@ -166,6 +188,7 @@ def _prepare_plot(
metric_name: str,
log_x: bool,
log_y: bool,
display_sampled: bool,
) -> go.Figure:
z_grid = df.pivot(
index=y_parameter_name, columns=x_parameter_name, values=f"{metric_name}_mean"
Expand All @@ -185,6 +208,24 @@ def _prepare_plot(
),
)

if display_sampled:
x_sampled = df[df["sampled"]][x_parameter_name].tolist()
y_sampled = df[df["sampled"]][y_parameter_name].tolist()

samples = go.Scatter(
x=x_sampled,
y=y_sampled,
mode="markers",
marker={
"symbol": "x",
"color": "black",
},
name="Sampled",
showlegend=False,
)

fig.add_trace(samples)

# Set the x-axis scale to log if relevant
if log_x:
fig.update_xaxes(
Expand Down
Loading
Loading