|
11 | 11 | import numpy as np
|
12 | 12 |
|
13 | 13 | from ax.core.types import TParamValue
|
| 14 | +from ax.exceptions.core import UnsupportedError |
14 | 15 | from ax.service.ax_client import AxClient, ObjectiveProperties
|
15 | 16 | from ax.telemetry.ax_client import AxClientCompletedRecord, AxClientCreatedRecord
|
16 | 17 | from ax.telemetry.experiment import ExperimentCompletedRecord, ExperimentCreatedRecord
|
@@ -130,6 +131,25 @@ def test_ax_client_completed_record_from_ax_client(self) -> None:
|
130 | 131 | )
|
131 | 132 | self._compare_axclient_completed_records(record, expected)
|
132 | 133 |
|
| 134 | + def test_batch_trial_warning(self) -> None: |
| 135 | + ax_client = AxClient() |
| 136 | + error_msg = ( |
| 137 | + "AxClient API does not support batch trials yet." |
| 138 | + " We plan to add this support in coming versions." |
| 139 | + ) |
| 140 | + with self.assertRaisesRegex(UnsupportedError, error_msg): |
| 141 | + ax_client.create_experiment( |
| 142 | + name="test_experiment", |
| 143 | + parameters=[ |
| 144 | + {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, |
| 145 | + ], |
| 146 | + objectives={"branin": ObjectiveProperties(minimize=True)}, |
| 147 | + is_test=True, |
| 148 | + choose_generation_strategy_kwargs={ |
| 149 | + "use_batch_trials": True, |
| 150 | + }, |
| 151 | + ) |
| 152 | + |
133 | 153 | def _compare_axclient_completed_records(
|
134 | 154 | self, record: AxClientCompletedRecord, expected: AxClientCompletedRecord
|
135 | 155 | ) -> None:
|
|
0 commit comments