Skip to content

Commit cb88fc5

Browse files
blethamfacebook-github-bot
authored andcommitted
add metric prediction summary to cross validation analysis (facebook#4995)
Summary: Pull Request resolved: facebook#4995 Reviewed By: mpolson64 Differential Revision: D94553707
1 parent 0d3535e commit cb88fc5

2 files changed

Lines changed: 67 additions & 4 deletions

File tree

ax/analysis/plotly/cross_validation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ax.adapter.base import Adapter
1414
from ax.adapter.cross_validation import cross_validate, CVResult
1515
from ax.analysis.analysis import Analysis
16+
from ax.analysis.healthcheck.predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD
1617
from ax.analysis.plotly.color_constants import AX_BLUE
1718
from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card
1819
from ax.analysis.plotly.utils import get_scatter_point_color, Z_SCORE_95_CI
@@ -106,6 +107,7 @@ def __init__(
106107
self.untransform = untransform
107108
self.trial_index = trial_index
108109
self.labels: dict[str, str] = {**labels} if labels is not None else {}
110+
self._r2s: dict[str, float] = {}
109111

110112
@override
111113
def validate_applicable_state(
@@ -144,6 +146,7 @@ def compute(
144146
relevant_adapter._experiment.signature_to_metric[signature].name
145147
for signature in relevant_adapter._metric_signatures
146148
]
149+
self._r2s = {}
147150
for metric_name in self.metric_names or relevant_adapter_metric_names:
148151
df = _prepare_data(
149152
metric_name=metric_name, cv_results=cv_results, adapter=relevant_adapter
@@ -162,6 +165,7 @@ def compute(
162165
y_obs=df["observed"].to_numpy(),
163166
y_pred=df["predicted"].to_numpy(),
164167
)
168+
self._r2s[metric_title] = r_squared
165169

166170
# Define the cross-validation description based on the number of folds
167171
cv_description = (
@@ -202,6 +206,50 @@ def compute(
202206

203207
cards.append(card)
204208

209+
# Create a summary table of R2 values for all metrics
210+
if self._r2s:
211+
threshold = DEFAULT_MODEL_FIT_THRESHOLD
212+
metric_names_list = list(self._r2s.keys())
213+
r2_values = [f"{v:.2f}" for v in self._r2s.values()]
214+
fill_colors = [
215+
"rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white"
216+
for r2 in self._r2s.values()
217+
]
218+
r2_fig = go.Figure(
219+
data=[
220+
go.Table(
221+
columnwidth=[4, 1],
222+
header={
223+
"values": ["Metric", "R\u00b2"],
224+
"align": "left",
225+
},
226+
cells={
227+
"values": [metric_names_list, r2_values],
228+
"align": "left",
229+
"fill_color": [fill_colors, fill_colors],
230+
},
231+
)
232+
]
233+
)
234+
r2_card = create_plotly_analysis_card(
235+
name=self.__class__.__name__,
236+
title="Summary of model fits",
237+
subtitle=(
238+
"R\u00b2 (coefficient of determination) measures how well"
239+
" the model predicts each metric. Higher values indicate"
240+
" better model fit. Metrics with R\u00b2 >="
241+
f" {threshold} are highlighted in green."
242+
),
243+
df=pd.DataFrame(
244+
{
245+
"Metric": metric_names_list,
246+
"R\u00b2": list(self._r2s.values()),
247+
}
248+
),
249+
fig=r2_fig,
250+
)
251+
cards.append(r2_card)
252+
205253
return self._create_analysis_card_group(
206254
title=CV_CARDGROUP_TITLE,
207255
subtitle=CV_CARDGROUP_SUBTITLE,

ax/analysis/plotly/tests/test_cross_validation.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ def test_compute(self, mock_r2: mock.Mock) -> None:
6565
):
6666
analysis.compute()
6767

68-
(card,) = analysis.compute(
68+
cards = analysis.compute(
6969
generation_strategy=self.client.generation_strategy
7070
).flatten()
71+
# Should have the CV plot card and the R2 summary card
72+
self.assertEqual(len(cards), 2)
73+
card = cards[0]
7174
self.assertEqual(
7275
card.name,
7376
"CrossValidationPlot",
@@ -106,6 +109,15 @@ def test_compute(self, mock_r2: mock.Mock) -> None:
106109
)
107110
self.assertIsNotNone(card.blob)
108111

112+
# Assert that _r2s is populated after compute
113+
self.assertIn("bar", analysis._r2s)
114+
self.assertAlmostEqual(analysis._r2s["bar"], 0.85)
115+
116+
# Assert the R2 summary card
117+
r2_card = cards[1]
118+
self.assertEqual(r2_card.name, "CrossValidationPlot")
119+
self.assertEqual(r2_card.title, "Summary of model fits")
120+
109121
# Assert that all arms are in the cross validation df
110122
# because trial index is not specified
111123
for t in self.client.experiment.trials.values():
@@ -121,9 +133,10 @@ def test_compute(self, mock_r2: mock.Mock) -> None:
121133

122134
def test_it_can_specify_trial_index_correctly(self) -> None:
123135
analysis = CrossValidationPlot(metric_names=["bar"], trial_index=9)
124-
(card,) = analysis.compute(
136+
cards = analysis.compute(
125137
generation_strategy=self.client.generation_strategy
126138
).flatten()
139+
card = cards[0]
127140
for t in self.client.experiment.trials.values():
128141
# Skip the last trial because the model was used to generate it
129142
# and therefore hasn't observed it
@@ -159,15 +172,17 @@ def test_compute_adhoc(self, mock_r2: mock.Mock) -> None:
159172
cards = compute_cross_validation_adhoc(
160173
adapter=adapter, labels=metric_mapping
161174
).flatten()
162-
self.assertEqual(len(cards), 2)
175+
self.assertEqual(len(cards), 3)
163176
titles = {
164177
"Cross Validation for spunky (R\u00b2 = 0.85)",
165178
"Cross Validation for foo2 (R\u00b2 = 0.85)",
166179
}
167-
for card in cards:
180+
for card in cards[:2]:
168181
self.assertEqual(card.name, "CrossValidationPlot")
169182
self.assertIn(card.title, titles)
170183
titles.remove(card.title)
184+
# The last card is the R2 summary
185+
self.assertEqual(cards[2].title, "Summary of model fits")
171186

172187
@TestCase.ax_long_test(
173188
reason=(

0 commit comments

Comments
 (0)