|
7 | 7 |
|
8 | 8 | from unittest.mock import patch |
9 | 9 |
|
10 | | -from ax.analysis.plotly.surface.utils import get_fixed_values_for_slice_or_contour |
| 10 | +from ax.analysis.plotly.surface.utils import ( |
| 11 | + _get_best_trial_info, |
| 12 | + get_fixed_values_for_slice_or_contour, |
| 13 | +) |
11 | 14 | from ax.core.arm import Arm |
12 | 15 | from ax.service.ax_client import AxClient, ObjectiveProperties |
13 | 16 | from ax.utils.common.testutils import TestCase |
@@ -121,3 +124,50 @@ def test_returns_center_when_best_point_returns_none(self) -> None: |
121 | 124 |
|
122 | 125 | self.assertEqual(fixed_values, {"x": 0.0, "y": 0.0}) |
123 | 126 | self.assertEqual(description, "the center of the search space") |
| 127 | + |
| 128 | + def test_get_best_trial_info_with_batch_trial(self) -> None: |
| 129 | + """Test that _get_best_trial_info works with BatchTrial (PTS experiments), |
| 130 | + both when a matching arm is found and when no arm matches.""" |
| 131 | + client = self._create_client() |
| 132 | + experiment = client.experiment |
| 133 | + trial = experiment.new_batch_trial() |
| 134 | + trial.add_arms_and_weights( |
| 135 | + arms=[ |
| 136 | + Arm(parameters={"x": 0.1, "y": 0.2}, name="0_0"), |
| 137 | + Arm(parameters={"x": 0.3, "y": 0.4}, name="0_1"), |
| 138 | + Arm(parameters={"x": 0.5, "y": 0.6}, name="0_2"), |
| 139 | + ], |
| 140 | + ) |
| 141 | + |
| 142 | + # Subtest 1: matching arm found |
| 143 | + with self.subTest("matching_arm_found"): |
| 144 | + with patch( |
| 145 | + "ax.analysis.plotly.surface.utils." |
| 146 | + "get_best_parameters_from_model_predictions_with_trial_index", |
| 147 | + return_value=(0, {"x": 0.3, "y": 0.4}, None), |
| 148 | + ): |
| 149 | + result = _get_best_trial_info( |
| 150 | + experiment=experiment, |
| 151 | + generation_strategy=client.generation_strategy, |
| 152 | + ) |
| 153 | + |
| 154 | + self.assertIsNotNone(result) |
| 155 | + # pyre-ignore[16]: result is not None per assertion above |
| 156 | + parameterization, trial_index, arm_name = result |
| 157 | + self.assertEqual(parameterization, {"x": 0.3, "y": 0.4}) |
| 158 | + self.assertEqual(trial_index, 0) |
| 159 | + self.assertEqual(arm_name, "0_1") |
| 160 | + |
| 161 | + # Subtest 2: no matching arm returns None |
| 162 | + with self.subTest("no_matching_arm_returns_none"): |
| 163 | + with patch( |
| 164 | + "ax.analysis.plotly.surface.utils." |
| 165 | + "get_best_parameters_from_model_predictions_with_trial_index", |
| 166 | + return_value=(0, {"x": 0.9, "y": 0.9}, None), |
| 167 | + ): |
| 168 | + result = _get_best_trial_info( |
| 169 | + experiment=experiment, |
| 170 | + generation_strategy=client.generation_strategy, |
| 171 | + ) |
| 172 | + |
| 173 | + self.assertIsNone(result) |
0 commit comments