diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index 23abfcabd49..9c63e9423a2 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -13,6 +13,7 @@ from ax.adapter.base import Adapter from ax.adapter.cross_validation import cross_validate, CVResult from ax.analysis.analysis import Analysis +from ax.analysis.healthcheck.predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD from ax.analysis.plotly.color_constants import AX_BLUE from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card from ax.analysis.plotly.utils import get_scatter_point_color, Z_SCORE_95_CI @@ -106,6 +107,7 @@ def __init__( self.untransform = untransform self.trial_index = trial_index self.labels: dict[str, str] = {**labels} if labels is not None else {} + self._r2s: dict[str, float] = {} @override def validate_applicable_state( @@ -144,6 +146,7 @@ def compute( relevant_adapter._experiment.signature_to_metric[signature].name for signature in relevant_adapter._metric_signatures ] + self._r2s = {} for metric_name in self.metric_names or relevant_adapter_metric_names: df = _prepare_data( metric_name=metric_name, cv_results=cv_results, adapter=relevant_adapter @@ -162,6 +165,7 @@ def compute( y_obs=df["observed"].to_numpy(), y_pred=df["predicted"].to_numpy(), ) + self._r2s[metric_title] = r_squared # Define the cross-validation description based on the number of folds cv_description = ( @@ -202,6 +206,50 @@ def compute( cards.append(card) + # Create a summary table of R2 values for all metrics + if self._r2s: + threshold = DEFAULT_MODEL_FIT_THRESHOLD + metric_names_list = list(self._r2s.keys()) + r2_values = [f"{v:.2f}" for v in self._r2s.values()] + fill_colors = [ + "rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white" + for r2 in self._r2s.values() + ] + r2_fig = go.Figure( + data=[ + go.Table( + columnwidth=[4, 1], + header={ + "values": ["Metric", "R\u00b2"], + "align": "left", + }, + cells={ + "values": [metric_names_list, r2_values], + "align": "left", + "fill_color": [fill_colors, fill_colors], + }, + ) + ] + ) + r2_card = create_plotly_analysis_card( + name=self.__class__.__name__, + title="Summary of model fits", + subtitle=( + "R\u00b2 (coefficient of determination) measures how well" + " the model predicts each metric. Higher values indicate" + " better model fit. Metrics with R\u00b2 >=" + f" {threshold} are highlighted in green." + ), + df=pd.DataFrame( + { + "Metric": metric_names_list, + "R\u00b2": list(self._r2s.values()), + } + ), + fig=r2_fig, + ) + cards.append(r2_card) + return self._create_analysis_card_group( title=CV_CARDGROUP_TITLE, subtitle=CV_CARDGROUP_SUBTITLE, diff --git a/ax/analysis/plotly/tests/test_cross_validation.py b/ax/analysis/plotly/tests/test_cross_validation.py index b7004f819e9..f186b469167 100644 --- a/ax/analysis/plotly/tests/test_cross_validation.py +++ b/ax/analysis/plotly/tests/test_cross_validation.py @@ -65,9 +65,12 @@ def test_compute(self, mock_r2: mock.Mock) -> None: ): analysis.compute() - (card,) = analysis.compute( + cards = analysis.compute( generation_strategy=self.client.generation_strategy ).flatten() + # Should have the CV plot card and the R2 summary card + self.assertEqual(len(cards), 2) + card = cards[0] self.assertEqual( card.name, "CrossValidationPlot", @@ -106,6 +109,15 @@ def test_compute(self, mock_r2: mock.Mock) -> None: ) self.assertIsNotNone(card.blob) + # Assert that _r2s is populated after compute + self.assertIn("bar", analysis._r2s) + self.assertAlmostEqual(analysis._r2s["bar"], 0.85) + + # Assert the R2 summary card + r2_card = cards[1] + self.assertEqual(r2_card.name, "CrossValidationPlot") + self.assertEqual(r2_card.title, "Summary of model fits") + # Assert that all arms are in the cross validation df # because trial index is not specified for t in self.client.experiment.trials.values(): @@ -121,9 +133,10 @@ def test_compute(self, mock_r2: mock.Mock) -> None: def test_it_can_specify_trial_index_correctly(self) -> None: analysis = CrossValidationPlot(metric_names=["bar"], trial_index=9) - (card,) = analysis.compute( + cards = analysis.compute( generation_strategy=self.client.generation_strategy ).flatten() + card = cards[0] for t in self.client.experiment.trials.values(): # Skip the last trial because the model was used to generate it # and therefore hasn't observed it @@ -159,15 +172,17 @@ def test_compute_adhoc(self, mock_r2: mock.Mock) -> None: cards = compute_cross_validation_adhoc( adapter=adapter, labels=metric_mapping ).flatten() - self.assertEqual(len(cards), 2) + self.assertEqual(len(cards), 3) titles = { "Cross Validation for spunky (R\u00b2 = 0.85)", "Cross Validation for foo2 (R\u00b2 = 0.85)", } - for card in cards: + for card in cards[:2]: self.assertEqual(card.name, "CrossValidationPlot") self.assertIn(card.title, titles) titles.remove(card.title) + # The last card is the R2 summary + self.assertEqual(cards[2].title, "Summary of model fits") @TestCase.ax_long_test( reason=(