Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.adapter.base import Adapter
from ax.adapter.cross_validation import cross_validate, CVResult
from ax.analysis.analysis import Analysis
from ax.analysis.healthcheck.predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD
from ax.analysis.plotly.color_constants import AX_BLUE
from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card
from ax.analysis.plotly.utils import get_scatter_point_color, Z_SCORE_95_CI
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
self.untransform = untransform
self.trial_index = trial_index
self.labels: dict[str, str] = {**labels} if labels is not None else {}
self._r2s: dict[str, float] = {}

@override
def validate_applicable_state(
Expand Down Expand Up @@ -144,6 +146,7 @@ def compute(
relevant_adapter._experiment.signature_to_metric[signature].name
for signature in relevant_adapter._metric_signatures
]
self._r2s = {}
for metric_name in self.metric_names or relevant_adapter_metric_names:
df = _prepare_data(
metric_name=metric_name, cv_results=cv_results, adapter=relevant_adapter
Expand All @@ -162,6 +165,7 @@ def compute(
y_obs=df["observed"].to_numpy(),
y_pred=df["predicted"].to_numpy(),
)
self._r2s[metric_title] = r_squared

# Define the cross-validation description based on the number of folds
cv_description = (
Expand Down Expand Up @@ -202,6 +206,50 @@ def compute(

cards.append(card)

# Create a summary table of R2 values for all metrics
if self._r2s:
threshold = DEFAULT_MODEL_FIT_THRESHOLD
metric_names_list = list(self._r2s.keys())
r2_values = [f"{v:.2f}" for v in self._r2s.values()]
fill_colors = [
"rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white"
for r2 in self._r2s.values()
]
r2_fig = go.Figure(
data=[
go.Table(
columnwidth=[4, 1],
header={
"values": ["Metric", "R\u00b2"],
"align": "left",
},
cells={
"values": [metric_names_list, r2_values],
"align": "left",
"fill_color": [fill_colors, fill_colors],
},
)
]
)
r2_card = create_plotly_analysis_card(
name=self.__class__.__name__,
title="Summary of model fits",
subtitle=(
"R\u00b2 (coefficient of determination) measures how well"
" the model predicts each metric. Higher values indicate"
" better model fit. Metrics with R\u00b2 >="
f" {threshold} are highlighted in green."
),
df=pd.DataFrame(
{
"Metric": metric_names_list,
"R\u00b2": list(self._r2s.values()),
}
),
fig=r2_fig,
)
cards.append(r2_card)

return self._create_analysis_card_group(
title=CV_CARDGROUP_TITLE,
subtitle=CV_CARDGROUP_SUBTITLE,
Expand Down
23 changes: 19 additions & 4 deletions ax/analysis/plotly/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ def test_compute(self, mock_r2: mock.Mock) -> None:
):
analysis.compute()

(card,) = analysis.compute(
cards = analysis.compute(
generation_strategy=self.client.generation_strategy
).flatten()
# Should have the CV plot card and the R2 summary card
self.assertEqual(len(cards), 2)
card = cards[0]
self.assertEqual(
card.name,
"CrossValidationPlot",
Expand Down Expand Up @@ -106,6 +109,15 @@ def test_compute(self, mock_r2: mock.Mock) -> None:
)
self.assertIsNotNone(card.blob)

# Assert that _r2s is populated after compute
self.assertIn("bar", analysis._r2s)
self.assertAlmostEqual(analysis._r2s["bar"], 0.85)

# Assert the R2 summary card
r2_card = cards[1]
self.assertEqual(r2_card.name, "CrossValidationPlot")
self.assertEqual(r2_card.title, "Summary of model fits")

# Assert that all arms are in the cross validation df
# because trial index is not specified
for t in self.client.experiment.trials.values():
Expand All @@ -121,9 +133,10 @@ def test_compute(self, mock_r2: mock.Mock) -> None:

def test_it_can_specify_trial_index_correctly(self) -> None:
analysis = CrossValidationPlot(metric_names=["bar"], trial_index=9)
(card,) = analysis.compute(
cards = analysis.compute(
generation_strategy=self.client.generation_strategy
).flatten()
card = cards[0]
for t in self.client.experiment.trials.values():
# Skip the last trial because the model was used to generate it
# and therefore hasn't observed it
Expand Down Expand Up @@ -159,15 +172,17 @@ def test_compute_adhoc(self, mock_r2: mock.Mock) -> None:
cards = compute_cross_validation_adhoc(
adapter=adapter, labels=metric_mapping
).flatten()
self.assertEqual(len(cards), 2)
self.assertEqual(len(cards), 3)
titles = {
"Cross Validation for spunky (R\u00b2 = 0.85)",
"Cross Validation for foo2 (R\u00b2 = 0.85)",
}
for card in cards:
for card in cards[:2]:
self.assertEqual(card.name, "CrossValidationPlot")
self.assertIn(card.title, titles)
titles.remove(card.title)
# The last card is the R2 summary
self.assertEqual(cards[2].title, "Summary of model fits")

@TestCase.ax_long_test(
reason=(
Expand Down