|
8 | 8 |
|
9 | 9 | from ax.exceptions.core import OptimizationNotConfiguredError |
10 | 10 | from ax.service.orchestrator import OrchestratorOptions |
11 | | -from ax.utils.common.complexity_utils import summarize_ax_optimization_complexity |
| 11 | +from ax.utils.common.complexity_utils import ( |
| 12 | + ADVANCED_TIER_MESSAGE, |
| 13 | + format_tier_message, |
| 14 | + summarize_ax_optimization_complexity, |
| 15 | + UNSUPPORTED_TIER_MESSAGE, |
| 16 | + WHEELHOUSE_TIER_MESSAGE, |
| 17 | +) |
12 | 18 | from ax.utils.common.testutils import TestCase |
13 | 19 | from ax.utils.testing.core_stubs import ( |
14 | 20 | get_experiment, |
@@ -160,3 +166,78 @@ def test_parameter_constraints_counted(self) -> None: |
160 | 166 |
|
161 | 167 | # THEN num_parameter_constraints should be greater than 0 |
162 | 168 | self.assertGreater(summary["num_parameter_constraints"], 0) |
| 169 | + |
| 170 | + |
| 171 | +class TestFormatTierMessage(TestCase): |
| 172 | + """Tests for format_tier_message.""" |
| 173 | + |
| 174 | + def test_tier_messages(self) -> None: |
| 175 | + """Test formatting of tier messages for all tiers.""" |
| 176 | + test_cases: list[ |
| 177 | + tuple[ |
| 178 | + str, |
| 179 | + list[str] | None, |
| 180 | + list[str] | None, |
| 181 | + str, |
| 182 | + list[str], |
| 183 | + ] |
| 184 | + ] = [ |
| 185 | + ( |
| 186 | + "Wheelhouse", |
| 187 | + None, |
| 188 | + None, |
| 189 | + WHEELHOUSE_TIER_MESSAGE, |
| 190 | + ["tier 'Wheelhouse'"], |
| 191 | + ), |
| 192 | + ( |
| 193 | + "Advanced", |
| 194 | + ["51 tunable parameters", "Early stopping is enabled"], |
| 195 | + None, |
| 196 | + ADVANCED_TIER_MESSAGE, |
| 197 | + [ |
| 198 | + "tier 'Advanced'", |
| 199 | + "Why this experiment is not in the 'Wheelhouse' tier:", |
| 200 | + "51 tunable parameters", |
| 201 | + "Early stopping is enabled", |
| 202 | + ], |
| 203 | + ), |
| 204 | + ( |
| 205 | + "Unsupported", |
| 206 | + ["51 tunable parameters"], |
| 207 | + ["201 tunable parameters"], |
| 208 | + UNSUPPORTED_TIER_MESSAGE, |
| 209 | + [ |
| 210 | + "tier 'Unsupported'", |
| 211 | + "Why this experiment is not in the 'Wheelhouse' tier:", |
| 212 | + "51 tunable parameters", |
| 213 | + "Why this experiment is not in the 'Advanced' tier:", |
| 214 | + "201 tunable parameters", |
| 215 | + ], |
| 216 | + ), |
| 217 | + ] |
| 218 | + |
| 219 | + for ( |
| 220 | + tier, |
| 221 | + why_not_wheelhouse, |
| 222 | + why_not_supported, |
| 223 | + expected_message, |
| 224 | + expected_contents, |
| 225 | + ) in test_cases: |
| 226 | + with self.subTest(tier=tier): |
| 227 | + msg = format_tier_message( |
| 228 | + tier=tier, |
| 229 | + why_not_is_in_wheelhouse=why_not_wheelhouse, |
| 230 | + why_not_supported=why_not_supported, |
| 231 | + ) |
| 232 | + self.assertIn(expected_message, msg) |
| 233 | + for content in expected_contents: |
| 234 | + self.assertIn(content, msg) |
| 235 | + |
| 236 | + def test_unknown_tier_raises_error(self) -> None: |
| 237 | + """Test that unknown tier raises ValueError.""" |
| 238 | + with self.assertRaisesRegex(ValueError, 'Got unexpected tier "BadTier"'): |
| 239 | + format_tier_message( |
| 240 | + tier="BadTier", |
| 241 | + why_not_is_in_wheelhouse=None, |
| 242 | + why_not_supported=None, |
| 243 | + ) |
0 commit comments