Skip to content

Commit e32c80c

Browse files
blethammeta-codesync[bot]
authored andcommitted
add metric prediction summary to cross validation analysis
Differential Revision: D94553707
1 parent b254ab8 commit e32c80c

File tree

5 files changed

+112
-12
lines changed

5 files changed

+112
-12
lines changed

ax/analysis/diagnostics.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ax.analysis.analysis import Analysis
1212
from ax.analysis.graphviz.generation_strategy_graph import GenerationStrategyGraph
1313
from ax.analysis.plotly.cross_validation import CrossValidationPlot
14+
from ax.analysis.plotly.metric_r2 import create_metric_r2_analysis_card
1415
from ax.analysis.utils import validate_experiment
1516
from ax.core.analysis_card import AnalysisCardGroup
1617
from ax.core.experiment import Experiment
@@ -76,17 +77,18 @@ def compute(
7677
generation_strategy_name=generation_strategy.name
7778
)
7879

79-
cross_validation_plots = (
80-
[
81-
CrossValidationPlot(metric_names=metric_names).compute_or_error_card(
82-
experiment=experiment,
83-
generation_strategy=generation_strategy,
84-
adapter=adapter,
85-
)
86-
]
87-
if not is_bandit
88-
else []
89-
)
80+
cross_validation_plots = []
81+
metric_r2_card = []
82+
if not is_bandit:
83+
cv_analysis = CrossValidationPlot(metric_names=metric_names)
84+
cv_card = cv_analysis.compute_or_error_card(
85+
experiment=experiment,
86+
generation_strategy=generation_strategy,
87+
adapter=adapter,
88+
)
89+
cross_validation_plots = [cv_card]
90+
if cv_analysis._r2s:
91+
metric_r2_card = [create_metric_r2_analysis_card(r2s=cv_analysis._r2s)]
9092

9193
generation_strategy_graph = (
9294
[
@@ -103,5 +105,9 @@ def compute(
103105
return self._create_analysis_card_group(
104106
title=DIAGNOSTICS_CARDGROUP_TITLE,
105107
subtitle=DIAGNOSTICS_CARDGROUP_SUBTITLE,
106-
children=[*cross_validation_plots, *generation_strategy_graph],
108+
children=[
109+
*cross_validation_plots,
110+
*metric_r2_card,
111+
*generation_strategy_graph,
112+
],
107113
)

ax/analysis/plotly/cross_validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(
106106
self.untransform = untransform
107107
self.trial_index = trial_index
108108
self.labels: dict[str, str] = {**labels} if labels is not None else {}
109+
self._r2s: dict[str, float] = {}
109110

110111
@override
111112
def validate_applicable_state(
@@ -144,6 +145,7 @@ def compute(
144145
relevant_adapter._experiment.signature_to_metric[signature].name
145146
for signature in relevant_adapter._metric_signatures
146147
]
148+
self._r2s = {}
147149
for metric_name in self.metric_names or relevant_adapter_metric_names:
148150
df = _prepare_data(
149151
metric_name=metric_name, cv_results=cv_results, adapter=relevant_adapter
@@ -162,6 +164,7 @@ def compute(
162164
y_obs=df["observed"].to_numpy(),
163165
y_pred=df["predicted"].to_numpy(),
164166
)
167+
self._r2s[metric_title] = r_squared
165168

166169
# Define the cross-validation description based on the number of folds
167170
cv_description = (

ax/analysis/plotly/metric_r2.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
import pandas as pd
9+
from ax.analysis.healthcheck.predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD
10+
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
11+
from plotly import graph_objects as go, io as pio
12+
13+
14+
class MetricR2AnalysisCard(PlotlyAnalysisCard):
15+
"""A PlotlyAnalysisCard that displays a table of metric R² values
16+
with green highlighting for metrics that meet the model fit threshold."""
17+
18+
19+
def create_metric_r2_analysis_card(
20+
r2s: dict[str, float],
21+
threshold: float = DEFAULT_MODEL_FIT_THRESHOLD,
22+
) -> MetricR2AnalysisCard:
23+
"""Create a MetricR2AnalysisCard from a dictionary of metric R² values.
24+
25+
Args:
26+
r2s: Dictionary mapping metric names to their R² values.
27+
threshold: R² threshold for highlighting a metric as having
28+
good model fit. Defaults to DEFAULT_MODEL_FIT_THRESHOLD.
29+
30+
Returns:
31+
A MetricR2AnalysisCard with a table of metric R² values.
32+
"""
33+
metric_names = list(r2s.keys())
34+
r2_values = [f"{v:.2f}" for v in r2s.values()]
35+
36+
fill_colors = [
37+
"rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white" for r2 in r2s.values()
38+
]
39+
40+
fig = go.Figure(
41+
data=[
42+
go.Table(
43+
columnwidth=[4, 1],
44+
header={
45+
"values": ["Metric", "R\u00b2"],
46+
"align": "left",
47+
},
48+
cells={
49+
"values": [metric_names, r2_values],
50+
"align": "left",
51+
"fill_color": [fill_colors, fill_colors],
52+
},
53+
)
54+
]
55+
)
56+
57+
return MetricR2AnalysisCard(
58+
name="MetricR2Summary",
59+
title="Summary of model fits",
60+
subtitle=(
61+
"R\u00b2 (coefficient of determination) measures how well the model"
62+
" predicts each metric. Higher values indicate better model fit."
63+
f" Metrics with R\u00b2 >= {threshold} are highlighted in green."
64+
),
65+
df=pd.DataFrame({"Metric": metric_names, "R\u00b2": list(r2s.values())}),
66+
blob=pio.to_json(fig),
67+
)

ax/analysis/plotly/tests/test_cross_validation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def test_compute(self, mock_r2: mock.Mock) -> None:
106106
)
107107
self.assertIsNotNone(card.blob)
108108

109+
# Assert that _r2s is populated after compute
110+
self.assertIn("bar", analysis._r2s)
111+
self.assertAlmostEqual(analysis._r2s["bar"], 0.85)
112+
109113
# Assert that all arms are in the cross validation df
110114
# because trial index is not specified
111115
for t in self.client.experiment.trials.values():

ax/analysis/tests/test_diagnostics.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DIAGNOSTICS_CARDGROUP_TITLE,
1717
)
1818
from ax.analysis.plotly.cross_validation import CrossValidationPlot
19+
from ax.analysis.plotly.metric_r2 import MetricR2AnalysisCard
1920
from ax.api.client import Client
2021
from ax.api.configs import RangeParameterConfig
2122
from ax.core.analysis_card import ErrorAnalysisCard
@@ -118,6 +119,13 @@ def test_compute(self) -> None:
118119
for card in card_group.flatten():
119120
self.assertNotIsInstance(card, ErrorAnalysisCard)
120121

122+
# Should have a MetricR2AnalysisCard with the expected title
123+
r2_cards = [
124+
c for c in card_group.flatten() if isinstance(c, MetricR2AnalysisCard)
125+
]
126+
self.assertEqual(len(r2_cards), 1)
127+
self.assertEqual(r2_cards[0].title, "Summary of model fits")
128+
121129
# --- Verify metric_names via patching CrossValidationPlot ---
122130
original_cv_init: Callable[..., None] = CrossValidationPlot.__init__
123131

@@ -163,6 +171,12 @@ def capturing_init(self: CrossValidationPlot, **kwargs: object) -> None:
163171
self.assertIn("CrossValidationPlot", child_names_no_gs)
164172
self.assertNotIn("GenerationStrategyGraph", child_names_no_gs)
165173

174+
# MetricR2AnalysisCard not present when CV errors (no adapter available)
175+
r2_cards_no_gs = [
176+
c for c in card_group_no_gs.flatten() if isinstance(c, MetricR2AnalysisCard)
177+
]
178+
self.assertEqual(len(r2_cards_no_gs), 0)
179+
166180
def test_compute_bandit(self) -> None:
167181
experiment = Experiment(
168182
name="bandit_test",
@@ -250,6 +264,12 @@ def test_compute_bandit(self) -> None:
250264
# Bandit experiment should NOT include CrossValidationPlot
251265
self.assertNotIn("CrossValidationPlot", child_names)
252266

267+
# Bandit experiment should NOT include MetricR2AnalysisCard
268+
r2_cards = [
269+
c for c in card_group.flatten() if isinstance(c, MetricR2AnalysisCard)
270+
]
271+
self.assertEqual(len(r2_cards), 0)
272+
253273
# GenerationStrategyGraph should still be included (GS is provided)
254274
self.assertIn("GenerationStrategyGraph", child_names)
255275

0 commit comments

Comments
 (0)