|
13 | 13 | )
|
14 | 14 | from ax.benchmark.methods.sobol import get_sobol_generation_strategy
|
15 | 15 | from ax.core.experiment import Experiment
|
| 16 | +from ax.modelbridge.factory import get_sobol |
16 | 17 | from ax.utils.common.testutils import TestCase
|
| 18 | +from ax.utils.testing.core_stubs import get_experiment_with_observations |
17 | 19 | from pyre_extensions import none_throws
|
18 | 20 |
|
19 | 21 |
|
@@ -71,8 +73,36 @@ def test_get_best_parameters(self) -> None:
|
71 | 73 | experiment=experiment, optimization_config=soo_config, n_points=2
|
72 | 74 | )
|
73 | 75 |
|
74 |
| - with self.subTest("Empty experiment"): |
75 |
| - result = method.get_best_parameters( |
| 76 | + with self.subTest("Empty experiment"), self.assertRaisesRegex( |
| 77 | + ValueError, "Cannot identify a best point if experiment has no trials" |
| 78 | + ): |
| 79 | + method.get_best_parameters( |
76 | 80 | experiment=experiment, optimization_config=soo_config, n_points=1
|
77 | 81 | )
|
78 |
| - self.assertEqual(result, []) |
| 82 | + |
| 83 | + with self.subTest("All constraints violated"): |
| 84 | + experiment = get_experiment_with_observations( |
| 85 | + observations=[[1, -1], [2, -1]], |
| 86 | + constrained=True, |
| 87 | + ) |
| 88 | + best_point = method.get_best_parameters( |
| 89 | + n_points=1, |
| 90 | + experiment=experiment, |
| 91 | + optimization_config=none_throws(experiment.optimization_config), |
| 92 | + ) |
| 93 | + self.assertEqual(len(best_point), 1) |
| 94 | + self.assertEqual(best_point[0], experiment.trials[1].arms[0].parameters) |
| 95 | + |
| 96 | + with self.subTest("No completed trials"): |
| 97 | + experiment = get_experiment_with_observations(observations=[]) |
| 98 | + sobol_generator = get_sobol(search_space=experiment.search_space) |
| 99 | + for _ in range(3): |
| 100 | + trial = experiment.new_trial(generator_run=sobol_generator.gen(n=1)) |
| 101 | + trial.run() |
| 102 | + best_point = method.get_best_parameters( |
| 103 | + n_points=1, |
| 104 | + experiment=experiment, |
| 105 | + optimization_config=none_throws(experiment.optimization_config), |
| 106 | + ) |
| 107 | + self.assertEqual(len(best_point), 1) |
| 108 | + self.assertEqual(best_point[0], experiment.trials[2].arms[0].parameters) |
0 commit comments