|
16 | 16 | DIAGNOSTICS_CARDGROUP_TITLE, |
17 | 17 | ) |
18 | 18 | from ax.analysis.plotly.cross_validation import CrossValidationPlot |
| 19 | +from ax.analysis.plotly.metric_r2 import MetricR2AnalysisCard |
19 | 20 | from ax.api.client import Client |
20 | 21 | from ax.api.configs import RangeParameterConfig |
21 | 22 | from ax.core.analysis_card import ErrorAnalysisCard |
@@ -118,6 +119,13 @@ def test_compute(self) -> None: |
118 | 119 | for card in card_group.flatten(): |
119 | 120 | self.assertNotIsInstance(card, ErrorAnalysisCard) |
120 | 121 |
|
| 122 | + # Should have a MetricR2AnalysisCard with the expected title |
| 123 | + r2_cards = [ |
| 124 | + c for c in card_group.flatten() if isinstance(c, MetricR2AnalysisCard) |
| 125 | + ] |
| 126 | + self.assertEqual(len(r2_cards), 1) |
| 127 | + self.assertEqual(r2_cards[0].title, "Summary of model fits") |
| 128 | + |
121 | 129 | # --- Verify metric_names via patching CrossValidationPlot --- |
122 | 130 | original_cv_init: Callable[..., None] = CrossValidationPlot.__init__ |
123 | 131 |
|
@@ -163,6 +171,12 @@ def capturing_init(self: CrossValidationPlot, **kwargs: object) -> None: |
163 | 171 | self.assertIn("CrossValidationPlot", child_names_no_gs) |
164 | 172 | self.assertNotIn("GenerationStrategyGraph", child_names_no_gs) |
165 | 173 |
|
| 174 | + # MetricR2AnalysisCard not present when CV errors (no adapter available) |
| 175 | + r2_cards_no_gs = [ |
| 176 | + c for c in card_group_no_gs.flatten() if isinstance(c, MetricR2AnalysisCard) |
| 177 | + ] |
| 178 | + self.assertEqual(len(r2_cards_no_gs), 0) |
| 179 | + |
166 | 180 | def test_compute_bandit(self) -> None: |
167 | 181 | experiment = Experiment( |
168 | 182 | name="bandit_test", |
@@ -250,6 +264,12 @@ def test_compute_bandit(self) -> None: |
250 | 264 | # Bandit experiment should NOT include CrossValidationPlot |
251 | 265 | self.assertNotIn("CrossValidationPlot", child_names) |
252 | 266 |
|
| 267 | + # Bandit experiment should NOT include MetricR2AnalysisCard |
| 268 | + r2_cards = [ |
| 269 | + c for c in card_group.flatten() if isinstance(c, MetricR2AnalysisCard) |
| 270 | + ] |
| 271 | + self.assertEqual(len(r2_cards), 0) |
| 272 | + |
253 | 273 | # GenerationStrategyGraph should still be included (GS is provided) |
254 | 274 | self.assertIn("GenerationStrategyGraph", child_names) |
255 | 275 |
|
|
0 commit comments