Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 140 additions & 56 deletions ax/analysis/plotly/arm_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.analysis.plotly.color_constants import BOTORCH_COLOR_SCALE
from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card
from ax.analysis.plotly.utils import (
generator_run_key_to_color,
get_arm_tooltip,
get_trial_statuses_with_fallback,
get_trial_trace_name,
Expand Down Expand Up @@ -42,6 +43,7 @@
from ax.core.arm import Arm
from ax.core.data import sort_by_trial_index_and_arm_name
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.trial_status import TrialStatus
from ax.generation_strategy.generation_strategy import GenerationStrategy
from plotly import graph_objects as go
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(
trial_index: int | None = None,
trial_statuses: Sequence[TrialStatus] | None = None,
additional_arms: Sequence[Arm] | None = None,
generator_runs: Mapping[str, GeneratorRun] | None = None,
label: str | None = None,
) -> None:
"""
Expand All @@ -99,6 +102,10 @@ def __init__(
additional_arms: If present, include these arms in the plot in addition to
the arms in the experiment. These arms will be marked as belonging to a
trial with index -1.
generator_runs: If present, a mapping from name to GeneratorRun. Each
GeneratorRun's arms will be plotted as a separate group with distinct
colors and legend entries. Unnamed arms will be labeled as
``{key}_0``, ``{key}_1``, etc.
label: A label to use in the plot in place of the metric name.
"""

Expand All @@ -112,6 +119,7 @@ def __init__(
)
)
self.additional_arms = additional_arms
self.generator_runs = generator_runs
self.label = label

@override
Expand Down Expand Up @@ -190,6 +198,7 @@ def compute(
trial_index=self.trial_index,
trial_statuses=self.trial_statuses,
additional_arms=self.additional_arms,
generator_runs=self.generator_runs,
relativize=self.relativize,
)

Expand Down Expand Up @@ -259,6 +268,7 @@ def compute_arm_effects_adhoc(
trial_index: int | None = None,
trial_statuses: Sequence[TrialStatus] | None = None,
additional_arms: Sequence[Arm] | None = None,
generator_runs: Mapping[str, GeneratorRun] | None = None,
labels: Mapping[str, str] | None = None,
) -> AnalysisCardGroup:
"""
Expand Down Expand Up @@ -303,6 +313,7 @@ def compute_arm_effects_adhoc(
trial_index=trial_index,
trial_statuses=trial_statuses,
additional_arms=additional_arms,
generator_runs=generator_runs,
label=labels.get(metric_name) if labels is not None else None,
).compute_or_error_card(
experiment=experiment,
Expand All @@ -318,6 +329,58 @@ def compute_arm_effects_adhoc(
)


def _build_scatter(
trial_df: pd.DataFrame,
metric_name: str,
is_relative: bool,
status_quo_arm_name: str | None,
color: str,
ci_color: str,
trace_name: str,
showlegend: bool,
legendgroup: str | None,
) -> tuple[go.Scatter, list[str], list[str]] | None:
"""Build a scatter trace for a group of arms.

Returns (scatter, arm_order_entries, arm_label_entries), or None if no
valid data points exist.
"""
xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()]
if xy_df.empty:
return None
if is_relative and status_quo_arm_name is not None:
xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name]
if xy_df.empty:
return None

if not trial_df[f"{metric_name}_sem"].isna().all():
error_y = {
"type": "data",
"array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"],
"color": ci_color,
}
else:
error_y = None

text = xy_df.apply(
lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1
)

scatter = go.Scatter(
x=xy_df["x_key_order"],
y=xy_df[f"{metric_name}_mean"],
error_y=error_y,
mode="markers",
marker={"color": color},
name=trace_name,
showlegend=showlegend,
hoverinfo="text",
text=text,
legendgroup=legendgroup,
)
return scatter, xy_df["x_key_order"].to_list(), xy_df["arm_name"].to_list()


def _prepare_figure(
df: pd.DataFrame,
metric_name: str,
Expand Down Expand Up @@ -354,76 +417,94 @@ def _prepare_figure(
num_non_candidate_trials = 0
candidate_trial_marker = None

# --- Trial loop ---
for trial_index in trial_indices:
trial_df = df[df["trial_index"] == trial_index]
xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()]
# Skip trials with no valid data points as they will not end up in the plot
if xy_df.empty:
trial_df = df[
(df["trial_index"] == trial_index) & (df["generator_run_key"].isna())
]
if trial_df.empty:
continue
if is_relative and status_quo_arm_name is not None:
# Exclude status quo arms from relativized plots, since arms are relative
# with respect to the status quo.
xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name]

arm_order = arm_order + xy_df["x_key_order"].to_list()
arm_label = arm_label + xy_df["arm_name"].to_list()
if not trial_df[f"{metric_name}_sem"].isna().all():
error_y = {
"type": "data",
"array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"],
"color": trial_index_to_color(
trial_df=trial_df,
trials_list=trials_list,
trial_index=trial_index,
transparent=True,
),
}
else:
error_y = None

marker = {
"color": trial_index_to_color(
trial_df=trial_df,
trials_list=trials_list,
trial_index=trial_index,
transparent=False,
),
}

if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name:
color = trial_index_to_color(
trial_df=trial_df,
trials_list=trials_list,
trial_index=trial_index,
transparent=False,
)
ci_color = trial_index_to_color(
trial_df=trial_df,
trials_list=trials_list,
trial_index=trial_index,
transparent=True,
)
is_candidate = (
not trial_df.empty
and trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name
)
result = _build_scatter(
trial_df=trial_df,
metric_name=metric_name,
is_relative=is_relative,
status_quo_arm_name=status_quo_arm_name,
color=color,
ci_color=ci_color,
trace_name=get_trial_trace_name(trial_index=trial_index),
showlegend=False, # Will be set after determining use_colorscale
legendgroup="candidate_trials" if is_candidate else None,
)
if result is None:
continue
scatter, order, labels = result
scatters.append(scatter)
scatter_trial_indices.append(trial_index)
arm_order += order
arm_label += labels
if is_candidate:
num_candidate_trials += 1
candidate_trial_marker = marker
candidate_trial_marker = {"color": color}
else:
num_non_candidate_trials += 1

text = xy_df.apply(
lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1
# --- Generator run loop ---
unique_gr_keys = df["generator_run_key"].dropna().unique().tolist()
generator_run_scatters: list[go.Scatter] = []
for gr_key in unique_gr_keys:
gr_df = df[df["generator_run_key"] == gr_key]
color = generator_run_key_to_color(
generator_run_key=gr_key,
all_generator_run_keys=unique_gr_keys,
transparent=False,
)

scatters.append(
go.Scatter(
x=xy_df["x_key_order"],
y=xy_df[f"{metric_name}_mean"],
error_y=error_y,
mode="markers",
marker=marker,
name=get_trial_trace_name(trial_index=trial_index),
showlegend=False, # Will be set after determining use_colorscale
hoverinfo="text",
text=text,
legendgroup="candidate_trials"
if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name
else None,
)
ci_color = generator_run_key_to_color(
generator_run_key=gr_key,
all_generator_run_keys=unique_gr_keys,
transparent=True,
)
scatter_trial_indices.append(trial_index)
result = _build_scatter(
trial_df=gr_df,
metric_name=metric_name,
is_relative=is_relative,
status_quo_arm_name=status_quo_arm_name,
color=color,
ci_color=ci_color,
trace_name=gr_key,
showlegend=True,
legendgroup=None,
)
if result is None:
continue
scatter, order, labels = result
generator_run_scatters.append(scatter)
arm_order += order
arm_label += labels

# Determine use_colorscale based on actual included trials
use_colorscale = num_non_candidate_trials > 10

# Update markers and legend settings based on use_colorscale
for scatter, trial_index in zip(scatters, scatter_trial_indices):
trial_df = df[df["trial_index"] == trial_index]
trial_df = df[
(df["trial_index"] == trial_index) & (df["generator_run_key"].isna())
]

if use_colorscale:
# Add colorscale settings to marker
Expand All @@ -449,6 +530,9 @@ def _prepare_figure(
trial_df["trial_status"].iloc[0] != TrialStatus.CANDIDATE.name
)

# Append generator run scatters (not subject to colorscale)
scatters.extend(generator_run_scatters)

# get the max length of x-ticker (arm name) to set the xaxis label and
# legend position
# This assumes the x-tickers are rotated 90 degrees (vertical) so legend
Expand Down
1 change: 1 addition & 0 deletions ax/analysis/plotly/color_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@
COLOR_FOR_DECREASES: str = METRIC_CONTINUOUS_COLOR_SCALE[2] # brown

DISCRETE_ARM_SCALE = px.colors.qualitative.Alphabet
GENERATOR_RUN_COLOR_SCALE: list[str] = px.colors.qualitative.Plotly
48 changes: 48 additions & 0 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.adapter.base import Adapter
from ax.adapter.cross_validation import cross_validate, CVResult
from ax.analysis.analysis import Analysis
from ax.analysis.healthcheck.predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD
from ax.analysis.plotly.color_constants import AX_BLUE
from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card
from ax.analysis.plotly.utils import get_scatter_point_color, Z_SCORE_95_CI
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
self.untransform = untransform
self.trial_index = trial_index
self.labels: dict[str, str] = {**labels} if labels is not None else {}
self._r2s: dict[str, float] = {}

@override
def validate_applicable_state(
Expand Down Expand Up @@ -144,6 +146,7 @@ def compute(
relevant_adapter._experiment.signature_to_metric[signature].name
for signature in relevant_adapter._metric_signatures
]
self._r2s = {}
for metric_name in self.metric_names or relevant_adapter_metric_names:
df = _prepare_data(
metric_name=metric_name, cv_results=cv_results, adapter=relevant_adapter
Expand All @@ -162,6 +165,7 @@ def compute(
y_obs=df["observed"].to_numpy(),
y_pred=df["predicted"].to_numpy(),
)
self._r2s[metric_title] = r_squared

# Define the cross-validation description based on the number of folds
cv_description = (
Expand Down Expand Up @@ -202,6 +206,50 @@ def compute(

cards.append(card)

# Create a summary table of R2 values for all metrics
if self._r2s:
threshold = DEFAULT_MODEL_FIT_THRESHOLD
metric_names_list = list(self._r2s.keys())
r2_values = [f"{v:.2f}" for v in self._r2s.values()]
fill_colors = [
"rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white"
for r2 in self._r2s.values()
]
r2_fig = go.Figure(
data=[
go.Table(
columnwidth=[4, 1],
header={
"values": ["Metric", "R\u00b2"],
"align": "left",
},
cells={
"values": [metric_names_list, r2_values],
"align": "left",
"fill_color": [fill_colors, fill_colors],
},
)
]
)
r2_card = create_plotly_analysis_card(
name=self.__class__.__name__,
title="Summary of model fits",
subtitle=(
"R\u00b2 (coefficient of determination) measures how well"
" the model predicts each metric. Higher values indicate"
" better model fit. Metrics with R\u00b2 >="
f" {threshold} are highlighted in green."
),
df=pd.DataFrame(
{
"Metric": metric_names_list,
"R\u00b2": list(self._r2s.values()),
}
),
fig=r2_fig,
)
cards.append(r2_card)

return self._create_analysis_card_group(
title=CV_CARDGROUP_TITLE,
subtitle=CV_CARDGROUP_SUBTITLE,
Expand Down
Loading